jaxili.utils module#

Utils.

This module contains utility functions used in the JaxILI package. Some functions are used to format the input data for the training. Other functions allow to check the validity of the input data.

jaxili.utils.check_density_estimator(estimator_arg: str)[source]#

Check density estimator argument to see if it belongs to the authorized network.

Parameters:

estimator_arg (str) – Density estimator argument to check.

jaxili.utils.check_hparams_maf(hparams: dict)[source]#

Check the hyperparameters of the Masked Autoregressive Flow.

Parameters:

hparams (dict) – Dictionary with the hyperparameters of the MAF.

jaxili.utils.check_hparams_mdn(hparams: dict)[source]#

Check the hyperparameters of the Mixture Density Network.

Parameters:

hparams (dict) – Dictionary with the hyperparameters of the MDN.

jaxili.utils.check_hparams_realnvp(hparams: dict)[source]#

Check the hyperparameters of the RealNVP.

Parameters:

hparams (dict) – Dictionary with the hyperparameters of the RealNVP.

jaxili.utils.create_data_loader(*datasets: Sequence[DataLoader], train: bool | Sequence[bool] = True, batch_size: int = 128)[source]#

Create data loaders from a set of datasets.

Parameters:
  • datasets (Datasets for which data loaders are created.)

  • train (Sequence indicating which datasets are used for training and which not.)

  • bool (If single)

  • datasets. (the same value is used for all)

  • batch_size (Batch size to use in the data loaders.)

Returns:

List of data loaders.

Return type:

list[jdl.DataLoader]

jaxili.utils.handle_non_serializable(obj)[source]#

Replace or transform objects into something serializable to save metadata from training.

Custom handler for non-serializable objects.

Parameters:

obj (Any) – Object to handle.

jaxili.utils.validate_theta_x(theta: Any, x: Any)[source]#

Check if the passed $(theta, x)$ pair is valid.

We check that: - $theta$ and $x$ are jax arrays - $theta$ and $x$ have the same number of samples. - $theta$ and $x$ have dtype=float32.

Raises:

AssertionError if $theta$ and $x$ are not jax arrays, do not have the same batch size or are not dtype==np.float32.

Parameters:
  • theta (Any) – Parameters of the simulations.

  • x (Any) – Simulation outputs.