Skip to content

Base

Bijection ¤

Bases: Module

Base class for bijective transformations.

This is an abstract template that should not be directly used.

Source code in flowMC/nfmodel/base.py
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
class Bijection(eqx.Module):
    """
    Base class for bijective transformations.

    This is an abstract template that should not be directly used."""

    @abstractmethod
    def __init__(self):
        return NotImplemented

    def __call__(
        self, x: Array, key: Optional[PRNGKeyArray] = None
    ) -> tuple[Array, Array]:
        return self.forward(x)

    @abstractmethod
    def forward(self, x: Array) -> tuple[Array, Array]:
        return NotImplemented

    @abstractmethod
    def inverse(self, x: Array) -> tuple[Array, Array]:
        return NotImplemented

Distribution ¤

Bases: Module

Base class for probability distributions.

This is an abstract template that should not be directly used.

Source code in flowMC/nfmodel/base.py
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
class Distribution(eqx.Module):
    """
    Base class for probability distributions.

    This is an abstract template that should not be directly used.
    """

    @abstractmethod
    def __init__(self):
        return NotImplemented

    def __call__(self, x: Array, key: Optional[PRNGKeyArray] = None) -> Array:
        return self.log_prob(x)

    @abstractmethod
    def log_prob(self, x: Array) -> Array:
        return NotImplemented

    @abstractmethod
    def sample(
        self, rng_key: PRNGKeyArray, n_samples: int
    ) -> Float[Array, " n_samples n_features"]:
        return NotImplemented

NFModel ¤

Bases: Module

Base class for normalizing flow models.

This is an abstract template that should not be directly used.

Source code in flowMC/nfmodel/base.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
class NFModel(eqx.Module):
    """
    Base class for normalizing flow models.

    This is an abstract template that should not be directly used.
    """

    @abstractmethod
    def __init__(self):
        raise NotImplementedError

    def __call__(self, x: Float[Array, "n_dim"]) -> tuple[Float[Array, "n_dim"], Float]:
        """
        Forward pass of the model.

        Args:
            x (Float[Array, "n_dim"]): Input data.

        Returns:
            tuple[Float[Array, "n_dim"], Float]: Output data and log determinant of the Jacobian.
        """
        return self.forward(x)

    @abstractmethod
    def log_prob(self, x: Float[Array, "n_dim"]) -> Float:
        return NotImplemented

    @abstractmethod
    def sample(self, rng_key: PRNGKeyArray, n_samples: int) -> Array:
        return NotImplemented

    @abstractmethod
    def forward(
        self, x: Float[Array, "n_dim"], key: Optional[PRNGKeyArray] = None
    ) -> tuple[Float[Array, "n_dim"], Float]:
        """
        Forward pass of the model.

        Args:
            x (Float[Array, "n_dim"]): Input data.

        Returns:
            tuple[Float[Array, "n_dim"], Float]: Output data and log determinant of the Jacobian.
        """
        return NotImplemented

    @abstractmethod
    def inverse(self, x: Float[Array, "n_dim"]) -> tuple[Float[Array, "n_dim"], Float]:
        """
        Inverse pass of the model.

        Args:
            x (Float[Array, "n_dim"]): Input data.

        Returns:
            tuple[Float[Array, "n_dim"], Float]: Output data and log determinant of the Jacobian.
        """
        return NotImplemented

    @abstractmethod
    def n_features(self) -> int:
        return NotImplemented

    def save_model(self, path: str):
        eqx.tree_serialise_leaves(path + ".eqx", self)

    def load_model(self, path: str):
        self = eqx.tree_deserialise_leaves(path + ".eqx", self)

    @eqx.filter_value_and_grad
    def loss_fn(self, x):
        return -jnp.mean(self.log_prob(x))

    @eqx.filter_jit
    def train_step(
        self: Self,
        x: Float[Array, "n_batch n_dim"],
        optim: optax.GradientTransformation,
        state: optax.OptState,
    ) -> tuple[Float, Self, optax.OptState]:
        """Train for a single step.

        Args:
            model (eqx.Model): NF model to train.
            x (Array): Training data.
            opt_state (optax.OptState): Optimizer state.

        Returns:
            loss (Array): Loss value.
            model (eqx.Model): Updated model.
            opt_state (optax.OptState): Updated optimizer state.
        """
        model = self
        loss, grads = model.loss_fn(x)
        updates, state = optim.update(grads, state)
        model = eqx.apply_updates(model, updates)
        return loss, model, state

    def train_epoch(
        self: Self,
        rng: PRNGKeyArray,
        optim: optax.GradientTransformation,
        state: optax.OptState,
        data: Float[Array, "n_example n_dim"],
        batch_size: Float,
    ) -> tuple[Float, Self, optax.OptState]:
        """Train for a single epoch."""
        model = self
        train_ds_size = len(data)
        steps_per_epoch = train_ds_size // batch_size
        if steps_per_epoch > 0:
            perms = jax.random.permutation(rng, train_ds_size)

            perms = perms[: steps_per_epoch * batch_size]  # skip incomplete batch
            perms = perms.reshape((steps_per_epoch, batch_size))
            for perm in perms:
                batch = data[perm, ...]
                value, model, state = model.train_step(batch, optim, state)
        else:
            value, model, state = model.train_step(data, optim, state)

        return value, model, state

    def train(
        self: Self,
        rng: PRNGKeyArray,
        data: Array,
        optim: optax.GradientTransformation,
        state: optax.OptState,
        num_epochs: int,
        batch_size: int,
        verbose: bool = True,
    ) -> tuple[PRNGKeyArray, Self, optax.OptState, Array]:
        """Train a normalizing flow model.

        Args:
            rng (PRNGKeyArray): JAX PRNGKey.
            model (eqx.Module): NF model to train.
            data (Array): Training data.
            num_epochs (int): Number of epochs to train for.
            batch_size (int): Batch size.
            verbose (bool): Whether to print progress.

        Returns:
            rng (PRNGKeyArray): Updated JAX PRNGKey.
            model (eqx.Model): Updated NF model.
            loss_values (Array): Loss values.
        """
        loss_values = jnp.zeros(num_epochs)
        if verbose:
            pbar = trange(num_epochs, desc="Training NF", miniters=int(num_epochs / 10))
        else:
            pbar = range(num_epochs)

        best_model = model = self
        best_state = state
        best_loss = 1e9
        for epoch in pbar:
            # Use a separate PRNG key to permute image data during shuffling
            rng, input_rng = jax.random.split(rng)
            # Run an optimization step over a training batch
            value, model, state = model.train_epoch(
                input_rng, optim, state, data, batch_size
            )
            loss_values = loss_values.at[epoch].set(value)
            if loss_values[epoch] < best_loss:
                best_model = model
                best_state = state
                best_loss = loss_values[epoch]
            if verbose:
                assert isinstance(pbar, tqdm)
                if num_epochs > 10:
                    if epoch % int(num_epochs / 10) == 0:
                        pbar.set_description(f"Training NF, current loss: {value:.3f}")
                else:
                    if epoch == num_epochs:
                        pbar.set_description(f"Training NF, current loss: {value:.3f}")

        return rng, best_model, best_state, loss_values

