Skip to content

Common

Gaussian ¤

Bases: Distribution

Multivariate Gaussian distribution.

Parameters:

Name Type Description Default
mean Array

Mean.

required
cov Array

Covariance matrix.

required
learnable bool

Whether the mean and covariance matrix are learnable parameters.

False

Attributes:

Name Type Description
mean Array

Mean.

cov Array

Covariance matrix.

Source code in flowMC/nfmodel/common.py
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
class Gaussian(Distribution):

    r"""Multivariate Gaussian distribution.

    Args:
        mean (Array): Mean.
        cov (Array): Covariance matrix.
        learnable (bool): Whether the mean and covariance matrix are learnable parameters.

    Attributes:
        mean (Array): Mean.
        cov (Array): Covariance matrix.
    """

    _mean: Float[Array, "n_dim"]
    _cov: Float[Array, "n_dim n_dim"]
    learnable: bool = False

    @property
    def mean(self) -> Float[Array, "n_dim"]:
        if self.learnable:
            return self._mean
        else:
            return jax.lax.stop_gradient(self._mean)

    @property
    def cov(self) -> Float[Array, "n_dim n_dim"]:
        if self.learnable:
            return self._cov
        else:
            return jax.lax.stop_gradient(self._cov)

    def __init__(self, mean: Float[Array, "n_dim"], cov: Float[Array, "n_dim n_dim"], learnable: bool = False):
        self._mean = mean
        self._cov = cov
        self.learnable = learnable

    def log_prob(self, x: Float[Array, "n_dim"]) -> Float:
        return jax.scipy.stats.multivariate_normal.logpdf(x, self.mean, self.cov)

    def sample(self, key: PRNGKeyArray, n_samples: int = 1) -> Float[Array, "n_samples n_dim"]:
        return jax.random.multivariate_normal(key, self.mean, self.cov, (n_samples,))

MLP ¤

Bases: Module

Multilayer perceptron.

Parameters:

Name Type Description Default
shape List[int]

Shape of the MLP. The first element is the input dimension, the last element is the output dimension.

required
key PRNGKeyArray

Random key.

required

Attributes:

Name Type Description
layers List

List of layers.

activation Callable

Activation function.

use_bias bool

Whether to use bias.

Source code in flowMC/nfmodel/common.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
class MLP(eqx.Module):
    r"""Multilayer perceptron.

    Args:
        shape (List[int]): Shape of the MLP. The first element is the input dimension, the last element is the output dimension.
        key (PRNGKeyArray): Random key.

    Attributes:
        layers (List): List of layers.
        activation (Callable): Activation function.
        use_bias (bool): Whether to use bias.
    """
    layers: List

    def __init__(
        self,
        shape: List[int],
        key: PRNGKeyArray,
        scale: Float = 1e-4,
        activation: Callable = jax.nn.relu,
        use_bias: bool = True,
    ):
        self.layers = []
        for i in range(len(shape) - 2):
            key, subkey1, subkey2 = jax.random.split(key, 3)
            layer = eqx.nn.Linear(
                shape[i], shape[i + 1], key=subkey1, use_bias=use_bias
            )
            weight = jax.random.normal(subkey2, (shape[i + 1], shape[i])) * jnp.sqrt(
                scale / shape[i]
            )
            layer = eqx.tree_at(lambda l: l.weight, layer, weight)
            self.layers.append(layer)
            self.layers.append(activation)
        key, subkey = jax.random.split(key)
        self.layers.append(
            eqx.nn.Linear(shape[-2], shape[-1], key=subkey, use_bias=use_bias)
        )

    def __call__(self, x: Float[Array, "n_in"]) -> Float[Array, "n_out"]:
        for layer in self.layers:
            x = layer(x)
        return x

    @property
    def n_input(self) -> int:
        return self.layers[0].in_features

    @property
    def n_output(self) -> int:
        return self.layers[-1].out_features

    @property
    def dtype(self) -> jnp.dtype:
        return self.layers[0].weight.dtype

MaskedCouplingLayer ¤

Bases: Bijection

Masked coupling layer.

f(x) = (1-m)b(x;c(mx;z)) + m*x where b is the inner bijector, m is the mask, and c is the conditioner.

Parameters:

Name Type Description Default
bijector Bijection

inner bijector in the masked coupling layer.

required
mask Array

Mask. 0 for the input variables that are transformed, 1 for the input variables that are not transformed.

required
Source code in flowMC/nfmodel/common.py
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
class MaskedCouplingLayer(Bijection):

    r"""Masked coupling layer.

    f(x) = (1-m)*b(x;c(m*x;z)) + m*x
    where b is the inner bijector, m is the mask, and c is the conditioner.

    Args:
        bijector (Bijection): inner bijector in the masked coupling layer.
        mask (Array): Mask. 0 for the input variables that are transformed, 1 for the input variables that are not transformed.

    """

    _mask: Float[Array, "n_dim"]
    bijector: Bijection

    @property
    def mask(self) -> Float[Array, "n_dim"]:
        return jax.lax.stop_gradient(self._mask)

    def __init__(self, bijector: Bijection, mask: Float[Array, "n_dim"]):
        self.bijector = bijector
        self._mask = mask

    def forward(self, x: Float[Array, "n_dim"]) -> Tuple[Float[Array, "n_dim"], Float[Array, "n_dim"]]:
        y, log_det = self.bijector(x, x * self.mask) # type: ignore
        y = (1 - self.mask) * y + self.mask * x
        log_det = ((1 - self.mask) * log_det).sum()
        return y, log_det

    def inverse(self, x: Float[Array, "n_dim"]) -> Tuple[Float[Array, "n_dim"], Float[Array, "n_dim"]]:
        y, log_det = self.bijector.inverse(x, x * self.mask) # type: ignore
        y = (1 - self.mask) * y + self.mask * x
        log_det = ((1 - self.mask) * log_det).sum()
        return y, log_det