Skip to content

Batch With VMap

In JAX, we can use vmap to vectorize a function. For example,

import jax.numpy as jnp
from jax import vmap

def f(x):
    return x ** 2

vmap(f)(jnp.arange(3)) # Array([0, 1, 4], dtype=int32)

And we can specify the axis to vectorize on, namely a batch axis. For example,

import jax.numpy as jnp
from jax import vmap

def f(x):
    return x ** 2

vmap(f, in_axes=(0,))(jnp.arange(3)) # Array([0, 1, 4], dtype=int32)

However, it's difficult to specify batch axis for a customized PyTree node. In fenbux, every distribution is treated as a PyTree, user can use use_batch=True to specify the batch axis . For example,

import jax.numpy as jnp
from jax import vmap

from fenbux import logpdf
from fenbux.univariate import Normal

dist = Normal(0, jnp.ones((2, 3, 5))) # each batch shape is (2, 3)
x = jnp.zeros((2, 3, 5))
# set claim use_batch=True to use vmap
vmap(logpdf, in_axes=(Normal(None, 2, use_batch=True), 2))(dist, x) 

Here Normal(None, 0, use_batch=True) means that we don't care about the batch axis of the first argument mean, and we want to vectorize on the second argument sd on 3rd dimension, namely the batch dimension.