Source code for jaxili.posterior.direct_posterior
"""
Direct Posterior.
This module contains the Direct Posterior class. It is used when doing neural posterior estimation where the neural network is trained to approximate the posterior directly.
"""
from typing import Optional
import jax.numpy as jnp
from jaxtyping import Array
from jaxili.model import NDENetwork
from jaxili.posterior import NeuralPosterior
from jaxili.train import TrainState
[docs]
class DirectPosterior(NeuralPosterior):
r"""
Posterior $p(\theta|x)$ with `log_prob()` and `sample()` methods.
The class wraps the trained neural network using Neural Posterior Estimation (NPE).
"""
def __init__(
self,
model: NDENetwork,
state: TrainState,
verbose: bool = False,
x: Optional[Array] = None,
):
"""
Initialize the Neural Posterior.
Parameters
----------
model : NDENetwork
The neural network used to generate the posterior.
state : dict
The state of the neural network.
verbose : bool
Whether to print information. (Default: False)
"""
super().__init__(model, state, verbose, x)
[docs]
def sample(self, num_samples: int, key: Array, x: Optional[Array] = None):
r"""
Sample from the posterior.
Parameters
----------
num_samples : int
The number of samples to draw.
key : Array
The random key used to generate the samples.
x : Array
The data used to condition the posterior.
Returns
-------
theta : Array
The samples from the posterior.
"""
if x is None:
x = self.x
if x is None:
raise ValueError(
"Please set the default data `x` using `set_default_x()` or provide `x` as an argument."
)
params = self.state.params
samples = self.model.apply(
{"params": params}, x, num_samples, key, method="sample"
)
return samples
[docs]
def unnormalized_log_prob(self, theta: Array, x: Optional[Array] = None):
r"""
Compute the unnormalized log probability of the posterior.
Parameters
----------
theta : Array
The parameters to evaluate the log probability.
x : Array
The data used to condition the posterior.
Returns
-------
log_prob : Array
The unnormalized log probability.
"""
if x is None:
x = self.x
if x is None:
raise ValueError(
"Please set the default data `x` using `set_default_x()` or provide `x` as an argument."
)
params = self.state.params
if len(x.shape) == 1:
x = jnp.expand_dims(x, axis=0)
if (x.shape[0] == 1) and (theta.shape[0] > 1):
x = jnp.repeat(x, theta.shape[0], axis=0)
elif x.shape[0] != theta.shape[0]:
raise ValueError(
"The batch size of `x` must be the same as the batch size of parameters `theta`."
)
log_prob = self.model.apply({"params": params}, theta, x, method="log_prob")
return log_prob