__call__(x) ¤

Forward pass of the model.

Parameters:

Name Type Description Default
x Float[Array, n_dim]

Input data.

required

Returns:

Type Description
tuple[Float[Array, n_dim], Float]

tuple[Float[Array, "n_dim"], Float]: Output data and log determinant of the Jacobian.

Source code in flowMC/nfmodel/base.py
25
26
27
28
29
30
31
32
33
34
35
def __call__(self, x: Float[Array, "n_dim"]) -> tuple[Float[Array, "n_dim"], Float]:
    """
    Forward pass of the model.

    Args:
        x (Float[Array, "n_dim"]): Input data.

    Returns:
        tuple[Float[Array, "n_dim"], Float]: Output data and log determinant of the Jacobian.
    """
    return self.forward(x)

forward(x, key=None) abstractmethod ¤

Forward pass of the model.

Parameters:

Name Type Description Default
x Float[Array, n_dim]

Input data.

required

Returns:

Type Description
tuple[Float[Array, n_dim], Float]

tuple[Float[Array, "n_dim"], Float]: Output data and log determinant of the Jacobian.

Source code in flowMC/nfmodel/base.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
@abstractmethod
def forward(
    self, x: Float[Array, "n_dim"], key: Optional[PRNGKeyArray] = None
) -> tuple[Float[Array, "n_dim"], Float]:
    """
    Forward pass of the model.

    Args:
        x (Float[Array, "n_dim"]): Input data.

    Returns:
        tuple[Float[Array, "n_dim"], Float]: Output data and log determinant of the Jacobian.
    """
    return NotImplemented

inverse(x) abstractmethod ¤

Inverse pass of the model.

Parameters:

Name Type Description Default
x Float[Array, n_dim]

Input data.

required

Returns:

Type Description
tuple[Float[Array, n_dim], Float]

tuple[Float[Array, "n_dim"], Float]: Output data and log determinant of the Jacobian.

Source code in flowMC/nfmodel/base.py
60
61
62
63
64
65
66
67
68
69
70
71
@abstractmethod
def inverse(self, x: Float[Array, "n_dim"]) -> tuple[Float[Array, "n_dim"], Float]:
    """
    Inverse pass of the model.

    Args:
        x (Float[Array, "n_dim"]): Input data.

    Returns:
        tuple[Float[Array, "n_dim"], Float]: Output data and log determinant of the Jacobian.
    """
    return NotImplemented

