jaxili.inference.nle module#

NLE.

This modules provides a Neural Likelihood Estimation (NLE) class to train a neural density estimator to perform NLE.

class jaxili.inference.nle.NLE(model_class: ~jaxili.model.NDENetwork = <class 'jaxili.model.ConditionalMAF'>, logging_level: int | str = 'WARNING', verbose: bool = True, model_hparams: ~typing.Dict[str, ~typing.Any] | None = {'activation': <jax._src.custom_derivatives.custom_jvp object>, 'layers': [50, 50], 'n_layers': 5, 'seed': 42, 'use_reverse': True}, loss_fn: ~typing.Callable = <function loss_nll_nle>)[source]#

Bases: object

NLE.

Base class for Neural Likelihood Estimation (NLE) methods. Default configuration used a ConditionalMAF to learn the likelihood function.

Examples

>>> from jaxili.inference import NLE
>>> inference = NLE()
>>> theta, x = ...  # Load parameters and simulation outputs
>>> inference.append_simulations(theta, x) #Push your simulations in the trainer
>>> inference.train() #Train your density estimator

Methods

append_simulations(theta, x[, ...])

Store parameters and simulation outputs to use them for later training.

build_posterior(prior_distr[, verbose, x, ...])

Build the posterior distribution $p(theta|x)$ using the trained density estimator.

create_trainer(optimizer_hparams[, seed, ...])

Create a TrainerModule for the density estimator.

load_from_checkpoints(checkpoint, exmp_input)

Create a NLE object where the TrainerModule is loading the already existing weights for the neural network.

set_dataloader(dataloader, type)

Set the dataloader to use for training, validation or testing.

set_dataset(dataset, type)

Set the dataset to use for training, validation or testing.

set_loss_fn(loss_fn)

Set the loss function to use for training.

set_model_hparams(hparams)

Set the hyperparameters of the model.

train([training_batch_size, learning_rate, ...])

Train the density estimator to approximate the distribution $p(theta|x)$.

append_simulations(theta: Array, x: Array, train_test_split: Iterable[float] = [0.7, 0.2, 0.1], key: PyTree | None = None)[source]#

Store parameters and simulation outputs to use them for later training.

Data is stored in a Dataset object from jax-dataloader

Parameters:
  • theta (Array) – Parameters of the simulations.

  • x (Array) – Simulation outputs.

  • train_test_split (Iterable[float], optional) – Fractions to split the dataset into training, validation and test sets. Should be of length 2 or 3. A length 2 list will not generate a test set. Default is [0.7, 0.2, 0.1].

  • key (PyTree, optional) – Key to use for the random permutation of the dataset. Default is None.

build_posterior(prior_distr: Distribution, verbose: bool | None = None, x: Array | None = None, mcmc_method: str | None = 'nuts_numpyro', mcmc_kwargs: Dict[str, Any] | None = {})[source]#

Build the posterior distribution $p(theta|x)$ using the trained density estimator.

Parameters:
  • prior_distr (dist.Distribution) – Numpyro distribution sampling the prior used to estimate the parameters.

  • verbose (bool, optional) – Whether to print information. Default is the verbiose boolean of the trainer.

  • x (Array, optional) – The data used to condition the posterior. Default is None.

  • mcmc_method (str, optional) – The MCMC method to use. Default is ‘nuts_numpyro’.

  • mcmc_kwargs (dict, optional) – The jeyword arguments to sample from the posterior.

Returns:

posterior – The posterior distribution allowing to sample and evaluate the unnormalized log-probability. The sampling is performed using MCMC methods.

Return type:

NeuralPosterior

create_trainer(optimizer_hparams: Dict[str, Any], seed: int = 42, logger_params: Dict[str, Any] = None, debug: bool = False, check_val_every_epoch: int = 1, **kwargs)[source]#

Create a TrainerModule for the density estimator.

Parameters:
  • optimizer_hparams (Dict[str, Any]) – Hyperparameters to use for the optimizer.

  • loss_fn (Callable) – Loss function to use for training.

  • exmp_input (Any) – Example input to use for the model.

  • seed (int, optional) – Seed to use for the trainer. Default is 42.

  • logger_params (Dict[str, Any], optional) – Parameters to use for the logger. Default is None.

  • debug (bool, optional) – Whether to use debug mode. Default is False.

  • check_val_every_epoch (int, optional) – Frequency at which to check the validation loss. Default is 1.

classmethod load_from_checkpoints(checkpoint: str, exmp_input: ~typing.Any, embedding_net_class=<class 'jaxili.compressor.Identity'>) Any[source]#

Create a NLE object where the TrainerModule is loading the already existing weights for the neural network.

Parameters:
  • nde_class (NDENetwork) – Class used to create the neural density estimator

  • checkpoint (str) – Folder in which the checkpoint and hyperparameter file is stored

  • exmp_input (Any) – An input to the model with which the shapes are inferred.

  • embedding_net_class (nn.Module) – Class used to create the embedding net. (Default: Identity)

Return type:

A NLE object containing a model with the pre-trained weights loaded.

set_dataloader(dataloader, type)[source]#

Set the dataloader to use for training, validation or testing.

Parameters:
  • dataloader (data.DataLoader) – dataloader to use.

  • type (str) – Type of the dataloader. Can be ‘train’, ‘val’ or ‘test’.

set_dataset(dataset, type)[source]#

Set the dataset to use for training, validation or testing.

Parameters:
  • dataset (data.Dataset) – Dataset to use.

  • type (str) – Type of the dataset. Can be ‘train’, ‘val’ or ‘test’.

set_loss_fn(loss_fn)[source]#

Set the loss function to use for training.

Parameters:

loss_fn (Callable) – Loss function to use for training.

set_model_hparams(hparams)[source]#

Set the hyperparameters of the model.

Parameters:

hparams (Dict[str, Any]) – Hyperparameters to use for the model.

train(training_batch_size: int = 50, learning_rate: float = 0.0005, patience: int = 20, num_epochs: int = 2147483647, check_val_every_epoch: int = 1, **kwargs)[source]#

Train the density estimator to approximate the distribution $p(theta|x)$.

Parameters:
  • training_batch_size (int, optional) – Batch size to use during training. Default is 50.

  • learning_rate (float, optional) – Learning rate to use during training. Default is 5e-4.

  • patience (int, optional) – Number of epochs to wait before early stopping. Default is 20.

  • num_epochs (int, optional) – Maximum number of epochs to train. Default is 2**31 - 1.

  • check_val_every_epoch (int, optional) – Frequency at which to check the validation loss. Default is 1.

  • **kwargs (dict, optional) –

    Additional keyword arguments for training customization:

    • optimizer_name (str): Name of the optimizer to use (default: ‘adam’).

    • gradient_clip (float): Value for gradient clipping (default: 5.0).

    • warmup (float): Warmup proportion for learning rate scheduling (default: 0.1).

    • weight_decay (float): Weight decay (L2 regularization) (default: 0.0).

    • checkpoint_path (str): Directory to save training checkpoints (default: ‘checkpoints/’).

    • log_dir (str or None): Directory for logging (default: None).

    • logger_type (str): Type of logger to use (default: ‘TensorBoard’).

    • seed (int): Random seed for reproducibility (default: 42).

    • debug (bool): Whether to run in debug mode (default: False).

    • min_delta (float): Minimum change in validation loss to qualify as improvement (default: 1e-3).

Returns:

  • metrics (Dict[str, Any]) – Dictionary containing the training, validation and test losses.

  • density_estimator (nn.Module) – The trained density estimator.