Skip to content

HMC

HMC ¤

Bases: ProposalBase

Hamiltonian Monte Carlo sampler class builiding the hmc_sampler method from target logpdf.

Parameters:

Name Type Description Default
logpdf Callable[[Float[Array, ' n_dim'], PyTree], Float]

target logpdf function

required
jit bool

whether to jit the sampler

required
params

dictionary of parameters for the sampler

required
Source code in flowMC/proposal/HMC.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
class HMC(ProposalBase):
    """
    Hamiltonian Monte Carlo sampler class builiding the hmc_sampler method
    from target logpdf.

    Args:
        logpdf: target logpdf function
        jit: whether to jit the sampler
        params: dictionary of parameters for the sampler
    """

    condition_matrix: Float[Array, " n_dim n_dim"]
    step_size: Float
    n_leapfrog: Int

    def __init__(
        self,
        logpdf: Callable[[Float[Array, " n_dim"], PyTree], Float],
        jit: bool,
        condition_matrix: Float[Array, " n_dim n_dim"] | Float = 1,
        step_size: Float = 0.1,
        n_leapfrog: Int = 10,
    ):
        super().__init__(logpdf, jit, condition_matrix=condition_matrix, step_size=step_size, n_leapfrog=n_leapfrog)

        self.potential: Callable[
            [Float[Array, " n_dim"], PyTree], Float
        ] = lambda x, data: -logpdf(x, data)
        self.grad_potential: Callable[
            [Float[Array, " n_dim"], PyTree], Float[Array, " n_dim"]
        ] = jax.grad(self.potential)

        self.condition_matrix = condition_matrix
        self.step_size = step_size
        self.n_leapfrog = n_leapfrog

        coefs = jnp.ones((self.n_leapfrog + 2, 2))
        coefs = coefs.at[0].set(jnp.array([0, 0.5]))
        coefs = coefs.at[-1].set(jnp.array([1, 0.5]))
        self.leapfrog_coefs = coefs

        self.kinetic: Callable[
            [Float[Array, " n_dim"], Float[Array, " n_dim n_dim"]], Float
        ] = (lambda p, metric: 0.5 * (p**2 * metric).sum())
        self.grad_kinetic = jax.grad(self.kinetic)

    def get_initial_hamiltonian(
        self,
        rng_key: PRNGKeyArray,
        position: Float[Array, " n_dim"],
        data: PyTree,
    ):
        """
        Compute the value of the Hamiltonian from positions with initial momentum draw
        at random from the standard normal distribution.
        """

        momentum = (
            jax.random.normal(rng_key, shape=position.shape)
            * self.condition_matrix ** -0.5
        )
        return self.potential(position, data) + self.kinetic(
            momentum, self.condition_matrix
        )

    def leapfrog_kernel(self, carry, extras):
        position, momentum, data, metric, index = carry
        position = position + self.step_size * self.leapfrog_coefs[index][
            0
        ] * self.grad_kinetic(momentum, metric)
        momentum = momentum - self.step_size * self.leapfrog_coefs[index][
            1
        ] * self.grad_potential(position, data)
        index = index + 1
        return (position, momentum, data, metric, index), extras

    def leapfrog_step(
        self,
        position: Float[Array, " n_dim"],
        momentum: Float[Array, " n_dim"],
        data: PyTree,
        metric: Float[Array, " n_dim n_dim"],
    ) -> tuple[Float[Array, " n_dim"], Float[Array, " n_dim"]]:
        (position, momentum, data, metric, index), _ = jax.lax.scan(
            self.leapfrog_kernel,
            (position, momentum, data, metric, 0),
            jnp.arange(self.n_leapfrog + 2),
        )
        return position, momentum

    def kernel(
        self,
        rng_key: PRNGKeyArray,
        position: Float[Array, " n_dim"],
        log_prob: Float[Array, "1"],
        data: PyTree,
    ) -> tuple[Float[Array, " n_dim"], Float[Array, "1"], Int[Array, "1"]]:
        """
        Note that since the potential function is the negative log likelihood,
        hamiltonian is going down, but the likelihood value should go up.

        Args:
            rng_key (n_chains, 2): random key
            position (n_chains,  n_dim): current position
            PE (n_chains, ): Potential energy of the current position
        """
        key1, key2 = jax.random.split(rng_key)

        momentum: Float[Array, " n_dim"] = (
            jax.random.normal(key1, shape=position.shape)
            * self.condition_matrix ** -0.5
        )
        momentum = jnp.dot(
            jax.random.normal(key1, shape=position.shape),
            jnp.linalg.cholesky(jnp.linalg.inv(self.condition_matrix)).T,
        )
        H = -log_prob + self.kinetic(momentum, self.condition_matrix)
        proposed_position, proposed_momentum = self.leapfrog_step(
            position, momentum, data, self.condition_matrix
        )
        proposed_PE = self.potential(proposed_position, data)
        proposed_ham = proposed_PE + self.kinetic(
            proposed_momentum, self.condition_matrix
        )
        log_acc = H - proposed_ham
        log_uniform = jnp.log(jax.random.uniform(key2))

        do_accept = log_uniform < log_acc

        position = jnp.where(do_accept, proposed_position, position)
        log_prob = jnp.where(do_accept, -proposed_PE, log_prob)  # type: ignore

        return position, log_prob, do_accept

    def update(
        self, i, state
    ) -> tuple[
        PRNGKeyArray,
        Float[Array, "nstep  n_dim"],
        Float[Array, "nstep 1"],
        Int[Array, "n_step 1"],
        PyTree,
    ]:
        key, positions, PE, acceptance, data = state
        _, key = jax.random.split(key)
        new_position, new_PE, do_accept = self.kernel(
            key, positions[i - 1], PE[i - 1], data
        )
        positions = positions.at[i].set(new_position)
        PE = PE.at[i].set(new_PE)
        acceptance = acceptance.at[i].set(do_accept)
        return (key, positions, PE, acceptance, data)

    def sample(
        self,
        rng_key: PRNGKeyArray,
        n_steps: int,
        initial_position: Float[Array, "n_chains  n_dim"],
        data: PyTree,
        verbose: bool = False,
    ) -> tuple[
        PRNGKeyArray,
        Float[Array, "n_chains n_steps  n_dim"],
        Float[Array, "n_chains n_steps 1"],
        Int[Array, "n_chains n_steps 1"],
    ]:
        keys = jax.vmap(jax.random.split)(rng_key)
        rng_key = keys[:, 0]
        logp = self.logpdf_vmap(initial_position, data)
        n_chains = rng_key.shape[0]
        acceptance = jnp.zeros((n_chains, n_steps))
        all_positions = (
            jnp.zeros(
                (
                    n_chains,
                    n_steps,
                )
                + initial_position.shape[-1:]
            )
            + initial_position[:, None]
        )
        all_logp = (
            jnp.zeros(
                (
                    n_chains,
                    n_steps,
                )
            )
            + logp[:, None]
        )
        state = (rng_key, all_positions, all_logp, acceptance, data)

        if verbose:
            iterator_loop = tqdm(
                range(1, n_steps),
                desc="Sampling Locally",
                miniters=int(n_steps / 10),
            )
        else:
            iterator_loop = range(1, n_steps)

        for i in iterator_loop:
            state = self.update_vmap(i, state)

        state = (state[0], state[1], state[2], state[3])
        return state

