Skip to content

Discrete

fenbux.univariate.Bernoulli

Bernoulli distribution. X ~ Bernoulli(p)

Parameters:

Name Type Description Default
p PyTree

Probability of success.

0.0
dtype jax.numpy.dtype

dtype of the distribution, default jnp.float_.

<class 'jax.numpy.float64'>
use_batch bool

Whether to use with vmap. Default False.

False

Examples:

>>> import jax.numpy as jnp
>>> from fenbux import logpdf
>>> from fenbux.univariate import Bernoulli
>>> dist = Bernoulli(0.5)
>>> logpdf(dist, jnp.ones((10, )))

fenbux.univariate.Binomial

Binomial distribution.

    X ~ Binomial(n, p)

Parameters:

Name Type Description Default
n PyTree

Number of trials.

0.0
p PyTree

Probability of success.

0.0
dtype jax.numpy.dtype

dtype of the distribution, default jnp.float_.

<class 'jax.numpy.float64'>
use_batch bool

Whether to use with vmap. Default False.

False

Examples:

>>> import jax.numpy as jnp
>>> from fenbux import logpmf
>>> from fenbux.univariate import Binomial
>>> dist = Binomial(10, 0.5)
>>> logpmf(dist, jnp.ones((10, )))

fenbux.univariate.Geometric

Geometric(p=0.0, dtype=, use_batch=False)

fenbux.univariate.Poisson

Poisson distribution.

X ~ Poisson(λ)

Parameters:

Name Type Description Default
rate PyTree

Rate parameter of the distribution.

0.0
dtype jax.numpy.dtype

dtype of the distribution, default jnp.float_.

<class 'jax.numpy.float64'>
use_batch bool

Whether to use with vmap. Default False.

False

Examples:

>>> import jax.numpy as jnp
>>> from fenbux import logpdf
>>> from fenbux.univariate import Poisson
>>> dist = Poisson(1.0)
>>> logpdf(dist, jnp.ones((10, )))

fenbux.univariate.BetaBinomial

BetaBinomial distribution.

    X ~ BetaBinomial(n, a, b)

Parameters:

Name Type Description Default
n PyTree

Number of trials.

1
a PyTree

Shape parameter a.

1.0
b PyTree

Shape parameter b.

1.0
dtype jax.numpy.dtype

dtype of the distribution, default jnp.float_.

<class 'jax.numpy.float64'>
use_batch bool

Whether to use with vmap. Default False.

False

Examples:

>>> import jax.numpy as jnp
>>> from fenbux import logpmf
>>> from fenbux.univariate import BetaBinomial
>>> dist = BetaBinomial(10, 1.0, 1.0)
>>> logpmf(dist, jnp.ones((10, )))