train(rng, data, optim, state, num_epochs, batch_size, verbose=True) ¤

Train a normalizing flow model.

Parameters:

Name Type Description Default
rng PRNGKeyArray

JAX PRNGKey.

required
model Module

NF model to train.

required
data Array

Training data.

required
num_epochs int

Number of epochs to train for.

required
batch_size int

Batch size.

required
verbose bool

Whether to print progress.

True

Returns:

Name Type Description
rng PRNGKeyArray

Updated JAX PRNGKey.

model Model

Updated NF model.

loss_values Array

Loss values.

Source code in flowMC/nfmodel/base.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
def train(
    self: Self,
    rng: PRNGKeyArray,
    data: Array,
    optim: optax.GradientTransformation,
    state: optax.OptState,
    num_epochs: int,
    batch_size: int,
    verbose: bool = True,
) -> tuple[PRNGKeyArray, Self, optax.OptState, Array]:
    """Train a normalizing flow model.

    Args:
        rng (PRNGKeyArray): JAX PRNGKey.
        model (eqx.Module): NF model to train.
        data (Array): Training data.
        num_epochs (int): Number of epochs to train for.
        batch_size (int): Batch size.
        verbose (bool): Whether to print progress.

    Returns:
        rng (PRNGKeyArray): Updated JAX PRNGKey.
        model (eqx.Model): Updated NF model.
        loss_values (Array): Loss values.
    """
    loss_values = jnp.zeros(num_epochs)
    if verbose:
        pbar = trange(num_epochs, desc="Training NF", miniters=int(num_epochs / 10))
    else:
        pbar = range(num_epochs)

    best_model = model = self
    best_state = state
    best_loss = 1e9
    for epoch in pbar:
        # Use a separate PRNG key to permute image data during shuffling
        rng, input_rng = jax.random.split(rng)
        # Run an optimization step over a training batch
        value, model, state = model.train_epoch(
            input_rng, optim, state, data, batch_size
        )
        loss_values = loss_values.at[epoch].set(value)
        if loss_values[epoch] < best_loss:
            best_model = model
            best_state = state
            best_loss = loss_values[epoch]
        if verbose:
            assert isinstance(pbar, tqdm)
            if num_epochs > 10:
                if epoch % int(num_epochs / 10) == 0:
                    pbar.set_description(f"Training NF, current loss: {value:.3f}")
            else:
                if epoch == num_epochs:
                    pbar.set_description(f"Training NF, current loss: {value:.3f}")

    return rng, best_model, best_state, loss_values

train_epoch(rng, optim, state, data, batch_size) ¤

Train for a single epoch.

Source code in flowMC/nfmodel/base.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
def train_epoch(
    self: Self,
    rng: PRNGKeyArray,
    optim: optax.GradientTransformation,
    state: optax.OptState,
    data: Float[Array, "n_example n_dim"],
    batch_size: Float,
) -> tuple[Float, Self, optax.OptState]:
    """Train for a single epoch."""
    model = self
    train_ds_size = len(data)
    steps_per_epoch = train_ds_size // batch_size
    if steps_per_epoch > 0:
        perms = jax.random.permutation(rng, train_ds_size)

        perms = perms[: steps_per_epoch * batch_size]  # skip incomplete batch
        perms = perms.reshape((steps_per_epoch, batch_size))
        for perm in perms:
            batch = data[perm, ...]
            value, model, state = model.train_step(batch, optim, state)
    else:
        value, model, state = model.train_step(data, optim, state)

    return value, model, state

train_step(x, optim, state) ¤

Train for a single step.

Parameters:

Name Type Description Default
model Model

NF model to train.

required
x Array

Training data.

required
opt_state OptState

Optimizer state.

required

Returns:

Name Type Description
loss Array

Loss value.

model Model

Updated model.

opt_state OptState

Updated optimizer state.

Source code in flowMC/nfmodel/base.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
@eqx.filter_jit
def train_step(
    self: Self,
    x: Float[Array, "n_batch n_dim"],
    optim: optax.GradientTransformation,
    state: optax.OptState,
) -> tuple[Float, Self, optax.OptState]:
    """Train for a single step.

    Args:
        model (eqx.Model): NF model to train.
        x (Array): Training data.
        opt_state (optax.OptState): Optimizer state.

    Returns:
        loss (Array): Loss value.
        model (eqx.Model): Updated model.
        opt_state (optax.OptState): Updated optimizer state.
    """
    model = self
    loss, grads = model.loss_fn(x)
    updates, state = optim.update(grads, state)
    model = eqx.apply_updates(model, updates)
    return loss, model, state