Skip to content

Exponential Distribution¤

stamox.distribution.pexp(q: Union[Float, ArrayLike], rate: Union[Float, ArrayLike], lower_tail: bool = True, log_prob: bool = False, dtype = <class 'jax.numpy.float64'>) -> ArrayLike ¤

Calculates the probability of a given value or array of values for an exponential distribution.

Parameters:

Name Type Description Default
q Union[Float, ArrayLike]

Union[Float, ArrayLike]. The value or array of values to calculate the probability of.

required
rate Union[Float, ArrayLike]

Union[Float, ArrayLike]. The rate parameter of the exponential distribution.

required
lower_tail bool

bool, optional. Whether to return the lower tail probability (default is True).

True
log_prob bool

bool, optional. Whether to return the log probability (default is False).

False
dtype

jnp.dtype, optional. The dtype of the output (default is float_).

<class 'jax.numpy.float64'>

Returns:

Type Description
ArrayLike

The probability of the given value or array of values.

Examples:

>>> pexp(1.0, 0.5)
Array(0.39346933, dtype=float32, weak_type=True)

stamox.distribution.qexp(p: Union[float, ArrayLike], rate: Union[float, ArrayLike], lower_tail = True, log_prob = False, dtype = <class 'jax.numpy.float64'>) -> ArrayLike ¤

Computes the quantile of an exponential distribution.

Parameters:

Name Type Description Default
p Union[float, ArrayLike]

Probability or log probability.

required
rate Union[float, ArrayLike]

Rate parameter of the exponential distribution.

required
lower_tail bool

Whether to compute the lower tail. Defaults to True.

True
log_prob bool

Whether p is a log probability. Defaults to False.

False
dtype jnp.dtype

The dtype of the output. Defaults to jnp.float_.

<class 'jax.numpy.float64'>

Returns:

Type Description
ArrayLike

The quantile of the exponential distribution.

Examples:

>>> qexp(0.5, 1.0)
Array([0.6931472], dtype=float32, weak_type=True)

stamox.distribution.dexp(x: Union[Float, ArrayLike], rate: Union[Float, ArrayLike], lower_tail: bool = True, log_prob: bool = False, dtype = <class 'jax.numpy.float64'>) -> ArrayLike ¤

Calculates the derivative of the exponential distribution.

Parameters:

Name Type Description Default
x Union[Float, ArrayLike]

The value at which to evaluate the derivative.

required
rate Union[Float, ArrayLike]

The rate parameter of the exponential distribution.

required
lower_tail bool

Whether to calculate the lower tail probability. Defaults to True.

True
log_prob bool

Whether to return the log probability. Defaults to False.

False
dtype jnp.dtype

The dtype of the output. Defaults to jnp.float_.

<class 'jax.numpy.float64'>

Returns:

Type Description
ArrayLike

The derivative of the exponential distribution evaluated at x.

Examples:

>>> dexp(1.0, 0.5, lower_tail=True, log_prob=False)
Array([0.30326533], dtype=float32, weak_type=True)

stamox.distribution.rexp(key: Union[jax.Array, jax._src.prng.PRNGKeyArray], sample_shape: Optional[Sequence[int]] = None, rate: Union[Float, ArrayLike] = None, lower_tail: bool = True, log_prob: bool = False, dtype = <class 'jax.numpy.float64'>) -> ArrayLike ¤

Generates random samples from the exponential distribution.

Parameters:

Name Type Description Default
key KeyArray

A PRNGKey to use for generating random numbers.

required
sample_shape Optional[Shape]

The shape of the output array. Defaults to None.

None
rate Union[Float, ArrayLike]

The rate parameter of the exponential distribution. Defaults to None.

None
lower_tail bool

Whether to return the lower tail of the distribution. Defaults to True.

True
log_prob bool

Whether to return the log probability of the samples. Defaults to False.

False
dtype jnp.dtype

The dtype of the output. Defaults to jnp.float_.

<class 'jax.numpy.float64'>

Returns:

Type Description
ArrayLike

An array of random samples from the exponential distribution.

Examples:

>>> key = jax.random.PRNGKey(0)
>>> rexp(key, sample_shape=(2, 3), rate=1.0, lower_tail=False, log_prob=True)
Array([[-0.69314718, -0.69314718, -0.69314718],
       [-0.69314718, -0.69314718, -0.69314718]], dtype=float32)