Training a normalizing flow¶
In this notebook, we use the API in flowMC
to train two different normalizing flow networks to approximate a simple test distribution. The API is built on top of the companion libraries of Jax for deep learning, flax
and optax
.
In typical applications of flowMC
to obtain samples from a given posterior distribution you will not need to interact with this level of the API, the training will be directly handled within the sampling. However you will need to choose the normalizing flow model and this tutorial exemplifies the abilities of the two models currently available in the package.
We train both a RealNVP flow from [Dinh et al. 2016] and a more complex normalizing flow model, the rational quadratic spline model [Durkan et al. 2019].
import jax
import jax.numpy as jnp # JAX NumPy
import jax.random as random # JAX random
import optax # Optimizers
import equinox as eqx # Equinox
from flowMC.nfmodel.realNVP import RealNVP
from flowMC.nfmodel.rqSpline import MaskedCouplingRQSpline
We will use make_moons from scikit-learn to create a toy dataset in 2-dimensions.
from sklearn.datasets import make_moons
import matplotlib.pyplot as plt
data = jnp.array(make_moons(n_samples=100000, noise=0.05)[0])
plt.scatter(data[:, 0], data[:, 1], s=0.5, alpha=0.5, label="data")
plt.legend()
2024-04-10 17:44:58.875351: W external/xla/xla/service/gpu/nvptx_compiler.cc:679] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.3.52). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
<matplotlib.legend.Legend at 0x7f14642bffd0>
RealNVPs¶
We first use the RealNVP model to fit the data. We need to specify:
n_layers
: the number of coupling layers.n_hidden
: the width of the hidden layers in the 1-hidden layer MLPs for learning the scales and translations in the affine coupling layers.
Inflating these numbers provides more flexibility to the normalizing flow, yet at the cost of increasing the computational budget.
# Model parameters
n_feature = 2
n_layers = 10
n_hidden = 100
key, subkey = jax.random.split(jax.random.PRNGKey(0), 2)
model = RealNVP(
n_feature,
n_layers,
n_hidden,
subkey,
data_mean=jnp.mean(data, axis=0),
data_cov=jnp.cov(data.T),
)
Next, we initialize a train_state
following flax
logic and an optax
optimizer beforw lanching the training.
# Optimization parameters
num_epochs = 100
batch_size = 10000
learning_rate = 0.001
momentum = 0.9
optim = optax.adam(learning_rate)
state = optim.init(eqx.filter(model, eqx.is_array))
key, subkey = jax.random.split(key)
key, model, state, loss = model.train(key, data, optim, state, num_epochs, batch_size, verbose=True)
Training NF, current loss: 1.704: 100%|██████████| 100/100 [00:19<00:00, 5.08it/s]
plt.plot(loss)
plt.xlabel("Epoch")
plt.ylabel("Loss")
Text(0, 0.5, 'Loss')
Finally we can visualize what the flow has learned by comparing the data distribution to the distribution of samples from the flow.
key, subkey = jax.random.split(key, 2)
nf_samples = model.sample(subkey, 10000)
plt.figure()
plt.scatter(data[:, 0], data[:, 1], s=0.5, alpha=0.5, label="data")
plt.scatter(
nf_samples[:, 0], nf_samples[:, 1], s=0.5, alpha=0.5, label="RealNVP samples"
)
plt.legend()
plt.show()
RQSplines¶
The second type of flows available are the RQSpline. These models are also based on coupling layers, however they allow for transformation more expressive than affine, namely splines of quotients of quadratic functions. Here the parameters are:
n_layers
: the number of coupling layers.n_hidden
: the list of widths of the hidden layers MLPs for learning the polynomial coefficients.n_bins
: the number of bins for the spline decompositions.
As previsouly, the bigger these numbers the more flexibility to the normalizing flow and higher is the computational cost of one training iteration. While RQSplines are generally more computationally demanding per training step than RealNVPs, there can be a favorable trade-off in selecting this more sophisticated model as it may require less iterations to converge to a satisfactory solution.
# Model parameters
n_feature = 2
n_layers = 8
n_hiddens = [64, 64]
n_bins = 8
key, subkey = jax.random.split(jax.random.PRNGKey(1))
model = MaskedCouplingRQSpline(
n_feature,
n_layers,
n_hiddens,
n_bins,
subkey,
data_cov=jnp.cov(data.T),
data_mean=jnp.mean(data, axis=0),
)
num_epochs = 100
batch_size = 10000
learning_rate = 0.001
momentum = 0.9
optim = optax.adam(learning_rate)
state = optim.init(eqx.filter(model, eqx.is_array))
key, subkey = jax.random.split(key)
key, model, state, loss = model.train(key, data, optim, state, num_epochs, batch_size, verbose=True)
Training NF, current loss: 1.267: 100%|██████████| 100/100 [00:20<00:00, 4.96it/s]
plt.plot(loss)
plt.xlabel("Epoch")
plt.ylabel("Loss")
Text(0, 0.5, 'Loss')
key, subkey = jax.random.split(key, 2)
nf_samples = model.sample(subkey, 10000)
plt.figure()
plt.scatter(data[:, 0], data[:, 1], s=0.5, alpha=0.5, label="data")
plt.scatter(
nf_samples[:, 0], nf_samples[:, 1], s=0.5, alpha=0.5, label="RQSpline samples"
)
plt.legend()
plt.show()