get_initial_hamiltonian(rng_key, position, data) ¤

Compute the value of the Hamiltonian from positions with initial momentum draw at random from the standard normal distribution.

Source code in flowMC/proposal/HMC.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def get_initial_hamiltonian(
    self,
    rng_key: PRNGKeyArray,
    position: Float[Array, " n_dim"],
    data: PyTree,
):
    """
    Compute the value of the Hamiltonian from positions with initial momentum draw
    at random from the standard normal distribution.
    """

    momentum = (
        jax.random.normal(rng_key, shape=position.shape)
        * self.condition_matrix ** -0.5
    )
    return self.potential(position, data) + self.kinetic(
        momentum, self.condition_matrix
    )

kernel(rng_key, position, log_prob, data) ¤

Note that since the potential function is the negative log likelihood, hamiltonian is going down, but the likelihood value should go up.

Parameters:

Name Type Description Default
rng_key (n_chains, 2)

random key

required
position (n_chains, n_dim)

current position

required
PE (n_chains)

Potential energy of the current position

required
Source code in flowMC/proposal/HMC.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
def kernel(
    self,
    rng_key: PRNGKeyArray,
    position: Float[Array, " n_dim"],
    log_prob: Float[Array, "1"],
    data: PyTree,
) -> tuple[Float[Array, " n_dim"], Float[Array, "1"], Int[Array, "1"]]:
    """
    Note that since the potential function is the negative log likelihood,
    hamiltonian is going down, but the likelihood value should go up.

    Args:
        rng_key (n_chains, 2): random key
        position (n_chains,  n_dim): current position
        PE (n_chains, ): Potential energy of the current position
    """
    key1, key2 = jax.random.split(rng_key)

    momentum: Float[Array, " n_dim"] = (
        jax.random.normal(key1, shape=position.shape)
        * self.condition_matrix ** -0.5
    )
    momentum = jnp.dot(
        jax.random.normal(key1, shape=position.shape),
        jnp.linalg.cholesky(jnp.linalg.inv(self.condition_matrix)).T,
    )
    H = -log_prob + self.kinetic(momentum, self.condition_matrix)
    proposed_position, proposed_momentum = self.leapfrog_step(
        position, momentum, data, self.condition_matrix
    )
    proposed_PE = self.potential(proposed_position, data)
    proposed_ham = proposed_PE + self.kinetic(
        proposed_momentum, self.condition_matrix
    )
    log_acc = H - proposed_ham
    log_uniform = jnp.log(jax.random.uniform(key2))

    do_accept = log_uniform < log_acc

    position = jnp.where(do_accept, proposed_position, position)
    log_prob = jnp.where(do_accept, -proposed_PE, log_prob)  # type: ignore

    return position, log_prob, do_accept