Source code for jaxili.compressor
"""
Compressor.
This module contains classes that implement compressors used in JaxILI.
"""
from typing import Any, Callable
from jaxtyping import Array
import flax.linen as nn
[docs]
class Identity(nn.Module):
"""Identity transformation."""
@nn.compact
def __call__(self, x):
"""
Forward pass of the identity transformation.
Parameters
----------
x : jnp.Array
Input data.
Returns
-------
jnp.Array
Output data.
"""
return x
[docs]
class Standardizer(nn.Module):
"""Standardizer transformation."""
mean: Array
std: Array
@nn.compact
def __call__(self, x):
"""
Forward pass of the standardizer transformation. The standardization uses the z-score.
Parameters
----------
x : jnp.Array
Input data.
Returns
-------
jnp.Array
Standardized data.
"""
return (x - self.mean) / self.std
[docs]
class MLPCompressor(nn.Module):
"""
Base class of a MLP Compressor.
Defines a MLP compressor to send the summary statistic to the same dimension than the parameters.
Parameters
----------
hidden_size : list
List with the size of the hidden layers.
activation : Callable
Activation function. Preferably from `jax.nn` or `jax.nn.activation`.
output_size : int
Size of the output layer.
"""
hidden_size: list
activation: Callable
output_size: int
@nn.compact
def __call__(self, x):
"""
Forward pass of the MLP Compressor.
Parameters
----------
x : jnp.array
Input data.
Returns
-------
jnp.array
Compressed data.
"""
for size in self.hidden_size:
x = nn.Dense(size)(x)
x = self.activation(x)
x = nn.Dense(self.output_size)(x)
return x
[docs]
class CNN2DCompressor(nn.Module):
"""
Base class of a CNN2D Compressor.
Defines a 2 dimensional Convolutional Neural Network to compress the data to the same dimension as the parameters.
Parameters
----------
output_size : int
Size of the output layer
activation : Callable
Activation function. Preferably from `jax.nn` or `jax.nn.activation`.
"""
output_size: int
activation: Callable
@nn.compact
def __call__(self, inputs):
"""
Forward pass of the CNN2D Compressor.
Parameters
----------
inputs : jnp.array
Input data.
Returns
-------
jnp.array
Compressed data.
"""
net_x = nn.Conv(32, 3, 2)(inputs)
net_x = self.activation(net_x)
net_x = nn.Conv(64, 3, 2)(net_x)
net_x = self.activation(net_x)
net_x = nn.Conv(128, 3, 2)(net_x)
net_x = self.activation(net_x)
net_x = nn.avg_pool(net_x, (16, 16), (8, 8), padding="SAME")
# Flatten the tensor
net_x = net_x.reshape((net_x.shape[0], -1))
net_x = nn.Dense(64)(net_x)
net_x = self.activation(net_x)
net_x = nn.Dense(self.output_size)(net_x)
return net_x.squeeze()