jaxili.loss module#

Loss.

This module contains useful loss functions used in the neural network training.

jaxili.loss.gaussian_kernel_matrix(x, y, sigmas=None)[source]#

Compute a Gaussian radial basis functions (RBFs) between the samples of x and y.

We create a sum of multiple Gaussian kernels each having a width sigma_i.

Parameters:
  • x (array of shape (num_draws_x, num_features))

  • y (array of shape (num_draws_y, num_features))

  • sigmas (list(float), optional, default: None) – List which denotes the width of each of the Gaussian in the kernel. A default range is used if sigmas is None.

Returns:

kernel values

Return type:

array of shape (num_draws_x, num_draws_y)

jaxili.loss.loss_mmd_npe(model, params, batch)[source]#

Compute the Maximum Mean Discrepancy (MMD) loss for Neural Posterior Estimation.

Parameters:
  • compress (function) – Neural network function to compress the data.

  • nf (function) – Neural network function to compute the log-probability conditionally to a random variable.

  • params (jnp.array) – Parameters of the neural network.

  • batch (jnp.array) – Batch of data.

Returns:

Maximum Mean Discrepancy (MMD) loss.

Return type:

float

jaxili.loss.loss_nll_nle(model: Any, params: PyTree, batch: Any)[source]#

Negative log-likelihood loss function for NLE methods using a given neural network as a model.

In NLE, the log-probability is given by the density estimator in observation space conditioned on the parameter. The batch should take the form of a Tuple of Arrays were the first one corresponds to the parameters and the second to the simulation outputs.

Parameters:
  • model (Any) – Neural network model from jaxili.model.

  • params (PyTree) – Parameters of the neural network.

  • batch (Any) – Batch of (parameters, outputs) to compute the loss.

Returns:

Mean of the negative log-likelihood loss across the batch.

Return type:

Array

jaxili.loss.loss_nll_npe(model: Any, params: PyTree, batch: Any) Array[source]#

Negative log-likelihood loss function for NPE methods using a given neural network as a model.

In NPE, the log-probability is given by the density estimator in parameter space conditioned on the data. The batch should take the form of a Tuple of Arrays were the first one corresponds to the parameters and the second to the simulation outputs.

Parameters:
  • model (Any) – Neural network model from jaxili.model.

  • params (PyTree) – Parameters of the neural network.

  • batch (Any) – Batch of (parameters, outputs) to compute the loss.

Returns:

Mean of the negative log-likelihood loss across the batch.

Return type:

Array

jaxili.loss.maximum_mean_discrepancy(source_samples, target_samples, kernel='gaussian', mmd_weight=1.0, minimum=0.0)[source]#

Compute the Maximum Mean Discrepancy (MMD) between source and target samples.

Parameters:
  • source_samples (samples from the source distribution. Shape: (N, num_features))

  • target_samples (samples from the target distribution. Shape: (M, num_features))

  • kernel (kernel function to use for the MMD computation. str: "gaussian")

  • mmd_weight (weight for the MMD loss. Default: 1.0)

  • minimum (minimum value for the MMD loss. Default: 0.0)

Returns:

Maximum Mean Discrepancy (MMD) between source and target samples.

Return type:

float

jaxili.loss.mmd_kernel(x, y, kernel)[source]#

Compute the Maximum Mean Discrepancy (MMD) between samples of x and y.

Parameters:
  • x (array of shape (num_draws_x, num_features))

  • y (array of shape (num_draws_y, num_features))

  • kernel (function) – A kernel function which computes the similarity between two sets of samples.

Returns:

Maximum Mean Discrepancy (MMD) between x and y.

Return type:

float

jaxili.loss.mmd_summary_space(summary_outputs, rng, z_dist='gaussian', kernel='gaussian')[source]#

Compute the Maximum Mean Discrepancy (MMD) between the summary outputs and samples from a unit Gaussian distribution.

Parameters:
  • summary_outputs (array of shape (num_samples, num_features)) – Summary outputs from the neural network.

  • rng (jax.random.PRNGKey) – Random key for reproducibility.

  • z_dist (str, optional) – Distribution of the samples. Default: “gaussian”

  • kernel (str, optional) – Kernel function to use for the MMD computation. Default: “gaussian”

Returns:

Maximum Mean Discrepancy (MMD) between the summary outputs and samples from a unit Gaussian distribution.

Return type:

float