Multivariate
fenbux.multivariate.MultivariateNormal
Multivariate normal distribution. X ~ Normal(μ, Σ)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
mean |
PyTree |
Mean of the distribution. |
0.0 |
cov |
PyTree |
Covariance matrix of the distribution. |
0.0 |
dtype |
jax.numpy.dtype |
dtype of the distribution, default jnp.float_. |
<class 'jax.numpy.float64'> |
use_batch |
bool |
Whether to use with vmap. Default False. |
False |
Examples:
>>> import jax.numpy as jnp
>>> from fenbux import logpdf
>>> from fenbux.distributions import MultivariateNormal
>>> dist = MultivariateNormal(jnp.zeros((10, )), jnp.eye(10))
>>> logpdf(dist, jnp.zeros((10, )))
Attributes:
Name | Type | Description |
---|---|---|
mean |
PyTree |
Mean of the distribution. |
cov |
PyTree |
Covariance matrix of the distribution. |
fenbux.multivariate.Dirichlet
Dirichlet(alpha=0.0, dtype=