"""Train.
This module implements an object to perform the training of Normalizing Flows and more generally of neural networks.
"""
import json
import os
import time
import warnings
from collections import defaultdict
from copy import copy, deepcopy
from typing import Any, Callable, Dict, Iterator, Optional, Sequence, Tuple
import jax
import jax.numpy as jnp
import optax
import orbax.checkpoint as ocp
from flax import linen as nn
from flax.training import checkpoints, orbax_utils, train_state
from flax.training.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from tqdm import tqdm
from jaxili.model import NDENetwork
from jaxili.inventory.func_dict import jax_nn_dict, jaxili_loss_dict, jaxili_nn_dict
from jaxili.utils import handle_non_serializable
[docs]
class TrainState(train_state.TrainState):
"""
A simple extension of TrainState to also include batch statistics.
If a model has no batch statistics, it is None.
Keep an rng state for dropout or init.
"""
batch_stats: Any = (None,)
rng: Any = None
[docs]
class TrainerModule:
"""
A module to perform the training of Normalizing Flows.
This module contains the training loop, evaluation, logging, and checkpointing. It can also be used to load a model from a checkpoint.
"""
def __init__(
self,
model_class: NDENetwork,
model_hparams: Dict[str, Any],
optimizer_hparams: Dict[str, Any],
loss_fn: Callable,
exmp_input: Any,
seed: int = 42,
logger_params: Dict[str, Any] = None,
enable_progress_bar: bool = True,
debug: bool = False,
check_val_every_epoch: int = 1,
nde_class: str = "NPE",
**kwargs,
):
"""
Initialize a basic Trainer module summarizing most training functionalities like logging, model initialization, training loop, etc...
Attributes
----------
model_class : jaxili.model.NDENetwork
The class of the model that should be trained.
model_hparams : Dict[str, Any]
A dictionnary of the hyperparameters of the model. Is used as input to the model when it is created.
optimizer_hparams : Dict[str, Any]
A dictionnary of the hyperparameters of the optimizer. Used during initialization of the optimizer.
exmp_input : Any
Input to the model for initialisation and tabulate.
seed : int
Seed to initialise PRNG.
logger_params : Dict[str, Any]
A dictionary containing the specifications of the logger.
enable_progress_bar : bool
Whether to enable the progress bar. Default is True.
debug : bool
If True, no jitting is applied. Can be helpful for debugging. Default is False.
check_val_every_epoch : int
How often to check the validation set. Default is 1.
nde_class : str
The class of the Neural Density Estimator. Default is "NPE". Only "NPE" and "NLE" are allowed.
"""
super().__init__()
self.model_class = model_class
self.model_hparams = model_hparams
self.loss_fn = loss_fn
self.optimizer_hparams = optimizer_hparams
self.enable_progress_bar = enable_progress_bar
self.debug = debug
self.seed = seed
self.key_rng = jax.random.PRNGKey(seed)
self.check_val_every_epoch = check_val_every_epoch
self.nde_class = nde_class
assert (
nde_class == "NPE" or nde_class == "NLE"
), "Choose a valid class of Neural Density Estimator. (NPE or NLE)"
self.exmp_input = exmp_input
if self.nde_class == "NLE":
self.exmp_input = (self.exmp_input[1], self.exmp_input[0])
self.generate_config(logger_params)
self.config.update(kwargs)
# Create an empty model. Note: no parameters yet
self.model = self.model_class(**self.model_hparams)
self.init_apply_fn()
self.print_tabulate(self.exmp_input)
# Init trainer parts
self.init_logger(logger_params)
self.create_jitted_functions()
self.init_model(self.exmp_input)
# Initialize checkpointer
self.init_checkpointer()
[docs]
def init_logger(self, logger_params: Optional[Dict] = None):
"""
Initialize the logger and created a logging directory.
Parameters
----------
logger_params : Dict[str, Any]
A dictionary containing the specifications of the logger.
"""
if logger_params is None:
logger_params = dict()
# Determine logging directory
log_dir = logger_params.get("log_dir", None)
if log_dir is None:
base_log_dir = logger_params.get("base_log_dir", "checkpoints/")
# Prepare logging
log_dir = os.path.join(base_log_dir, self.config["model_class"])
if "logger_name" in logger_params:
log_dir = os.path.join(log_dir, logger_params["logger_name"])
version = None
else:
version = ""
# Create logger object
logger_type = logger_params.get("logger_type", "TensorBoard").lower()
if logger_type == "tensorboard":
self.logger = TensorBoardLogger(save_dir=log_dir, version=version, name="")
elif logger_type == "wandb":
self.logger = WandbLogger(save_dir=log_dir, version=version, name="")
else:
assert False, f'Unknown logger type "{logger_type}"'
# Save hyperparameters
log_dir = self.logger.log_dir
if not os.path.isfile(os.path.join(log_dir, "hparams.json")):
os.makedirs(os.path.join(log_dir, "metrics/"), exist_ok=True)
try:
self.write_config(log_dir)
except:
warnings.warn("Could not save hyperparameters.", Warning)
self.log_dir = log_dir
[docs]
def write_config(self, log_dir):
"""Write the config of the trainer in a JSON file."""
with open(os.path.join(log_dir, "hparams.json"), "w") as f:
json.dump(self.config, f, indent=4, default=handle_non_serializable)
[docs]
def init_checkpointer(self):
"""Initialize the checkpointer to save the model."""
options = ocp.CheckpointManagerOptions(max_to_keep=1, create=True)
self.checkpoint_manager = ocp.CheckpointManager(self.log_dir, options=options)
[docs]
def init_model(self, exmp_input: Any):
"""
Create an initial training state with newly generated network parameters.
Parameters
----------
exmp_input : Any
An input to the model with which the shapes are inferred.
"""
# Prepare PRNG and input
init_rng, self.key_rng = jax.random.split(self.key_rng)
exmp_input = (
[exmp_input] if not isinstance(exmp_input, (list, tuple)) else exmp_input
)
# Run model initialization
variables = self.run_model_init(exmp_input, init_rng)
# Create default state. Optimizer is initialized later
model_rng, self.key_rng = jax.random.split(self.key_rng)
self.state = TrainState(
step=0,
apply_fn=self.apply_fn,
params=variables["params"],
batch_stats=variables.get("batch_stats"),
rng=model_rng,
tx=None,
opt_state=None,
)
[docs]
def init_apply_fn(self):
"""Initialize a default apply function for the model."""
self.apply_fn = self.model.log_prob
[docs]
def generate_config(self, logger_params):
"""Generate a configuration dictionary for the trainer."""
self.config = {
"model_class": self.model_class.__name__,
"model_hparams": deepcopy(self.model_hparams),
"loss_fn": self.loss_fn.__name__,
"optimizer_hparams": self.optimizer_hparams,
"logger_params": logger_params,
"enable_progress_bar": self.enable_progress_bar,
"debug": self.debug,
"check_val_every_epoch": self.check_val_every_epoch,
"seed": self.seed,
"nde_class": self.nde_class,
}
if "activation" in self.model_hparams.keys():
self.config["model_hparams"]["activation"] = self.model_hparams[
"activation"
].__name__
[docs]
def run_model_init(self, exmp_input: Any, init_rng: Any) -> Dict:
"""
Initialize the model by calling it on the example input.
Parameters
----------
exmp_input : Dict[str, Any]
An input to the model with which the shapes are inferred.
init_rng : Array
A jax.random.PRNGKey
Returns
-------
The initialized variable dictionary.
"""
return self.model.init(init_rng, *exmp_input, method="log_prob")
[docs]
def print_tabulate(self, exmp_input: Any):
"""
Print a summary of the model represented as a table.
Parameters
----------
exmp_input : Any
An input to the model with which the shapes are inferred.
"""
try:
print(self.model.tabulate(jax.random.PRNGKey(0), *exmp_input))
except Exception as e:
print(f"Could not tabulate model: {e}")
[docs]
def init_optimizer(self, num_epochs: int, num_steps_per_epoch: int):
"""
Initialize the optimizer and learning rate scheduler.
Parameters
----------
num_epochs : int
Number of epochs to train.
num_steps_per_epoch : int
Number of steps per epoch.
"""
hparams = copy(self.optimizer_hparams)
# Initialize optimizer
optimizer_name = hparams.pop("optimizer_name", "adam")
if optimizer_name.lower() == "adam":
opt_class = optax.adam
elif optimizer_name.lower() == "sgd":
opt_class = optax.sgd
elif optimizer_name.lower() == "adamw":
opt_class = optax.adamw
else:
assert False, f'Unknown optimizer "{optimizer_name}"'
# Initialize learning rate scheduler
# A cosine decay scheduler is used, but others are also possible
lr = hparams.pop("lr", 1e-3)
warmup = hparams.pop(
"warmup", num_steps_per_epoch
) # By default linear warmup during the first epoch
decay_steps = hparams.pop(
"decay_steps", int(num_epochs // 2 * num_steps_per_epoch)
)
lr_schedule = optax.warmup_cosine_decay_schedule(
init_value=0.01 * lr,
peak_value=lr,
warmup_steps=warmup,
decay_steps=decay_steps,
end_value=0.01 * lr,
)
# Clip gradients at max value, and evt. apply weight decay
transf = [optax.clip_by_global_norm(hparams.pop("gradient_clip", 5.0))]
if opt_class == optax.sgd and "weight_decay" in hparams:
transf.append(optax.add_decayed_weights(hparams.pop("weight_decay", 0.0)))
hparams.pop(
"weight_decay", None
) # removes weight decay if the opt_class is not sgd.
optimizer = optax.chain(*transf, opt_class(lr_schedule, **hparams))
# Initialize training state
self.state = TrainState.create(
apply_fn=self.state.apply_fn,
params=self.state.params,
batch_stats=self.state.batch_stats,
tx=optimizer,
rng=self.state.rng,
)
[docs]
def create_jitted_functions(self):
"""
Create jitted versions of the training, validation and evaluation functions.
If self.debug is True, no jitting is applied.
"""
train_step, eval_step = self.create_functions()
if self.debug: # Skip jitting
print("Skipping jitting due to debug=True")
self.train_step = train_step
self.eval_step = eval_step
else:
self.train_step = jax.jit(train_step)
self.eval_step = jax.jit(eval_step)
[docs]
def create_functions(
self,
) -> Tuple[
Callable[[TrainState, Any], Tuple[TrainState, Dict]],
Callable[[TrainState, Any], Tuple[TrainState, Dict]],
]:
"""
Create and returns functions for the training and evaluation step.
The functions take as input the training state and a batch from the train/val/test loader.
Both functions are expected to return a dictionary of logging metrics, and the training function a new train state. This function can be overwritten by a subclass. The train_step and eval_step functions here are examples for the signature of the functions.
"""
def train_step(state: TrainState, batch: Any):
loss_fn = lambda params: self.loss_fn(self.model, params, batch)
loss, grads = jax.value_and_grad(loss_fn)(state.params)
state = state.apply_gradients(grads=grads)
metrics = {"loss": loss}
return state, metrics
def eval_step(state: TrainState, batch: Any):
loss = self.loss_fn(self.model, state.params, batch)
metrics = {"loss": loss}
return metrics
return train_step, eval_step
[docs]
def train_model(
self,
train_loader: Iterator,
val_loader: Iterator,
test_loader: Optional[Iterator] = None,
num_epochs: int = 500,
min_delta: float = 1e-3,
patience: int = 20,
) -> Dict[str, Any]:
"""
Start a training loop for the given number of epochs.
Parameters
----------
train_loader : Iterator
An iterator over the training data.
val_loader : Iterator
An iterator over the validation data.
test_loader : Iterator
If given, best model will be evaluated on the test set.
num_epochs : int
Number of epochs for which to train the model.
min_delta : float
Minimum change in the monitored metric to qualify as an improvement.
patience : int
Number of epochs with no improvement after which training will be stopped. Default is 20.
Returns
-------
Dict[str, Any]
A dictionary of the train, validation and evt. test metrics for the best model on the validation set.
"""
# Create optimizer and the scheduler for the given numer of epochs
self.init_optimizer(num_epochs, len(train_loader))
# Prepare training loop
self.on_training_start()
best_eval_metrics = None
best_epoch = None
early_stop = EarlyStopping(min_delta, patience)
pbar = self.tracker(range(1, num_epochs + 1), desc="Epochs")
for epoch_idx in pbar:
train_metrics = self.train_epoch(train_loader)
self.logger.log_metrics(train_metrics, step=epoch_idx)
self.on_training_epoch_end(epoch_idx)
# Validation every N epochs
if epoch_idx % self.check_val_every_epoch == 0:
eval_metrics = self.eval_model(val_loader, log_prefix="val/")
self.on_validation_epoch_end(epoch_idx, eval_metrics, val_loader)
self.logger.log_metrics(eval_metrics, step=epoch_idx)
self.save_metrics(f"eval_epoch_{str(epoch_idx).zfill(3)}", eval_metrics)
# Save best model
if self.is_new_model_better(eval_metrics, best_eval_metrics):
best_eval_metrics = eval_metrics
best_eval_metrics.update(train_metrics)
best_epoch = epoch_idx
self.save_model(step=epoch_idx)
self.save_metrics("best_eval", best_eval_metrics)
early_stop = early_stop.update(eval_metrics["val/loss"])
if early_stop.should_stop:
print(f"Neural network training stopped after {epoch_idx} epochs.")
print(
f"Early stopping with best validation metric: {early_stop.best_metric}"
)
print(f"Best model saved at epoch {best_epoch}")
print(
f"Early stopping parameters: min_delta={min_delta}, patience={patience}"
)
break
if self.enable_progress_bar:
pbar.set_description(
f"Epochs: Val loss {eval_metrics['val/loss']:.3f}/ Best val loss {early_stop.best_metric:.3f}"
)
# Test best model if possible
if test_loader is not None:
self.load_model()
test_metrics = self.eval_model(test_loader, log_prefix="test/")
self.logger.log_metrics(test_metrics, step=epoch_idx)
self.save_metrics("test", test_metrics)
best_eval_metrics.update(test_metrics)
# Close logger
self.logger.finalize("success")
return best_eval_metrics
[docs]
def train_epoch(self, train_loader: Iterator) -> Dict[str, Any]:
"""
Train the model for one epoch.
Parameters
----------
train_loader : Iterator
An iterator over the training data.
Returns
-------
Dict[str, Any]
A dictionary of the average training metrics over all batches for logging
"""
# Train model for one epoch, and log avg loss and accuracy
metrics = defaultdict(float)
num_train_steps = len(train_loader)
start_time = time.time()
for batch in train_loader:
self.state, step_metrics = self.train_step(self.state, batch)
for key in step_metrics:
metrics["train/" + key] += step_metrics[key] / num_train_steps
metrics = {key: metrics[key].item() for key in metrics}
metrics["epoch_time"] = time.time() - start_time
return metrics
[docs]
def eval_model(
self, data_loader: Iterator, log_prefix: Optional[str] = ""
) -> Dict[str, Any]:
"""
Evaluate the model on a dataset.
Parameters
----------
data_loader : Iterator
An iterator over the data.
log_prefix : str
A prefix to add to all metrics.
Returns
-------
Dict[str, Any]
A dictionary of the evaluation metrics, averaged over data points in the dataset
"""
# Test model on all element of the dataloader and return avg loss
metrics = defaultdict(float)
num_elements = 0
for batch in data_loader:
step_metrics = self.eval_step(self.state, batch)
batch_size = (
batch[0].shape[0]
if isinstance(batch, (list, tuple))
else batch.shape[0]
)
for key in step_metrics:
metrics[key] += step_metrics[key] * batch_size
num_elements += batch_size
metrics = {
(log_prefix + key): (metrics[key] / num_elements).item() for key in metrics
}
return metrics
[docs]
def is_new_model_better(
self, new_metrics: Dict[str, Any], old_metrics: Dict[str, Any]
) -> bool:
"""
Compare two sets of evaluation metrics to decide whether the new model is better than the previous ones or not.
Parameters
----------
new_metrics : Iterator
A dictionary of the evaluation metrics of the new model
old_metrics : Iterator
A dictionary of the evaluation metrics of the previously best model.
Returns
-------
bool
True if the new model is better, False otherwise.
"""
if old_metrics is None:
return True
for key, is_larger in [
("val/val_metric", False),
("val/acc", True),
("val/loss", False),
]:
if key in new_metrics:
if is_larger:
return new_metrics[key] > old_metrics[key]
else:
return new_metrics[key] < old_metrics[key]
assert False, f"No known metrics to log on: {new_metrics}"
[docs]
def tracker(self, iterator: Iterator, **kwargs) -> Iterator:
"""
Wrap an iterator in a progress bar track (tqdm) if the progress bar is enabled.
Parameters
----------
iterator : Iterator
Iterator to wrap in tqdm.
kwargs : Any
additional arguments to tqdm.
Returns
-------
Iterator
Wrapped iterator if progress bar is enabled, otherwise same iterator than input.
"""
if self.enable_progress_bar:
return tqdm(iterator, **kwargs)
else:
return iterator
[docs]
def save_metrics(self, filename: str, metrics: Dict[str, Any]):
"""
Save a dictionary of metrics to file.
This can be used as a textual representation of the validation performance for checking in the terminal.
Parameters
----------
filename : str
Name of the metrics file without folders and postfix
metrics : Dict[str, Any]
A dictionary of the metrics to save
"""
with open(os.path.join(self.log_dir, f"metrics/{filename}.json"), "w") as f:
json.dump(metrics, f, indent=4)
[docs]
def on_training_start(self):
"""
Perform any necessary operations before the training starts.
Method called before training is started. Can be used for additional initialization operations etc.
"""
pass
[docs]
def on_training_epoch_end(self, epoch_idx: int):
"""
Perform any necessary operations at the end of each training epoch.
Method called at the end of each training epoch. Can be used for additional logging or similar.
Parameters
----------
epoch_idx : int
Index of the epoch that just finished.
"""
pass
[docs]
def on_validation_epoch_end(
self, epoch_idx: int, eval_metrics: Dict[str, Any], val_loader: Iterator
):
"""
Perform any necessary operations at the end of each validation epoch.
Method called at the end of each validation epoch. Can be used for additional logging and evaluation.
Parameters
----------
epoch_idx : int
Index of the epoch that just finished.
eval_metrics : Dict[str, Any]
A dictionary of the evaluation metrics.
val_loader : Iterator
DataLoader of the validation set to support additional evaluation.
"""
pass
[docs]
def save_model(self, step: int = 0):
"""
Save current training state at certain training iteration.
Only the model parameters and batch statistics are saved to reduce memory footprint. To allow the training to be continued from a checkpoint, this method can be extended to include the optimizer state as well.
Parameters
----------
step : int
Index of the step to save the model at, e.g. epoch.
"""
target = {"params": self.state.params, "batch_stats": self.state.batch_stats}
self.checkpoint_manager.save(step, args=ocp.args.StandardSave(target))
self.checkpoint_manager.wait_until_finished()
[docs]
def load_model(self):
"""Load model and batch statistics from the logging directory."""
step = self.checkpoint_manager.latest_step()
state_dict = self.checkpoint_manager.restore(step)
self.state = TrainState.create(
apply_fn=self.apply_fn,
params=state_dict["params"],
batch_stats=state_dict["batch_stats"],
tx=self.state.tx if self.state.tx else optax.sgd(0.1),
rng=self.state.rng,
)
[docs]
def bind_model(self):
"""
Return a model with parameters bound to it. Enables an easier inference access.
Returns
-------
The model with parameters and evt. batch statistics bound to it.
"""
params = {"params": self.state.params}
if self.state.batch_stats:
params["batch_stats"] = self.state.batch_stats
return self.model.bind(params)
[docs]
@classmethod
def load_from_checkpoints(
cls, model_class: NDENetwork, checkpoint: str, exmp_input: Any
) -> Any:
"""
Create a Trainer object with same hyperparameters and loaded model from a checkpoint directory.
Parameters
----------
model_class : jaxili.model.NDENetwork
The class of the model that should be loaded.
checkpoint : str
Folder in which the checkpoint and hyperparameter file is stored
exmp_input : Any
An input to the model with which the shapes are inferred.
Returns
-------
A Trainer object with model loaded from the checkpoint folder.
"""
hparams_file = os.path.join(checkpoint, "hparams.json")
assert os.path.isfile(hparams_file), "Could not find hparams file."
with open(hparams_file, "r") as f:
hparams = json.load(f)
assert (
hparams["model_class"] == model_class.__name__
), "Model class does not match. Check that you are using the correct architecture."
hparams.pop("model_class")
# Check if an activation function is used as a hyperparameter if the neural network.
assert (
hparams["loss_fn"] in jaxili_loss_dict
), "Unknown loss function. Check that the loss function you used comes from `jax.nn`."
hparams["loss_fn"] = jaxili_loss_dict[hparams["loss_fn"]]
if "activation" in hparams["model_hparams"].keys():
hparams["model_hparams"]["activation"] = jax_nn_dict[
hparams["model_hparams"]["activation"]
]
if "nde" in hparams["model_hparams"].keys():
hparams["model_hparams"]["nde"] = jaxili_nn_dict[
hparams["model_hparams"]["nde"]
]
if not hparams["logger_params"]:
hparams["logger_params"] = dict()
hparams["logger_params"]["log_dir"] = checkpoint
trainer = cls(model_class=model_class, exmp_input=exmp_input, **hparams)
trainer.load_model()
return trainer