jaxili.inference.nle module#
NLE.
This modules provides a Neural Likelihood Estimation (NLE) class to train a neural density estimator to perform NLE.
- class jaxili.inference.nle.NLE(model_class: ~jaxili.model.NDENetwork = <class 'jaxili.model.ConditionalMAF'>, logging_level: int | str = 'WARNING', verbose: bool = True, model_hparams: ~typing.Dict[str, ~typing.Any] | None = {'activation': <jax._src.custom_derivatives.custom_jvp object>, 'layers': [50, 50], 'n_layers': 5, 'seed': 42, 'use_reverse': True}, loss_fn: ~typing.Callable = <function loss_nll_nle>)[source]#
Bases:
objectNLE.
Base class for Neural Likelihood Estimation (NLE) methods. Default configuration used a ConditionalMAF to learn the likelihood function.
Examples
>>> from jaxili.inference import NLE >>> inference = NLE() >>> theta, x = ... # Load parameters and simulation outputs >>> inference.append_simulations(theta, x) #Push your simulations in the trainer >>> inference.train() #Train your density estimator
Methods
append_simulations(theta, x[, ...])Store parameters and simulation outputs to use them for later training.
build_posterior(prior_distr[, verbose, x, ...])Build the posterior distribution $p(theta|x)$ using the trained density estimator.
create_trainer(optimizer_hparams[, seed, ...])Create a TrainerModule for the density estimator.
load_from_checkpoints(checkpoint, exmp_input)Create a NLE object where the TrainerModule is loading the already existing weights for the neural network.
set_dataloader(dataloader, type)Set the dataloader to use for training, validation or testing.
set_dataset(dataset, type)Set the dataset to use for training, validation or testing.
set_loss_fn(loss_fn)Set the loss function to use for training.
set_model_hparams(hparams)Set the hyperparameters of the model.
train([training_batch_size, learning_rate, ...])Train the density estimator to approximate the distribution $p(theta|x)$.
- append_simulations(theta: Array, x: Array, train_test_split: Iterable[float] = [0.7, 0.2, 0.1], key: PyTree | None = None)[source]#
Store parameters and simulation outputs to use them for later training.
Data is stored in a Dataset object from jax-dataloader
- Parameters:
theta (Array) – Parameters of the simulations.
x (Array) – Simulation outputs.
train_test_split (Iterable[float], optional) – Fractions to split the dataset into training, validation and test sets. Should be of length 2 or 3. A length 2 list will not generate a test set. Default is [0.7, 0.2, 0.1].
key (PyTree, optional) – Key to use for the random permutation of the dataset. Default is None.
- build_posterior(prior_distr: Distribution, verbose: bool | None = None, x: Array | None = None, mcmc_method: str | None = 'nuts_numpyro', mcmc_kwargs: Dict[str, Any] | None = {})[source]#
Build the posterior distribution $p(theta|x)$ using the trained density estimator.
- Parameters:
prior_distr (dist.Distribution) – Numpyro distribution sampling the prior used to estimate the parameters.
verbose (bool, optional) – Whether to print information. Default is the verbiose boolean of the trainer.
x (Array, optional) – The data used to condition the posterior. Default is None.
mcmc_method (str, optional) – The MCMC method to use. Default is ‘nuts_numpyro’.
mcmc_kwargs (dict, optional) – The jeyword arguments to sample from the posterior.
- Returns:
posterior – The posterior distribution allowing to sample and evaluate the unnormalized log-probability. The sampling is performed using MCMC methods.
- Return type:
- create_trainer(optimizer_hparams: Dict[str, Any], seed: int = 42, logger_params: Dict[str, Any] = None, debug: bool = False, check_val_every_epoch: int = 1, **kwargs)[source]#
Create a TrainerModule for the density estimator.
- Parameters:
optimizer_hparams (Dict[str, Any]) – Hyperparameters to use for the optimizer.
loss_fn (Callable) – Loss function to use for training.
exmp_input (Any) – Example input to use for the model.
seed (int, optional) – Seed to use for the trainer. Default is 42.
logger_params (Dict[str, Any], optional) – Parameters to use for the logger. Default is None.
debug (bool, optional) – Whether to use debug mode. Default is False.
check_val_every_epoch (int, optional) – Frequency at which to check the validation loss. Default is 1.
- classmethod load_from_checkpoints(checkpoint: str, exmp_input: ~typing.Any, embedding_net_class=<class 'jaxili.compressor.Identity'>) Any[source]#
Create a NLE object where the TrainerModule is loading the already existing weights for the neural network.
- Parameters:
nde_class (NDENetwork) – Class used to create the neural density estimator
checkpoint (str) – Folder in which the checkpoint and hyperparameter file is stored
exmp_input (Any) – An input to the model with which the shapes are inferred.
embedding_net_class (nn.Module) – Class used to create the embedding net. (Default: Identity)
- Return type:
A NLE object containing a model with the pre-trained weights loaded.
- set_dataloader(dataloader, type)[source]#
Set the dataloader to use for training, validation or testing.
- Parameters:
dataloader (data.DataLoader) – dataloader to use.
type (str) – Type of the dataloader. Can be ‘train’, ‘val’ or ‘test’.
- set_dataset(dataset, type)[source]#
Set the dataset to use for training, validation or testing.
- Parameters:
dataset (data.Dataset) – Dataset to use.
type (str) – Type of the dataset. Can be ‘train’, ‘val’ or ‘test’.
- set_loss_fn(loss_fn)[source]#
Set the loss function to use for training.
- Parameters:
loss_fn (Callable) – Loss function to use for training.
- set_model_hparams(hparams)[source]#
Set the hyperparameters of the model.
- Parameters:
hparams (Dict[str, Any]) – Hyperparameters to use for the model.
- train(training_batch_size: int = 50, learning_rate: float = 0.0005, patience: int = 20, num_epochs: int = 2147483647, check_val_every_epoch: int = 1, **kwargs)[source]#
Train the density estimator to approximate the distribution $p(theta|x)$.
- Parameters:
training_batch_size (int, optional) – Batch size to use during training. Default is 50.
learning_rate (float, optional) – Learning rate to use during training. Default is 5e-4.
patience (int, optional) – Number of epochs to wait before early stopping. Default is 20.
num_epochs (int, optional) – Maximum number of epochs to train. Default is 2**31 - 1.
check_val_every_epoch (int, optional) – Frequency at which to check the validation loss. Default is 1.
**kwargs (dict, optional) –
Additional keyword arguments for training customization:
optimizer_name (str): Name of the optimizer to use (default: ‘adam’).
gradient_clip (float): Value for gradient clipping (default: 5.0).
warmup (float): Warmup proportion for learning rate scheduling (default: 0.1).
weight_decay (float): Weight decay (L2 regularization) (default: 0.0).
checkpoint_path (str): Directory to save training checkpoints (default: ‘checkpoints/’).
log_dir (str or None): Directory for logging (default: None).
logger_type (str): Type of logger to use (default: ‘TensorBoard’).
seed (int): Random seed for reproducibility (default: 42).
debug (bool): Whether to run in debug mode (default: False).
min_delta (float): Minimum change in validation loss to qualify as improvement (default: 1e-3).
- Returns:
metrics (Dict[str, Any]) – Dictionary containing the training, validation and test losses.
density_estimator (nn.Module) – The trained density estimator.