Discrete
fenbux.univariate.Bernoulli
Bernoulli distribution. X ~ Bernoulli(p)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
p |
PyTree |
Probability of success. |
0.0 |
dtype |
jax.numpy.dtype |
dtype of the distribution, default jnp.float_. |
<class 'jax.numpy.float64'> |
use_batch |
bool |
Whether to use with vmap. Default False. |
False |
Examples:
>>> import jax.numpy as jnp
>>> from fenbux import logpdf
>>> from fenbux.univariate import Bernoulli
>>> dist = Bernoulli(0.5)
>>> logpdf(dist, jnp.ones((10, )))
fenbux.univariate.Binomial
Binomial distribution.
X ~ Binomial(n, p)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
n |
PyTree |
Number of trials. |
0.0 |
p |
PyTree |
Probability of success. |
0.0 |
dtype |
jax.numpy.dtype |
dtype of the distribution, default jnp.float_. |
<class 'jax.numpy.float64'> |
use_batch |
bool |
Whether to use with vmap. Default False. |
False |
Examples:
>>> import jax.numpy as jnp
>>> from fenbux import logpmf
>>> from fenbux.univariate import Binomial
>>> dist = Binomial(10, 0.5)
>>> logpmf(dist, jnp.ones((10, )))
fenbux.univariate.Geometric
Geometric(p=0.0, dtype=
fenbux.univariate.Poisson
Poisson distribution.
X ~ Poisson(λ)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
rate |
PyTree |
Rate parameter of the distribution. |
0.0 |
dtype |
jax.numpy.dtype |
dtype of the distribution, default jnp.float_. |
<class 'jax.numpy.float64'> |
use_batch |
bool |
Whether to use with vmap. Default False. |
False |
Examples:
>>> import jax.numpy as jnp
>>> from fenbux import logpdf
>>> from fenbux.univariate import Poisson
>>> dist = Poisson(1.0)
>>> logpdf(dist, jnp.ones((10, )))
fenbux.univariate.BetaBinomial
BetaBinomial distribution.
X ~ BetaBinomial(n, a, b)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
n |
PyTree |
Number of trials. |
1 |
a |
PyTree |
Shape parameter a. |
1.0 |
b |
PyTree |
Shape parameter b. |
1.0 |
dtype |
jax.numpy.dtype |
dtype of the distribution, default jnp.float_. |
<class 'jax.numpy.float64'> |
use_batch |
bool |
Whether to use with vmap. Default False. |
False |
Examples:
>>> import jax.numpy as jnp
>>> from fenbux import logpmf
>>> from fenbux.univariate import BetaBinomial
>>> dist = BetaBinomial(10, 1.0, 1.0)
>>> logpmf(dist, jnp.ones((10, )))