Skip to content

Base

ProposalBase ¤

Source code in flowMC/proposal/base.py
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
@jax.tree_util.register_pytree_node_class
class ProposalBase:
    def __init__(
        self,
        logpdf: Callable[[Float[Array, " n_dim"], PyTree], Float],
        jit: bool,
        **kwargs,
    ):
        """
        Initialize the sampler class
        """
        self.logpdf = logpdf
        self.jit = jit
        self.logpdf_vmap = jax.vmap(logpdf, in_axes=(0, None))
        self.kernel_vmap = jax.vmap(self.kernel, in_axes=(0, 0, 0, None))
        self.update_vmap = jax.vmap(
            self.update,
            in_axes=(None, (0, 0, 0, 0, None)),
            out_axes=(0, 0, 0, 0, None),
        )
        self.kwargs = kwargs
        if self.jit is True:
            self.logpdf_vmap = jax.jit(self.logpdf_vmap)
            self.kernel = jax.jit(self.kernel)
            self.kernel_vmap = jax.jit(self.kernel_vmap)
            self.update = jax.jit(self.update)
            self.update_vmap = jax.jit(self.update_vmap)

    def precompilation(self, n_chains, n_dims, n_step, data):
        if self.jit is True:
            print("jit is requested, precompiling kernels and update...")
            key = jax.random.split(jax.random.PRNGKey(0), n_chains)

            self.logpdf_vmap = (
                jax.jit(self.logpdf_vmap)
                .lower(jnp.ones((n_chains, n_dims)), data)
                .compile()
            )
            self.kernel_vmap = (
                jax.jit(self.kernel_vmap)
                .lower(
                    key,
                    jnp.ones((n_chains, n_dims)),
                    jnp.ones((n_chains, )),
                    data,
                )
                .compile()
            )
            self.update_vmap = (
                jax.jit(self.update_vmap)
                .lower(
                    1,
                    (
                        key,
                        jnp.ones((n_chains, n_step, n_dims)),
                        jnp.ones((n_chains, n_step, )),
                        jnp.zeros((n_chains, n_step, )),
                        data,
                    ),
                )
                .compile()
            )
        else:
            print("jit is not requested, compiling only vmap functions...")
            key = jax.random.split(jax.random.PRNGKey(0), n_chains)
            self.logpdf_vmap = self.logpdf_vmap(jnp.ones((n_chains, n_dims)), data)
            self.kernel_vmap(
                key,
                jnp.ones((n_chains, n_dims)),
                jnp.ones((n_chains, )),
                data,
            )
            self.update_vmap(
                1,
                (
                    key,
                    jnp.ones((n_chains, n_step, n_dims)),
                    jnp.ones((n_chains, n_step, )),
                    jnp.zeros((n_chains, n_step, )),
                    data,
                ),
            )

    @abstractmethod
    def kernel(
        self,
        rng_key: PRNGKeyArray,
        position: Float[Array, "nstep  n_dim"],
        log_prob: Float[Array, "nstep 1"],
        data: PyTree,
    ) -> tuple[
        Float[Array, "nstep  n_dim"], Float[Array, "nstep 1"], Int[Array, "n_step 1"]
    ]:
        """
        Kernel for one step in the proposal cycle.
        """

    @abstractmethod
    def update(
        self,
        i: Float,
        state: tuple[
            PRNGKeyArray,
            Float[Array, "nstep  n_dim"],
            Float[Array, "nstep 1"],
            Int[Array, "n_step 1"],
            PyTree,
        ],
    ) -> tuple[
        PRNGKeyArray,
        Float[Array, "nstep  n_dim"],
        Float[Array, "nstep 1"],
        Int[Array, "n_step 1"],
        PyTree,
    ]:
        """
        Make the update function for multiple steps
        """

    @abstractmethod
    def sample(
        self,
        rng_key: PRNGKeyArray,
        n_steps: int,
        initial_position: Float[Array, "n_chains  n_dim"],
        data: PyTree,
        verbose: bool = False,
    ) -> tuple[
        PRNGKeyArray,
        Float[Array, "n_chains n_steps  n_dim"],
        Float[Array, "n_chains n_steps 1"],
        Int[Array, "n_chains n_steps 1"],
    ]:
        """
        Make the sampler for multiple chains given initial positions
        """

    def tree_flatten(self):
        children = ()

        aux_data = {"logpdf": self.logpdf, "jit": self.jit, "kwargs": self.kwargs}
        return (children, aux_data)

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*children, **aux_data)

__init__(logpdf, jit, **kwargs) ¤

Initialize the sampler class

Source code in flowMC/proposal/base.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def __init__(
    self,
    logpdf: Callable[[Float[Array, " n_dim"], PyTree], Float],
    jit: bool,
    **kwargs,
):
    """
    Initialize the sampler class
    """
    self.logpdf = logpdf
    self.jit = jit
    self.logpdf_vmap = jax.vmap(logpdf, in_axes=(0, None))
    self.kernel_vmap = jax.vmap(self.kernel, in_axes=(0, 0, 0, None))
    self.update_vmap = jax.vmap(
        self.update,
        in_axes=(None, (0, 0, 0, 0, None)),
        out_axes=(0, 0, 0, 0, None),
    )
    self.kwargs = kwargs
    if self.jit is True:
        self.logpdf_vmap = jax.jit(self.logpdf_vmap)
        self.kernel = jax.jit(self.kernel)
        self.kernel_vmap = jax.jit(self.kernel_vmap)
        self.update = jax.jit(self.update)
        self.update_vmap = jax.jit(self.update_vmap)

kernel(rng_key, position, log_prob, data) abstractmethod ¤

Kernel for one step in the proposal cycle.

Source code in flowMC/proposal/base.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
@abstractmethod
def kernel(
    self,
    rng_key: PRNGKeyArray,
    position: Float[Array, "nstep  n_dim"],
    log_prob: Float[Array, "nstep 1"],
    data: PyTree,
) -> tuple[
    Float[Array, "nstep  n_dim"], Float[Array, "nstep 1"], Int[Array, "n_step 1"]
]:
    """
    Kernel for one step in the proposal cycle.
    """

sample(rng_key, n_steps, initial_position, data, verbose=False) abstractmethod ¤

Make the sampler for multiple chains given initial positions

Source code in flowMC/proposal/base.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
@abstractmethod
def sample(
    self,
    rng_key: PRNGKeyArray,
    n_steps: int,
    initial_position: Float[Array, "n_chains  n_dim"],
    data: PyTree,
    verbose: bool = False,
) -> tuple[
    PRNGKeyArray,
    Float[Array, "n_chains n_steps  n_dim"],
    Float[Array, "n_chains n_steps 1"],
    Int[Array, "n_chains n_steps 1"],
]:
    """
    Make the sampler for multiple chains given initial positions
    """

update(i, state) abstractmethod ¤

Make the update function for multiple steps

Source code in flowMC/proposal/base.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
@abstractmethod
def update(
    self,
    i: Float,
    state: tuple[
        PRNGKeyArray,
        Float[Array, "nstep  n_dim"],
        Float[Array, "nstep 1"],
        Int[Array, "n_step 1"],
        PyTree,
    ],
) -> tuple[
    PRNGKeyArray,
    Float[Array, "nstep  n_dim"],
    Float[Array, "nstep 1"],
    Int[Array, "n_step 1"],
    PyTree,
]:
    """
    Make the update function for multiple steps
    """