Skip to content

Multivariate

fenbux.multivariate.MultivariateNormal

Multivariate normal distribution. X ~ Normal(μ, Σ)

Parameters:

Name Type Description Default
mean PyTree

Mean of the distribution.

0.0
cov PyTree

Covariance matrix of the distribution.

0.0
dtype jax.numpy.dtype

dtype of the distribution, default jnp.float_.

<class 'jax.numpy.float64'>
use_batch bool

Whether to use with vmap. Default False.

False

Examples:

>>> import jax.numpy as jnp
>>> from fenbux import logpdf
>>> from fenbux.distributions import MultivariateNormal
>>> dist = MultivariateNormal(jnp.zeros((10, )), jnp.eye(10))
>>> logpdf(dist, jnp.zeros((10, )))

Attributes:

Name Type Description
mean PyTree

Mean of the distribution.

cov PyTree

Covariance matrix of the distribution.

fenbux.multivariate.Dirichlet

Dirichlet(alpha=0.0, dtype=, use_batch=False)