9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154 | @jax.tree_util.register_pytree_node_class
class ProposalBase:
def __init__(
self,
logpdf: Callable[[Float[Array, " n_dim"], PyTree], Float],
jit: bool,
**kwargs,
):
"""
Initialize the sampler class
"""
self.logpdf = logpdf
self.jit = jit
self.logpdf_vmap = jax.vmap(logpdf, in_axes=(0, None))
self.kernel_vmap = jax.vmap(self.kernel, in_axes=(0, 0, 0, None))
self.update_vmap = jax.vmap(
self.update,
in_axes=(None, (0, 0, 0, 0, None)),
out_axes=(0, 0, 0, 0, None),
)
self.kwargs = kwargs
if self.jit is True:
self.logpdf_vmap = jax.jit(self.logpdf_vmap)
self.kernel = jax.jit(self.kernel)
self.kernel_vmap = jax.jit(self.kernel_vmap)
self.update = jax.jit(self.update)
self.update_vmap = jax.jit(self.update_vmap)
def precompilation(self, n_chains, n_dims, n_step, data):
if self.jit is True:
print("jit is requested, precompiling kernels and update...")
key = jax.random.split(jax.random.PRNGKey(0), n_chains)
self.logpdf_vmap = (
jax.jit(self.logpdf_vmap)
.lower(jnp.ones((n_chains, n_dims)), data)
.compile()
)
self.kernel_vmap = (
jax.jit(self.kernel_vmap)
.lower(
key,
jnp.ones((n_chains, n_dims)),
jnp.ones((n_chains, )),
data,
)
.compile()
)
self.update_vmap = (
jax.jit(self.update_vmap)
.lower(
1,
(
key,
jnp.ones((n_chains, n_step, n_dims)),
jnp.ones((n_chains, n_step, )),
jnp.zeros((n_chains, n_step, )),
data,
),
)
.compile()
)
else:
print("jit is not requested, compiling only vmap functions...")
key = jax.random.split(jax.random.PRNGKey(0), n_chains)
self.logpdf_vmap = self.logpdf_vmap(jnp.ones((n_chains, n_dims)), data)
self.kernel_vmap(
key,
jnp.ones((n_chains, n_dims)),
jnp.ones((n_chains, )),
data,
)
self.update_vmap(
1,
(
key,
jnp.ones((n_chains, n_step, n_dims)),
jnp.ones((n_chains, n_step, )),
jnp.zeros((n_chains, n_step, )),
data,
),
)
@abstractmethod
def kernel(
self,
rng_key: PRNGKeyArray,
position: Float[Array, "nstep n_dim"],
log_prob: Float[Array, "nstep 1"],
data: PyTree,
) -> tuple[
Float[Array, "nstep n_dim"], Float[Array, "nstep 1"], Int[Array, "n_step 1"]
]:
"""
Kernel for one step in the proposal cycle.
"""
@abstractmethod
def update(
self,
i: Float,
state: tuple[
PRNGKeyArray,
Float[Array, "nstep n_dim"],
Float[Array, "nstep 1"],
Int[Array, "n_step 1"],
PyTree,
],
) -> tuple[
PRNGKeyArray,
Float[Array, "nstep n_dim"],
Float[Array, "nstep 1"],
Int[Array, "n_step 1"],
PyTree,
]:
"""
Make the update function for multiple steps
"""
@abstractmethod
def sample(
self,
rng_key: PRNGKeyArray,
n_steps: int,
initial_position: Float[Array, "n_chains n_dim"],
data: PyTree,
verbose: bool = False,
) -> tuple[
PRNGKeyArray,
Float[Array, "n_chains n_steps n_dim"],
Float[Array, "n_chains n_steps 1"],
Int[Array, "n_chains n_steps 1"],
]:
"""
Make the sampler for multiple chains given initial positions
"""
def tree_flatten(self):
children = ()
aux_data = {"logpdf": self.logpdf, "jit": self.jit, "kwargs": self.kwargs}
return (children, aux_data)
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children, **aux_data)
|