jaxili.posterior.direct_posterior module

jaxili.posterior.direct_posterior module#

Direct Posterior.

This module contains the Direct Posterior class. It is used when doing neural posterior estimation where the neural network is trained to approximate the posterior directly.

class jaxili.posterior.direct_posterior.DirectPosterior(model: NDENetwork, state: TrainState, verbose: bool = False, x: Array | None = None)[source]#

Bases: NeuralPosterior

Posterior $p(theta|x)$ with log_prob() and sample() methods.

The class wraps the trained neural network using Neural Posterior Estimation (NPE).

Methods

sample(num_samples, key[, x])

Sample from the posterior.

set_default_x(x)

Set the default data for the posterior.

unnormalized_log_prob(theta[, x])

Compute the unnormalized log probability of the posterior.

sample(num_samples: int, key: Array, x: Array | None = None)[source]#

Sample from the posterior.

Parameters:
  • num_samples (int) – The number of samples to draw.

  • key (Array) – The random key used to generate the samples.

  • x (Array) – The data used to condition the posterior.

Returns:

theta – The samples from the posterior.

Return type:

Array

unnormalized_log_prob(theta: Array, x: Array | None = None)[source]#

Compute the unnormalized log probability of the posterior.

Parameters:
  • theta (Array) – The parameters to evaluate the log probability.

  • x (Array) – The data used to condition the posterior.

Returns:

log_prob – The unnormalized log probability.

Return type:

Array