Source code for jaxili.inference.npe

"""
NPE.

This module provides the Neural Posterior Estimation (NPE) class to train a neural density estimator to perform NPE.
"""

import os
import json
import re
import warnings
import copy
from typing import Any, Callable, Dict, Iterable, Optional, Union

import distrax
import flax.linen as nn
import jax
import jax.numpy as jnp
import jax.random as jr
import jax_dataloader as jdl
import numpy as np
from jaxtyping import Array, Float, PyTree

import jaxili
from jaxili.loss import loss_nll_npe
from jaxili.model import (
    ConditionalMAF,
    ConditionalRealNVP,
    MixtureDensityNetwork,
    NDE_w_Standardization,
)
from jaxili.compressor import Identity, Standardizer
from jaxili.posterior import DirectPosterior
from jaxili.train import TrainerModule
from jaxili.utils import *
from jaxili.utils import (
    check_density_estimator,
    create_data_loader,
    validate_theta_x,
)
from jaxili.inventory.func_dict import jaxili_loss_dict, jax_nn_dict, jaxili_nn_dict

default_maf_hparams = {
    "n_layers": 5,
    "layers": [50, 50],
    "activation": jax.nn.relu,
    "use_reverse": True,
    "seed": 42,
}


[docs] class NPE: """ NPE. Base class for Neural Posterior Estimation (NPE) methods. Default configuration used a `ConditionalMAF` to learn the posterior function. Examples -------- >>> from jaxili.inference import NPE >>> inference = NPE() >>> theta, x = ... # Load parameters and simulation outputs >>> inference.append_simulations(theta, x) #Push your simulations in the trainer >>> inference.train() #Train your density estimator """ def __init__( self, model_class: jaxili.model.NDENetwork = ConditionalMAF, logging_level: Union[int, str] = "WARNING", verbose: bool = True, model_hparams: Optional[Dict[str, Any]] = default_maf_hparams, loss_fn: Callable = loss_nll_npe, ): """ Initialize class for Neural Posterior Estimation (NPE) methods. Parameters ---------- model_class : jaxili.model.NDENetwork Class of the neural density estimator to use. Default: ConditionalMAF. model_hparams : Dict[str, Any] Hyperparameters to use for the model. logging_level: Union[int, str], optional Logging level to use. Default is "WARNING". show_progress_bar : bool, optional Whether to show a progress bar during training. Default is True. """ self._model_class = model_class self._model_hparams = model_hparams self._logging_level = logging_level self._loss_fn = loss_fn self.verbose = verbose
[docs] def set_model_hparams(self, hparams): """ Set the hyperparameters of the model. Parameters ---------- hparams : Dict[str, Any] Hyperparameters to use for the model. """ self._model_hparams = hparams
[docs] def set_loss_fn(self, loss_fn): """ Set the loss function to use for training. Parameters ---------- loss_fn : Callable Loss function to use for training. """ self._loss_fn = loss_fn
[docs] def set_dataset(self, dataset, type): """ Set the dataset to use for training, validation or testing. Parameters ---------- dataset : data.Dataset Dataset to use. type : str Type of the dataset. Can be 'train', 'val' or 'test'. """ assert type in [ "train", "val", "test", ], "Type should be 'train', 'val' or 'test'." if type == "train": self._train_dataset = dataset elif type == "val": self._val_dataset = dataset elif type == "test": self._test_dataset = dataset
[docs] def set_dataloader(self, dataloader, type): """ Set the dataloader to use for training, validation or testing. Parameters ---------- dataloader : data.DataLoader dataloader to use. type : str Type of the dataloader. Can be 'train', 'val' or 'test'. """ assert type in [ "train", "val", "test", ], "Type should be 'train', 'val' or 'test'." if type == "train": self._train_dataloader = dataloader elif type == "val": self._val_dataloader = dataloader elif type == "test": self._test_dataloader = dataloader
[docs] def append_simulations( self, theta: Array, x: Array, train_test_split: Iterable[float] = [0.7, 0.2, 0.1], key: Optional[PyTree] = None, ): """ Store parameters and simulation outputs to use them for later training. Data is stored in a Dataset object from `jax-dataloader` Parameters ---------- theta : Array Parameters of the simulations. x : Array Simulation outputs. train_test_split : Iterable[float], optional Fractions to split the dataset into training, validation and test sets. Should be of length 2 or 3. A length 2 list will not generate a test set. Default is [0.7, 0.2, 0.1]. key : PyTree, optional Key to use for the random permutation of the dataset. Default is None. """ # Verify theta and x typing and size of the dataset theta, x, num_sims = validate_theta_x(theta, x) if self.verbose: print(f"[!] Inputs are valid.") print(f"[!] Appending {num_sims} simulations to the dataset.") self._dim_params = theta.shape[1] self._dim_cond = x.shape[1] self._num_sims = num_sims # Split the dataset into training, validation and test sets is_test_set = len(train_test_split) == 3 if is_test_set: train_fraction, val_fraction, test_fraction = train_test_split assert np.isclose( train_fraction + val_fraction + test_fraction, 1.0 ), "The sum of the split fractions should be 1." elif len(train_test_split) == 2: train_fraction, val_fraction = train_test_split assert np.isclose( train_fraction + val_fraction, 1.0 ), "The sum of the split fractions should be 1." else: raise ValueError("train_test_split should have 2 or 3 elements.") if key is None: key = jr.PRNGKey(np.random.randint(0, 1000)) index_permutation = jr.permutation(key, num_sims) train_idx = index_permutation[: int(train_fraction * num_sims)] val_idx = index_permutation[ int(train_fraction * num_sims) : int( (train_fraction + val_fraction) * num_sims ) ] if is_test_set: test_idx = index_permutation[ int((train_fraction + val_fraction) * num_sims) : ] self.set_dataset(jdl.ArrayDataset(theta[train_idx], x[train_idx]), type="train") self.set_dataset(jdl.ArrayDataset(theta[val_idx], x[val_idx]), type="val") self.set_dataset( jdl.ArrayDataset(theta[test_idx], x[test_idx]) if is_test_set else None, type="test", ) if self.verbose: print(f"[!] Dataset split into training, validation and test sets.") print(f"[!] Training set: {len(train_idx)} simulations.") print(f"[!] Validation set: {len(val_idx)} simulations.") if is_test_set: print(f"[!] Test set: {len(test_idx)} simulations.") return self
def _create_data_loader(self, **kwargs): """ Create DataLoaders for the training, validation and test datasets. Can only be executed after appending simulations. Parameters ---------- batch_size : int Batch size to use for the DataLoader. Default is 128. """ try: self._train_dataset except AttributeError: raise ValueError( "No training dataset found. Please append simulations first." ) try: self._val_dataset except AttributeError: raise ValueError( "No validation dataset found. Please append simulations first." ) train = [True, False] if self._test_dataset is None else [True, False, False] batch_size = kwargs.get("batch_size", 128) if self.verbose: print(f"[!] Creating DataLoaders with batch_size {batch_size}.") if self._test_dataset is None: self._train_loader, self._val_loader = create_data_loader( self._train_dataset, self._val_dataset, train=train, **kwargs ) self._test_loader = None else: self._train_loader, self._val_loader, self._test_loader = ( create_data_loader( self._train_dataset, self._val_dataset, self._test_dataset, train=train, batch_size=batch_size, ) ) def _build_neural_network( self, z_score_theta: bool = True, z_score_x: bool = True, embedding_net: nn.Module = Identity, embedding_hparams: dict = None, **kwargs, ): """ Build the neural network for the density estimator. Parameters ---------- z_score_theta : bool, optional Whether to z-score the parameters. Default is True. z_score_x : bool, optional Whether to z-score the simulation outputs. Default is True. embedding_net : nn.Module, optional Neural network to use for embedding. Default is nn.Identity(). """ if self.verbose: print("[!] Building the neural network.") # Check if the model class and hparams are correct if self._model_class == ConditionalMAF: check_hparams_maf(self._model_hparams) elif self._model_class == ConditionalRealNVP: check_hparams_realnvp(self._model_hparams) elif self._model_class == MixtureDensityNetwork: check_hparams_mdn(self._model_hparams) else: warnings.warn( f"Model class {self.model_class} is not a base class of JaxILI.\n Check that the hyperparameters of your network are consistent.", Warning, ) try: self._train_dataset except AttributeError: raise ValueError( "No training dataset found. Please append simulations first." ) # Check if z-score is required for theta. shift = jnp.zeros(self._dim_params) scale = jnp.ones(self._dim_params) if z_score_theta: shift = jnp.mean(self._train_dataset[:][0], axis=0) scale = jnp.std(self._train_dataset[:][0], axis=0) min_std = kwargs.get("min_std", 1e-14) scale = scale.at[scale < min_std].set(min_std) self._transformation_hparams = {"shift": shift, "scale": scale} self._transformation = distrax.ScalarAffine(scale=scale, shift=shift) # Check if z-score is required for x. if z_score_x: shift = jnp.mean(self._train_dataset[:][1], axis=0) scale = jnp.std(self._train_dataset[:][1], axis=0) min_std = kwargs.get("min_std", 1e-14) scale = scale.at[scale < min_std].set(min_std) standardizer = Standardizer(shift, scale) else: standardizer = Identity() if embedding_net == Identity: embedding_net = Identity() else: if embedding_hparams is None: warnings.warn( "An embedding net has been specified but not its hyperparameters. Creating an embedding of the instance `Identity` instead." ) embedding_net = Identity() else: embedding_net = embedding_net(**embedding_hparams) self._embedding_net = nn.Sequential([standardizer, embedding_net]) if isinstance(embedding_net, Identity): n_cond = self._dim_cond else: n_cond = embedding_net.output_size self._model_hparams["n_in"] = self._dim_params self._model_hparams["n_cond"] = n_cond self._nde = self._model_class(**self._model_hparams) model = NDE_w_Standardization( nde=self._nde, embedding_net=self._embedding_net, transformation=self._transformation, ) return model
[docs] def create_trainer( self, optimizer_hparams: Dict[str, Any], seed: int = 42, logger_params: Dict[str, Any] = None, debug: bool = False, check_val_every_epoch: int = 1, **kwargs, ): """ Create a TrainerModule for the density estimator. Parameters ---------- optimizer_hparams : Dict[str, Any] Hyperparameters to use for the optimizer. loss_fn : Callable Loss function to use for training. exmp_input : Any Example input to use for the model. seed : int, optional Seed to use for the trainer. Default is 42. logger_params : Dict[str, Any], optional Parameters to use for the logger. Default is None. debug : bool, optional Whether to use debug mode. Default is False. check_val_every_epoch : int, optional Frequency at which to check the validation loss. Default is 1. """ try: self._nde except AttributeError: z_score_theta = kwargs.get("z_score_theta", True) z_score_x = kwargs.get("z_score_x", True) embedding_net = kwargs.get("embedding_net", Identity) embedding_net_hparams = kwargs.get("embedding_hparams", None) _ = self._build_neural_network( z_score_theta=z_score_theta, z_score_x=z_score_x, embedding_net=embedding_net, embedding_hparams=embedding_net_hparams, ) nde_w_std_hparams = { "nde": self._nde, "embedding_net": self._embedding_net, "transformation": self._transformation, } exmp_input = (jnp.zeros((1, self._dim_params)), jnp.zeros((1, self._dim_cond))) if self.verbose: print("[!] Creating the Trainer module.") self.trainer = TrainerModule( model_class=NDE_w_Standardization, model_hparams=nde_w_std_hparams, optimizer_hparams=optimizer_hparams, loss_fn=self._loss_fn, exmp_input=exmp_input, seed=seed, logger_params=logger_params, enable_progress_bar=self.verbose, debug=debug, check_val_every_epoch=check_val_every_epoch, ) self.trainer.config.update({"nde_hparams": copy.deepcopy(self._model_hparams)}) # Check if there is an activation function to rename if "activation" in self._model_hparams.keys(): self.trainer.config["nde_hparams"]["activation"] = self.trainer.config[ "nde_hparams" ]["activation"].__name__ self.trainer.config.update( {"transformation_hparams": copy.deepcopy(self._transformation_hparams)} ) if embedding_net_hparams is not None: self.trainer.config.update( {"embedding_hparams": copy.deepcopy(embedding_net_hparams)} ) # Check if there is an activation function to rename if "activation" in embedding_net_hparams.keys(): self.trainer.config["embedding_hparams"]["activation"] = ( self.trainer.config["embedding_hparams"]["activation"].__name__ ) self.trainer.write_config(self.trainer.log_dir)
[docs] def train( self, training_batch_size: int = 50, learning_rate: float = 5e-4, patience: int = 20, num_epochs: int = 2**31 - 1, check_val_every_epoch: int = 1, **kwargs, ): r""" Train the density estimator to approximate the distribution $p(\theta|x)$. Parameters ---------- training_batch_size : int, optional Batch size to use during training. Default is 50. learning_rate: float, optional Learning rate to use during training. Default is 5e-4. patience: int, optional Number of epochs to wait before early stopping. Default is 20. num_epochs: int, optional Maximum number of epochs to train. Default is 2**31 - 1. check_val_every_epoch: int, optional Frequency at which to check the validation loss. Default is 1. **kwargs : dict, optional Additional keyword arguments for training customization: - optimizer_name (str): Name of the optimizer to use (default: 'adam'). - gradient_clip (float): Value for gradient clipping (default: 5.0). - warmup (float): Warmup proportion for learning rate scheduling (default: 0.1). - weight_decay (float): Weight decay (L2 regularization) (default: 0.0). - checkpoint_path (str): Directory to save training checkpoints (default: 'checkpoints/'). - log_dir (str or None): Directory for logging (default: None). - logger_type (str): Type of logger to use (default: 'TensorBoard'). - seed (int): Random seed for reproducibility (default: 42). - debug (bool): Whether to run in debug mode (default: False). - min_delta (float): Minimum change in validation loss to qualify as improvement (default: 1e-3). Returns ------- metrics : Dict[str, float] Dictionary containing the training, validation and test losses. density_estimator : nn.Module The trained density estimator. """ try: self._train_dataset except AttributeError: raise ValueError( "No training dataset found. Please append simulations first." ) # Create the dataloaders to perform the training try: self._train_loader except AttributeError: self._create_data_loader(batch_size=training_batch_size) try: metrics = self.trainer.train_model( self._train_loader, self._val_loader, test_loader=self._test_loader, num_epochs=num_epochs, patience=patience, **kwargs, ) except AttributeError: test_optimizer_hparams = kwargs.get("optimizer_hparams", None) if test_optimizer_hparams is not None: warnings.warn( "The optimizer hyperparameters specified will not be taken into account. Please refer to the documentation to modify it. Falling back to default optimizer hyperparameters." ) optimizer_hparams = { "lr": learning_rate, "optimizer_name": kwargs.get("optimizer_name", "adam"), "gradient_clip": kwargs.get("gradient_clip", 5.0), "warmup": kwargs.get("warmup", 0.1), "weight_decay": kwargs.get("weight_decay", 0.0), } logger_params = { "base_log_dir": kwargs.get("checkpoint_path", "checkpoints/"), "log_dir": kwargs.get("log_dir", None), "logger_type": kwargs.get("logger_type", "TensorBoard"), } self.create_trainer( optimizer_hparams=optimizer_hparams, seed=kwargs.get("seed", 42), logger_params=logger_params, debug=kwargs.get("debug", False), check_val_every_epoch=check_val_every_epoch, **kwargs, ) if self.verbose: print("[!] Training the density estimator.") metrics = self.trainer.train_model( self._train_loader, self._val_loader, test_loader=self._test_loader, num_epochs=num_epochs, patience=patience, min_delta=kwargs.get("min_delta", 1e-3), ) if self.verbose: print(f"[!] Training loss: {metrics['train/loss']}") print(f"[!] Validation loss: {metrics['val/loss']}") if self._test_loader is not None: print(f"[!] Test loss: {metrics['test/loss']}") density_estimator = self.trainer.bind_model() return metrics, density_estimator
[docs] def build_posterior( self, verbose: Optional[bool] = None, x: Optional[Array] = None ): r""" Build the posterior distribution $p(\theta|x)$ using the trained density estimator. Parameters ---------- x : Array, optional The data used to condition the posterior. Default is None. Returns ------- posterior : NeuralPosterior The posterior distribution allowing to sample and evaluate the unnormalized log-probability. """ try: self.trainer except AttributeError: raise ValueError("No trainer found. You must first create a trainer.") if verbose is None: verbose = self.verbose posterior = DirectPosterior( model=self.trainer.model, state=self.trainer.state, verbose=verbose, x=x ) if self.verbose: print( r"[!] Posterior $p(\theta| x)$ built. The class DirectPosterior is used to sample and evaluate the log probability." ) return posterior
[docs] @classmethod def load_from_checkpoints( cls, checkpoint: str, exmp_input: Any, embedding_net_class=Identity ) -> Any: """ Create a NPE object where the TrainerModule is loading the already existing weights for the neural network. Parameters ---------- nde_class: NDENetwork Class used to create the neural density estimator checkpoint: str Folder in which the checkpoint and hyperparameter file is stored exmp_input : Any An input to the model with which the shapes are inferred. embedding_net_class: nn.Module Class used to create the embedding net. (Default: Identity) Returns ------- A NPE object containing a model with the pre-trained weights loaded. """ hparams_file = os.path.join(checkpoint, "hparams.json") assert os.path.isfile(hparams_file), "Could not find hparams file." with open(hparams_file, "r") as f: hparams = json.load(f) assert ( hparams["model_class"] == NDE_w_Standardization.__name__ ), "The model has not been trained with NDE_w_Standardization. Check the checkpoint path is correct." hparams.pop("model_class") # Check that the embedding class name is correct. embedding_str = hparams["model_hparams"]["embedding_net"] # Find all class names in the layers list class_names = re.findall(r"(\w+)\s*\(", embedding_str) # The first entry is "Sequential", so we take the next two embedding_classes = [ class_ for class_ in class_names[1:] if class_ != "Array" ] # Skip "Sequential" assert ( embedding_classes[1] == embedding_net_class.__name__ ), "The embedding class does not match. Check that you are using the correct architecture." # Check if the loss function is correct. assert ( hparams["loss_fn"] in jaxili_loss_dict ), "Unknown loss function. Check that the loss function you used comes from `jax.nn`." hparams["loss_fn"] = jaxili_loss_dict[hparams["loss_fn"]] # Create the NDE # Extract the nde string nde_str = hparams["model_hparams"]["nde"] # Use regex to extract the class name nde_class_match = re.match(r"(\w+)\s*\(", nde_str) # Get the class name nde_class_name = nde_class_match.group(1) if nde_class_match else None nde_class = jaxili_nn_dict[nde_class_name] nde_hparams = hparams["nde_hparams"] if "activation" in nde_hparams.keys(): nde_hparams["activation"] = jax_nn_dict[nde_hparams["activation"]] # Create object from the class NPE inference = cls( model_class=nde_class, model_hparams=nde_hparams, loss_fn=hparams["loss_fn"] ) # Create the NDE inference._nde = nde_class(**nde_hparams) # Regenerate the embedding net if embedding_classes[0] == "Identity": standardizer = Identity() elif embedding_classes[0] == "Standardizer": embedding_net_str = hparams["model_hparams"]["embedding_net"] # Regular expressions to extract mean and std arrays mean_match = re.search(r"mean\s*=\s*Array\((\[.*?\])", embedding_net_str) std_match = re.search(r"std\s*=\s*Array\((\[.*?\])", embedding_net_str) # Convert extracted values into NumPy arrays mean_array = ( np.fromstring(mean_match.group(1).strip("[]"), sep=", ") if mean_match else None ) std_array = ( np.fromstring(std_match.group(1).strip("[]"), sep=", ") if std_match else None ) standardizer = Standardizer(mean=mean_array, std=std_array) else: raise ValueError( "The first class of the embedding net should be `Identity` or `Standardizer`." ) if embedding_classes[1] != "Identity": if "embedding_hparams" not in hparams.keys(): raise ValueError( "The embedding net hyperparameters can't be find. Check that you are using the correct checkpoint path." ) if "activation" in hparams["embedding_hparams"].keys(): hparams["embedding_hparams"]["activation"] = jax_nn_dict[ hparams["embedding_hparams"]["activation"] ] embedding_net = embedding_net_class(**hparams["embedding_hparams"]) else: embedding_net = Identity() inference._embedding_net = nn.Sequential(layers=[standardizer, embedding_net]) # Regenerate the transformation of the parameters shift_str = hparams["transformation_hparams"]["shift"] shift_list = [float(x) for x in shift_str.strip("[]").split()] scale_str = hparams["transformation_hparams"]["scale"] scale_list = [float(x) for x in scale_str.strip("[]").split()] inference._transformation = distrax.ScalarAffine( np.array(shift_list), np.array(scale_list) ) model_hparams = { "nde": inference._nde, "embedding_net": inference._embedding_net, "transformation": inference._transformation, } if not hparams["logger_params"]: hparams["logger_params"] = dict() hparams["logger_params"]["log_dir"] = checkpoint hparams.pop("model_hparams") inference.trainer = TrainerModule( model_class=NDE_w_Standardization, exmp_input=exmp_input, model_hparams=model_hparams, **hparams, ) inference.trainer.load_model() return inference