"""
MCMC Posterior.
This module contains the MCMCPosterior class that wraps the NeuralPosterior class to perform MCMC sampling.
"""
from typing import Any, Callable, Optional
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
# from flowMC.resource.nf_model.rqSpline import MaskedCouplingRQSpline
# from flowMC.resource.local_kernel.MALA import MALA
# from flowMC.Sampler import Sampler
from jaxtyping import Array
from numpyro.infer import HMC, MCMC, NUTS
from numpyro.infer.util import init_to_value
from jaxili.model import NDENetwork
from jaxili.posterior import NeuralPosterior
from jaxili.train import TrainState
implemented_method = ["nuts_numpyro", "hmc_numpyro"]
nuts_numpyro_kwargs_default = {}
hmc_numpyro_kwargs_default = {}
mala_flowmc_kwargs_default = {}
[docs]
class MCMCPosterior(NeuralPosterior):
r"""
Likelihood $p(x|\theta)$ with `log_prob()` and `sample()` methods.
The class wraps the trained neural network using Neural Likelihood Estimation (NLE).
Sampling is performed using Markov Chain Monte Carlo (MCMC) methods to get samples from the posterior.
"""
def __init__(
self,
model: NDENetwork,
state: TrainState,
prior_distr: dist.Distribution,
verbose: Optional[bool] = False,
x: Optional[Array] = None,
mcmc_method: Optional[str] = "nuts_numpyro",
mcmc_kwargs: Optional[dict] = nuts_numpyro_kwargs_default,
):
"""
Initialize the MCMC Posterior.
Parameters
----------
model : NDENetwork
The neural network used to generate the posterior.
state : TrainState
The state of the neural network.
prior_distr : ...
The prior distribution of the parameters. (One must specify a prior to perform MCMC sampling.)
verbose : bool
Whether to print information. (Default: False)
x : Array
The data used to condition the posterior. (Default: None)
mcmc_method : str
The MCMC method to use. (Default: 'hmc_numpyro')
mcmc_kwargs : dict
The keyword arguments for the MCMC method. (Default: hmc_numpyro_kwargs_default)
"""
super().__init__(model, state, verbose, x)
self.set_mcmc_method(mcmc_method)
self.set_mcmc_kwargs(mcmc_kwargs)
if self.verbose:
print(f"Using MCMC method: {mcmc_method}")
print(f"MCMC kwargs: {mcmc_kwargs}")
self.set_prior(prior_distr)
[docs]
def sample(self, num_samples: int, key: Array, x: Optional[Array] = None, **kwargs):
r"""
Sample from the posterior using MCMC.
Parameters
----------
x : Array
The data used to sample the parameters.
num_samples : int
The number of samples to draw.
key : Array
The random key used to generate the samples.
Returns
-------
Array
The samples from the posterior.
"""
if x is None:
try:
x = self.x
except:
raise ValueError(
"The data x must be specified or loaded in the posterior with `set_default_x()`."
)
self.mcmc_kwargs.update({"num_samples": num_samples})
num_chains = self.mcmc_kwargs.get("num_chains", 1)
sample_key, key = jax.random.split(key)
initial_states = self._get_initial_state(x, num_chains, sample_key, **kwargs)
if self.mcmc_method == "nuts_numpyro":
samples = self._nuts_numpyro(x, key, initial_states, self.mcmc_kwargs)
elif self.mcmc_method == "hmc_numpyro":
samples = self._hmc_numpyro(x, key, initial_states, self.mcmc_kwargs)
elif self.mcmc_method == "mala_flowmc":
samples = self._mala_flowmc(x, key, initial_states, self.mcmc_kwargs)
else:
raise NotImplementedError(
f"The MCMC method {self.mcmc_method} is not implemented. Check print_implemented_methods()."
)
self.mcmc_kwargs.pop("num_samples")
return samples
[docs]
def log_prior(self, theta):
r"""
Compute the log prior of the parameters.
Parameters
----------
theta : Array
The parameters to evaluate the log prior.
Returns
-------
Array
The log prior of the parameters.
"""
log_prior = self.prior_distr.log_prob(theta)
if len(log_prior.shape) > 1:
log_prior = jnp.sum(log_prior, axis=-1)
return log_prior
[docs]
def log_likelihood(self, x: Array, theta: Array):
r"""
Compute the log-likelihood learned by the neural density estimator.
Parameters
----------
theta : Array
The parameters to evaluate the log probability.
x : Array
The data used to condition the posterior.
Returns
-------
Array
The unnormalized log probability.
"""
params = self.state.params
log_likelihood = self.model.apply(
{"params": params}, x, theta, method="log_prob"
).squeeze()
return log_likelihood
[docs]
def unnormalized_log_prob(self, theta: Array, x: Optional[Array] = None):
"""
Compute the unnormalized log probability of the posterior.
Parameters
----------
theta : Array
The parameters to evaluate the log probability.
x : Array
The data used to condition the posterior.
Returns
-------
Array
The unnormalized log probability.
"""
return self.log_prior(theta) + self.log_likelihood(x, theta)
def _build_model_numpyro(self, x: Array):
"""
Create a function corresponding to the Bayesian model in numpyro.
Parameters
----------
x : Array
The data used to condition the posterior.
Returns
-------
Callable
The model function.
"""
def model(data):
theta = numpyro.sample("theta", self.prior_distr)
z = numpyro.deterministic("z", theta)
likelihood = self.log_likelihood(x, theta.reshape((1, theta.shape[0])))
numpyro.factor("log_likelihood", likelihood)
return model
def _nuts_numpyro(
self,
x: Array,
key: Array,
initial_state: Array,
mcmc_kwargs: Optional[dict] = nuts_numpyro_kwargs_default,
):
"""
Perform MCMC sampling using the No-U-Turn Sampler (NUTS) in numpyro.
Parameters
----------
x : Array
The data used to condition the posterior.
key : Array
The random key used to generate the samples.
initial_state : Array
The initial state of the MCMC sampler.
mcmc_kwargs: dict
The keyword arguments for the MCMC method. (Default: nuts_numpyro_kwargs_default)
Returns
-------
Array
The samples from the posterior.
"""
model = self._build_model_numpyro(x)
adapt_step_size = mcmc_kwargs.get("adapt_step_size", True)
init_values = initial_state
nuts_kernel = NUTS(
model,
adapt_step_size=adapt_step_size,
init_strategy=init_to_value(values=init_values),
)
num_warmup = mcmc_kwargs.get("num_warmup", 500)
num_samples = mcmc_kwargs.get("num_samples", 2000)
num_chains = mcmc_kwargs.get("num_chains", 1)
mcmc = MCMC(
nuts_kernel,
num_warmup=num_warmup,
num_samples=num_samples,
num_chains=num_chains,
)
mcmc.run(key, data=x)
samples = mcmc.get_samples()["theta"]
return samples
def _hmc_numpyro(
self,
x: Array,
key: Array,
initial_state: Array,
mcmc_kwargs: Optional[dict] = hmc_numpyro_kwargs_default,
):
"""
Perform MCMC sampling using the Hamiltonian Monte Carlo (HMC) in numpyro.
Parameters
----------
x : Array
The data used to condition the posterior.
key : Array
The random key used to generate the samples.
initial_state : Array
The initial state of the MCMC sampler.
mcmc_kwargs: dict
The keyword arguments for the MCMC method. (Default: hmc_numpyro_kwargs_default)
Returns
-------
Array
The samples from the posterior.
"""
model = self._build_model_numpyro(x)
adapt_step_size = mcmc_kwargs.get("adapt_step_size", True)
init_values = initial_state
hmc_kernel = HMC(
model,
adapt_step_size=adapt_step_size,
init_strategy=init_to_value(values=init_values),
)
num_warmup = mcmc_kwargs.get("num_warmup", 500)
num_samples = mcmc_kwargs.get("num_samples", 2000)
num_chains = mcmc_kwargs.get("num_chains", 1)
mcmc = MCMC(
hmc_kernel,
num_warmup=num_warmup,
num_samples=num_samples,
num_chains=num_chains,
)
mcmc.run(key, data=x)
samples = mcmc.get_samples()["theta"]
return samples
def _mala_flowmc(
self,
x: Array,
key: Array,
initial_state: Array,
mcmc_kwargs: Optional[dict] = mala_flowmc_kwargs_default,
):
"""
Perform MCMC sampling using the Metropolis-Adjusted Langevin Algorithm (MALA) in FlowMC.
WARNING: Current version does not work. Will be updated in future releases.
Parameters
----------
x : Array
The data used to condition the posterior.
key : Array
The random key used to generate the samples.
initial_state : Array
The initial state of the MCMC sampler.
mcmc_kwargs: dict
The keyword arguments for the MCMC method. (Default: mala_flowmc_kwargs_default)
Returns
-------
Array
The samples from the posterior.
"""
raise NotImplementedError(
"Sampling with FlowMC is not yet implemented in JaxILI."
)
""" num_samples = mcmc_kwargs.get("num_samples", 2000)
num_chains = mcmc_kwargs.get("num_chains", 1)
n_dim = self.prior_distr.sample(sample_shape=(1,), key=key).shape[
0
] # Get the dimension of the parameters (we can probably do better)
# We can probably inherit this property from the trainer module.
# Setup the Normalizing Flow
n_layers = mcmc_kwargs.get("n_layers", 3)
hidden_size = mcmc_kwargs.get("hidden_size", [64, 64])
num_bins = mcmc_kwargs.get("num_bins", 8) # Number of bins in the spline
nf_key, key = jax.random.split(key)
model = MaskedCouplingRQSpline(n_dim, n_layers, hidden_size, num_bins, nf_key)
# Setup the MALA kernel
step_size = mcmc_kwargs.get("step_size", 1e-1)
local_sampler = MALA(self.unnormalized_log_prob, True, step_size=step_size)
# Create the sampler
n_local_steps = mcmc_kwargs.get("n_local_steps", 50)
n_global_steps = mcmc_kwargs.get("n_global_steps", 50)
n_epochs = mcmc_kwargs.get("n_epochs", 30)
learning_rate = mcmc_kwargs.get("learning_rate", 1e-2)
nf_sampler = Sampler(
n_dim,
key,
x,
local_sampler,
model,
n_local_steps=n_local_steps,
n_global_steps=n_global_steps,
n_epochs=n_epochs,
learning_rate=learning_rate,
batch_size=num_samples,
n_chains=num_chains,
)
# Sample!
initial_state = jnp.expand_dims(initial_state, axis=1)
nf_sampler.sample(initial_state, x)
chains, log_prob, local_accs, global_accs = (
nf_sampler.get_sampler_state().values()
)
chains = chains.squeeze()
return chains """
def _get_initial_state(self, x: Array, num_chains: int, key: Array, **kwargs):
"""
Get the initial state of the MCMC sampler. The state is obtained via resampling proposal samples with log_prob weights.
Parameters
----------
x : Array
The data used to condition the posterior.
num_chains : int
The number of chains.
key : Array
The random key used to generate the samples.
Returns
-------
Array
The initial state of the MCMC sampler.
"""
initial_state = []
# Define the potential function
def potential_fn(theta):
x_ = x * jnp.ones(
(theta.shape[0], 1)
) # Reshape x to match the shape of theta
return self.unnormalized_log_prob(theta, x_)
for _ in range(num_chains):
proposal_samples, key = self._resample_proposal(potential_fn, key, **kwargs)
initial_state.append(proposal_samples)
return jnp.concatenate(initial_state, axis=0)
def _resample_proposal(
self,
potential_fn: Callable,
key: Array,
num_candidate_samples: int = 10_000,
num_batches: int = 1,
**kwargs,
):
"""
Resample the proposal samples using the neural density estimator.
Parameters
----------
potential_fn : Callable
The potential function of the MCMC sampler.
num_candidate_samples : int
The number of candidate samples to generate.
num_batches : int
The number of batches to generate the candidate samples.
Returns
-------
Array
The proposal samples.
Array
The modified random key used to generate the samples.
"""
log_weights = []
init_state_candidates = []
for _ in range(num_batches):
subkey, key = jax.random.split(key)
batch_draws = self.prior_distr.sample(
sample_shape=(num_candidate_samples,), key=subkey
)
init_state_candidates.append(batch_draws)
log_weights.append(potential_fn(batch_draws))
log_weights = jnp.concatenate(log_weights, axis=0)
init_state_candidates = jnp.concatenate(init_state_candidates, axis=0)
# Normalize the weights in log-space.
log_weights = log_weights - jax.scipy.special.logsumexp(log_weights, axis=0)
probs = jnp.exp(log_weights)
probs = probs.at[jnp.isnan(probs)].set(0.0)
probs = probs.at[jnp.isinf(probs)].set(0.0)
probs = probs / jnp.sum(probs)
subkey, key = jax.random.split(key)
idxs = jax.random.choice(
subkey,
jnp.arange(num_candidate_samples * num_batches),
shape=(1,),
replace=True,
p=probs,
)
proposal_samples = init_state_candidates[idxs]
return proposal_samples, key
[docs]
def set_default_x(self, x: Array):
"""Set the default data for the posterior."""
self.x = x
[docs]
def set_prior(self, prior_distr: dist.Distribution):
"""Set the prior distribution for the parameters."""
self.prior_distr = prior_distr
[docs]
def set_mcmc_method(self, mcmc_method: str):
"""Set the MCMC method to use."""
if mcmc_method not in implemented_method:
raise NotImplementedError(
f"The MCMC method {mcmc_method} is not implemented. Check print_implemented_methods()."
)
self.mcmc_method = mcmc_method
[docs]
def set_mcmc_kwargs(self, mcmc_kwargs: dict):
"""Set the keyword arguments for the MCMC method."""
self.mcmc_kwargs = mcmc_kwargs
[docs]
def print_implemented_methods(self):
"""Print the implemented MCMC methods."""
print(f"Implemented MCMC methods: {implemented_method}")