Source code for jaxili.posterior.base_posterior

"""
Base Posterior.

This module contains the base class for Neural Posteriors. Classes used to sample in NPE and NLE will inherit from this class.
"""

from abc import abstractmethod
from typing import Any, Dict, Optional

from jaxtyping import Array

from jaxili.model import NDENetwork
from jaxili.train import TrainState


[docs] class NeuralPosterior: r""" Posterior $p(\theta|x)$ with `log_prob()` and `sample()` methods. The class wraps the trained neural network such that one can directly evaluate the log-probability and sample from the posterior. """ def __init__( self, model: NDENetwork, state: TrainState, verbose: bool = False, x: Optional[Array] = None, ): """ Initialize the Neural Posterior. Parameters ---------- model : NDENetwork The neural network used to generate the posterior. state : dict The state of the neural network. verbose : bool Whether to print information. (Default: False) """ self.model = model self.state = state self.verbose = verbose self.x = x
[docs] @abstractmethod def sample( self, num_samples: int, key: Array, x: Array, mcmc_method: Optional[str] = None, mcmc_kwargs: Optional[Dict[str, Any]] = None, ): """Define abstract method to sample from the posterior. The sampling method depends on the methodology used.""" pass
[docs] @abstractmethod def unnormalized_log_prob( self, theta: Array, ): """Define abstract method to compute the unnormalized log-probability of a given parameter.""" pass
[docs] def set_default_x(self, x: Array): """Set the default data for the posterior.""" self.x = x