Skip to content

realNVP

AffineCoupling ¤

Bases: Module

Affine coupling layer. (Defined in the RealNVP paper https://arxiv.org/abs/1605.08803) We use tanh as the default activation function.

Parameters:

Name Type Description Default
n_features int

(int) The number of features in the input.

required
n_hidden int

(int) The number of hidden units in the MLP.

required
mask Array

(ndarray) Alternating mask for the affine coupling layer.

required
dt Float

(Float) Scaling factor for the affine coupling layer.

1
Source code in flowMC/nfmodel/realNVP.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
class AffineCoupling(eqx.Module):
    """
    Affine coupling layer.
    (Defined in the RealNVP paper https://arxiv.org/abs/1605.08803)
    We use tanh as the default activation function.

    Args:
        n_features: (int) The number of features in the input.
        n_hidden: (int) The number of hidden units in the MLP.
        mask: (ndarray) Alternating mask for the affine coupling layer.
        dt: (Float) Scaling factor for the affine coupling layer.
    """

    _mask: Array
    scale_MLP: MLP
    translate_MLP: MLP
    dt: Float = 1

    def __init__(
        self,
        n_features: int,
        n_hidden: int,
        mask: Array,
        key: PRNGKeyArray,
        dt: Float = 1,
        scale: Float = 1e-4,
    ):
        self._mask = mask
        self.dt = dt
        key, scale_subkey, translate_subkey = jax.random.split(key, 3)
        features = [n_features, n_hidden, n_features]
        self.scale_MLP = MLP(features, key=scale_subkey, scale=scale)
        self.translate_MLP = MLP(features, key=translate_subkey, scale=scale)

    @property
    def mask(self):
        return jax.lax.stop_gradient(self._mask)

    @property
    def n_features(self):
        return self.scale_MLP.n_input

    def __call__(self, x: Array):
        return self.forward(x)

    def forward(self, x: Array) -> Tuple[Array, Array]:
        """From latent space to data space

        Args:
            x: (Array) Latent space.

        Returns:
            outputs: (Array) Data space.
            log_det: (Array) Log determinant of the Jacobian.
        """
        s = self.mask * self.scale_MLP(x * (1 - self.mask))
        s = jnp.tanh(s) * self.dt
        t = self.mask * self.translate_MLP(x * (1 - self.mask)) * self.dt

        # Compute log determinant of the Jacobian
        log_det = s.sum()

        # Apply the transformation
        outputs = (x + t) * jnp.exp(s)
        return outputs, log_det

    def inverse(self, x: Array) -> Tuple[Array, Array]:
        """From data space to latent space

        Args:
            x: (Array) Data space.

        Returns:
            outputs: (Array) Latent space.
            log_det: (Array) Log determinant of the Jacobian.
        """
        s = self.mask * self.scale_MLP(x * (1 - self.mask))
        s = jnp.tanh(s) * self.dt
        t = self.mask * self.translate_MLP(x * (1 - self.mask)) * self.dt
        log_det = -s.sum()
        outputs = x * jnp.exp(-s) - t
        return outputs, log_det

forward(x) ¤

From latent space to data space

Parameters:

Name Type Description Default
x Array

(Array) Latent space.

required

Returns:

Name Type Description
outputs Array

(Array) Data space.

log_det Array

(Array) Log determinant of the Jacobian.

Source code in flowMC/nfmodel/realNVP.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def forward(self, x: Array) -> Tuple[Array, Array]:
    """From latent space to data space

    Args:
        x: (Array) Latent space.

    Returns:
        outputs: (Array) Data space.
        log_det: (Array) Log determinant of the Jacobian.
    """
    s = self.mask * self.scale_MLP(x * (1 - self.mask))
    s = jnp.tanh(s) * self.dt
    t = self.mask * self.translate_MLP(x * (1 - self.mask)) * self.dt

    # Compute log determinant of the Jacobian
    log_det = s.sum()

    # Apply the transformation
    outputs = (x + t) * jnp.exp(s)
    return outputs, log_det

inverse(x) ¤

From data space to latent space

Parameters:

Name Type Description Default
x Array

(Array) Data space.

required

Returns:

Name Type Description
outputs Array

(Array) Latent space.

log_det Array

(Array) Log determinant of the Jacobian.

Source code in flowMC/nfmodel/realNVP.py
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def inverse(self, x: Array) -> Tuple[Array, Array]:
    """From data space to latent space

    Args:
        x: (Array) Data space.

    Returns:
        outputs: (Array) Latent space.
        log_det: (Array) Log determinant of the Jacobian.
    """
    s = self.mask * self.scale_MLP(x * (1 - self.mask))
    s = jnp.tanh(s) * self.dt
    t = self.mask * self.translate_MLP(x * (1 - self.mask)) * self.dt
    log_det = -s.sum()
    outputs = x * jnp.exp(-s) - t
    return outputs, log_det

RealNVP ¤

Bases: NFModel

RealNVP mode defined in the paper https://arxiv.org/abs/1605.08803. MLP is needed to make sure the scaling between layers are more or less the same.

Parameters:

Name Type Description Default
n_layers int

(int) The number of affine coupling layers.

required
n_features int

(int) The number of features in the input.

required
n_hidden int

(int) The number of hidden units in the MLP.

required
dt

(Float) Scaling factor for the affine coupling layer.

required
Properties

data_mean: (ndarray) Mean of Gaussian base distribution data_cov: (ndarray) Covariance of Gaussian base distribution

Source code in flowMC/nfmodel/realNVP.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
class RealNVP(NFModel):
    """
    RealNVP mode defined in the paper https://arxiv.org/abs/1605.08803.
    MLP is needed to make sure the scaling between layers are more or less the same.

    Args:
        n_layers: (int) The number of affine coupling layers.
        n_features: (int) The number of features in the input.
        n_hidden: (int) The number of hidden units in the MLP.
        dt: (Float) Scaling factor for the affine coupling layer.

    Properties:
        data_mean: (ndarray) Mean of Gaussian base distribution
        data_cov: (ndarray) Covariance of Gaussian base distribution
    """

    base_dist: Distribution
    affine_coupling: List[MaskedCouplingLayer]
    _n_features: int
    _data_mean: Float[Array, " n_dim"] | None
    _data_cov: Float[Array, " n_dim n_dim"] | None

    @property
    def n_features(self):
        return self._n_features

    @property
    def data_mean(self):
        return jax.lax.stop_gradient(self._data_mean)

    @property
    def data_cov(self):
        return jax.lax.stop_gradient(jnp.atleast_2d(self._data_cov))

    def __init__(
        self,
        n_features: int,
        n_layers: int,
        n_hidden: int,
        key: PRNGKeyArray,
        **kwargs
    ):
        if kwargs.get("base_dist") is not None:
            self.base_dist = kwargs.get("base_dist")  # type: ignore
        else:
            self.base_dist = Gaussian(
                jnp.zeros(n_features), jnp.eye(n_features), learnable=False
            )

        if kwargs.get("data_mean") is not None:
            self._data_mean = kwargs.get("data_mean")
        else:
            self._data_mean = jnp.zeros(n_features)

        if kwargs.get("data_cov") is not None:
            self._data_cov = kwargs.get("data_cov")
        else:
            self._data_cov = jnp.eye(n_features)

        self._n_features = n_features

        def make_layer(i: int, key: PRNGKeyArray):
            key, scale_subkey, shift_subkey = jax.random.split(key, 3)
            mask = jnp.ones(n_features)
            mask = mask.at[: int(n_features / 2)].set(0)
            mask = jax.lax.cond(i % 2 == 0, lambda x: 1 - x, lambda x: x, mask)
            scale_MLP = MLP([n_features, n_hidden, n_features], key=scale_subkey)
            shift_MLP = MLP([n_features, n_hidden, n_features], key=shift_subkey)
            return MaskedCouplingLayer(MLPAffine(scale_MLP, shift_MLP), mask)

        keys = jax.random.split(key, n_layers)
        self.affine_coupling = eqx.filter_vmap(make_layer)(jnp.arange(n_layers), keys)

    def __call__(self, x: Array) -> Tuple[Array, Float]:
        return self.forward(x)

    def forward(self, x: Array) -> Tuple[Array, Float]:
        log_det = 0.0
        dynamics, statics = eqx.partition(self.affine_coupling, eqx.is_array)

        def f(carry, data):
            x, log_det = carry
            layers = eqx.combine(data, statics)
            x, log_det_i = layers(x)
            return (x, log_det + log_det_i), None

        (x, log_det), _ = jax.lax.scan(f, (x, log_det), dynamics)
        return x, log_det

    @partial(jax.vmap, in_axes=(None, 0))
    def inverse(self, x: Array) -> Tuple[Array, Float]:
        """From latent space to data space"""
        log_det = 0.0
        dynamics, statics = eqx.partition(self.affine_coupling, eqx.is_array)

        def f(carry, data):
            x, log_det = carry
            layers = eqx.combine(data, statics)
            x, log_det_i = layers.inverse(x)
            return (x, log_det + log_det_i), None

        (x, log_det), _ = jax.lax.scan(f, (x, log_det), dynamics, reverse=True)
        return x, log_det

    @eqx.filter_jit
    def sample(self, rng_key: PRNGKeyArray, n_samples: int) -> Array:
        samples = self.base_dist.sample(rng_key, n_samples)
        samples = self.inverse(samples)[0]
        samples = samples * jnp.sqrt(jnp.diag(self.data_cov)) + self.data_mean
        return samples

    @eqx.filter_jit
    @partial(jax.vmap, in_axes=(None, 0))
    def log_prob(self, x: Array) -> Array:
        x = (x - self.data_mean) / jnp.sqrt(jnp.diag(self.data_cov))
        y, log_det = self.__call__(x)
        log_det = log_det + jax.scipy.stats.multivariate_normal.logpdf(
            y, jnp.zeros(self.n_features), jnp.eye(self.n_features)
        )
        return log_det

inverse(x) ¤

From latent space to data space

Source code in flowMC/nfmodel/realNVP.py
188
189
190
191
192
193
194
195
196
197
198
199
200
201
@partial(jax.vmap, in_axes=(None, 0))
def inverse(self, x: Array) -> Tuple[Array, Float]:
    """From latent space to data space"""
    log_det = 0.0
    dynamics, statics = eqx.partition(self.affine_coupling, eqx.is_array)

    def f(carry, data):
        x, log_det = carry
        layers = eqx.combine(data, statics)
        x, log_det_i = layers.inverse(x)
        return (x, log_det + log_det_i), None

    (x, log_det), _ = jax.lax.scan(f, (x, log_det), dynamics, reverse=True)
    return x, log_det