Batch With VMap
In JAX, we can use vmap to vectorize a function. For example,
import jax.numpy as jnp
from jax import vmap
def f(x):
return x ** 2
vmap(f)(jnp.arange(3)) # Array([0, 1, 4], dtype=int32)
And we can specify the axis to vectorize on, namely a batch axis. For example,
import jax.numpy as jnp
from jax import vmap
def f(x):
return x ** 2
vmap(f, in_axes=(0,))(jnp.arange(3)) # Array([0, 1, 4], dtype=int32)
However, it's difficult to specify batch axis for a customized PyTree node. In fenbux, every distribution is treated as a PyTree, user can use use_batch=True to specify the batch axis . For example,
import jax.numpy as jnp
from jax import vmap
from fenbux import logpdf
from fenbux.univariate import Normal
dist = Normal(0, jnp.ones((2, 3, 5))) # each batch shape is (2, 3)
x = jnp.zeros((2, 3, 5))
# set claim use_batch=True to use vmap
vmap(logpdf, in_axes=(Normal(None, 2, use_batch=True), 2))(dist, x)
Here Normal(None, 0, use_batch=True) means that we don't care about the batch axis of the first argument mean, and we want to vectorize on the second argument sd on 3rd dimension, namely the batch dimension.