Source code for jaxili.loss

"""Loss.

This module contains useful loss functions used in the neural network training.
"""

from typing import Any

import jax
import jax.numpy as jnp
from jaxtyping import Array, Float, PyTree

MMD_BANDWIDTH_LIST = [
    1e-6,
    1e-5,
    1e-4,
    1e-3,
    1e-2,
    1e-1,
    1,
    5,
    10,
    15,
    20,
    25,
    30,
    35,
    100,
    1e3,
    1e4,
    1e5,
    1e6,
]


[docs] def loss_nll_npe(model: Any, params: PyTree, batch: Any) -> Array: """ Negative log-likelihood loss function for NPE methods using a given neural network as a model. In NPE, the log-probability is given by the density estimator in parameter space conditioned on the data. The batch should take the form of a Tuple of Arrays were the first one corresponds to the parameters and the second to the simulation outputs. Parameters ---------- model : Any Neural network model from `jaxili.model`. params : PyTree Parameters of the neural network. batch : Any Batch of (parameters, outputs) to compute the loss. Returns ------- Array Mean of the negative log-likelihood loss across the batch. """ thetas, xs = batch output = model.apply({"params": params}, thetas, xs, method="log_prob") return -jnp.mean(output)
[docs] def loss_nll_nle(model: Any, params: PyTree, batch: Any): """ Negative log-likelihood loss function for NLE methods using a given neural network as a model. In NLE, the log-probability is given by the density estimator in observation space conditioned on the parameter. The batch should take the form of a Tuple of Arrays were the first one corresponds to the parameters and the second to the simulation outputs. Parameters ---------- model : Any Neural network model from `jaxili.model`. params : PyTree Parameters of the neural network. batch : Any Batch of (parameters, outputs) to compute the loss. Returns ------- Array Mean of the negative log-likelihood loss across the batch. """ thetas, xs = batch return -jnp.mean(model.apply({"params": params}, xs, thetas, method="log_prob"))
[docs] def gaussian_kernel_matrix(x, y, sigmas=None): """ Compute a Gaussian radial basis functions (RBFs) between the samples of x and y. We create a sum of multiple Gaussian kernels each having a width sigma_i. Parameters ---------- x: array of shape (num_draws_x, num_features) y: array of shape (num_draws_y, num_features) sigmas: list(float), optional, default: None List which denotes the width of each of the Gaussian in the kernel. A default range is used if sigmas is None. Returns ------- kernel values: array of shape (num_draws_x, num_draws_y) """ if sigmas is None: sigmas = jnp.array(MMD_BANDWIDTH_LIST) norm = lambda v: jnp.sum(v**2, axis=1) beta = 1.0 / (2.0 * (jnp.expand_dims(sigmas, 1))) dist = jnp.transpose(norm(jnp.expand_dims(x, 2) - jnp.transpose(y))) s = jnp.matmul(beta, jnp.reshape(dist, (1, -1))) kernel = jnp.reshape(jnp.sum(jnp.exp(-s), axis=0), jnp.shape(dist)) return kernel
[docs] def mmd_kernel(x, y, kernel): """ Compute the Maximum Mean Discrepancy (MMD) between samples of x and y. Parameters ---------- x: array of shape (num_draws_x, num_features) y: array of shape (num_draws_y, num_features) kernel: function A kernel function which computes the similarity between two sets of samples. Returns ------- float Maximum Mean Discrepancy (MMD) between x and y. """ return jnp.mean(kernel(x, x)) + jnp.mean(kernel(y, y)) - 2 * jnp.mean(kernel(x, y))
[docs] def maximum_mean_discrepancy( source_samples, target_samples, kernel="gaussian", mmd_weight=1.0, minimum=0.0 ): """ Compute the Maximum Mean Discrepancy (MMD) between source and target samples. Parameters ---------- source_samples: samples from the source distribution. Shape: (N, num_features) target_samples: samples from the target distribution. Shape: (M, num_features) kernel: kernel function to use for the MMD computation. str: "gaussian" mmd_weight: weight for the MMD loss. Default: 1.0 minimum: minimum value for the MMD loss. Default: 0.0 Returns ------- float Maximum Mean Discrepancy (MMD) between source and target samples. """ assert kernel == "gaussian", "Only Gaussian kernel is supported for now" if kernel == "gaussian": kernel_fun = gaussian_kernel_matrix loss_value = mmd_kernel(source_samples, target_samples, kernel=kernel_fun) loss_value = jnp.maximum(loss_value, minimum) * mmd_weight return loss_value
[docs] def mmd_summary_space(summary_outputs, rng, z_dist="gaussian", kernel="gaussian"): """ Compute the Maximum Mean Discrepancy (MMD) between the summary outputs and samples from a unit Gaussian distribution. Parameters ---------- summary_outputs: array of shape (num_samples, num_features) Summary outputs from the neural network. rng: jax.random.PRNGKey Random key for reproducibility. z_dist: str, optional Distribution of the samples. Default: "gaussian" kernel: str, optional Kernel function to use for the MMD computation. Default: "gaussian" Returns ------- float Maximum Mean Discrepancy (MMD) between the summary outputs and samples from a unit Gaussian distribution. """ assert z_dist == "gaussian", "Only Gaussian distribution is supported for now" assert kernel == "gaussian", "Only Gaussian kernel is supported for now" z_samples = jax.random.normal(rng, shape=summary_outputs.shape) mmd_loss = maximum_mean_discrepancy(summary_outputs, z_samples, kernel=kernel) return mmd_loss
[docs] def loss_mmd_npe(model, params, batch): """ Compute the Maximum Mean Discrepancy (MMD) loss for Neural Posterior Estimation. Parameters ---------- compress: function Neural network function to compress the data. nf: function Neural network function to compute the log-probability conditionally to a random variable. params: jnp.array Parameters of the neural network. batch: jnp.array Batch of data. Returns ------- float Maximum Mean Discrepancy (MMD) loss. """ compress = lambda params, x: model.apply({"params": params}, x, method="compress") nf = lambda params, theta, z: model.apply( {"params": params}, z, theta, model="NPE", method="log_prob_from_compressed" ) theta, x = batch # compress the data z = compress(params, x) # Compute the MMD loss rng_key = jax.random.PRNGKey(0) # Could probably be improved mmd_loss = mmd_summary_space(z, rng_key) # Compute the log-probability log_prob = nf(params, theta, z) return -jnp.mean(log_prob) + mmd_loss