Skip to content

Beta Distribution¤

stamox.distribution.pbeta(q: Union[Float, ArrayLike], a: Union[Float, ArrayLike], b: Union[Float, ArrayLike], lower_tail = True, log_prob = False, dtype = <class 'jax.numpy.float64'>) -> ArrayLike ¤

Computes the cumulative distribution function of the beta distribution.

Parameters:

Name Type Description Default
q Union[Float, ArrayLike]

Quantiles.

required
a Union[Float, ArrayLike]

Shape parameter.

required
b Union[Float, ArrayLike]

Shape parameter.

required
lower_tail bool

If True (default), probabilities are P[X ≤ x], otherwise, P[X > x].

True
log_prob bool

If True, probabilities are given as log(P).

False
dtype jnp.dtype

The dtype of the output. Defaults to None.

<class 'jax.numpy.float64'>

Returns:

Type Description
ArrayLike

The probability or log of the probability for each quantile.

Examples:

>>> q = jnp.array([0.1, 0.5, 0.9])
>>> a = 2.0
>>> b = 3.0
>>> pbeta(q, a, b)
Array([0.05230004, 0.68749976, 0.9963    ], dtype=float32)

stamox.distribution.qbeta(p: Union[Float, ArrayLike], a: Union[Float, ArrayLike], b: Union[Float, ArrayLike], lower_tail: bool = True, log_prob: bool = False, dtype = <class 'jax.numpy.float64'>) -> ArrayLike ¤

Computes the quantile of beta distribution function.

Parameters:

Name Type Description Default
p Union[Float, ArrayLike]

A float or array-like object representing the quantile.

required
a Union[Float, ArrayLike]

A float or array-like object representing the alpha parameter.

required
b Union[Float, ArrayLike]

A float or array-like object representing the beta parameter.

required
lower_tail bool

A boolean indicating whether to compute the lower tail of the

True
log_prob bool

A boolean indicating whether to compute the log probability

False
dtype

The dtype of the output. Defaults to None.

<class 'jax.numpy.float64'>

Returns:

Type Description
ArrayLike

The value of the beta distribution at the given quantile.

Examples:

>>> qbeta(0.5, 2, 3)
Array(0.38572744, dtype=float32)

stamox.distribution.dbeta(x: Union[Float, ArrayLike], a: Union[Float, ArrayLike], b: Union[Float, ArrayLike], lower_tail: bool = True, log_prob: bool = False, dtype = <class 'jax.numpy.float64'>) -> ArrayLike ¤

Calculates the probability density function of the beta distribution.

Parameters:

Name Type Description Default
x Union[Float, ArrayLike]

A float or array-like object representing the value(s) at which to evaluate the PDF.

required
a Union[Float, ArrayLike]

A float or array-like object representing the shape parameter of the beta distribution.

required
b Union[Float, ArrayLike]

A float or array-like object representing the scale parameter of the beta distribution.

required
lower_tail bool

A boolean indicating whether to calculate the lower tail (default True).

True
log_prob bool

A boolean indicating whether to return the logarithm of the PDF (default False).

False
dtype

The dtype of the output. Defaults to None.

<class 'jax.numpy.float64'>

Returns:

Type Description
ArrayLike

The probability density function of the beta distribution evaluated at x.

Examples:

>>> dbeta(0.5, 2, 3, lower_tail=True, log_prob=False)
Array(1.4999996, dtype=float32, weak_type=True)

stamox.distribution.rbeta(key: Union[jax.Array, jax._src.prng.PRNGKeyArray], sample_shape: Optional[Sequence[int]] = None, a: Union[Float, ArrayLike] = None, b: Union[Float, ArrayLike] = None, lower_tail: bool = True, log_prob: bool = False, dtype = <class 'jax.numpy.float64'>) -> ArrayLike ¤

Generates random numbers from the Beta distribution.

Parameters:

Name Type Description Default
key Union[jax.Array, jax._src.prng.PRNGKeyArray]

A PRNGKey used for random number generation.

required
sample_shape Optional[Sequence[int]]

An optional shape for the output samples.

None
a Union[Float, ArrayLike]

The shape parameter of the Beta distribution. Can be either a float or an array-like object.

None
b Union[Float, ArrayLike]

The scale parameter of the Beta distribution. Can be either a float or an array-like object.

None
lower_tail bool

Whether to return the lower tail probability (defaults to True).

True
log_prob bool

Whether to return the log probability (defaults to False).

False
dtype

The dtype of the output. Defaults to jnp.float32.

<class 'jax.numpy.float64'>

Returns:

Type Description
ArrayLike

Random numbers from the Beta distribution.

Examples:

>>> key = jax.random.PRNGKey(0)
>>> rbeta(key, sample_shape=(3,), a=2, b=3)
Array([0.02809353, 0.13760717, 0.49360353], dtype=float32)