Exponential Distribution¤
stamox.distribution.pexp(q: Union[Float, ArrayLike], rate: Union[Float, ArrayLike], lower_tail: bool = True, log_prob: bool = False, dtype = <class 'jax.numpy.float64'>) -> ArrayLike
¤
Calculates the probability of a given value or array of values for an exponential distribution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
q |
Union[Float, ArrayLike] |
Union[Float, ArrayLike]. The value or array of values to calculate the probability of. |
required |
rate |
Union[Float, ArrayLike] |
Union[Float, ArrayLike]. The rate parameter of the exponential distribution. |
required |
lower_tail |
bool |
bool, optional. Whether to return the lower tail probability (default is True). |
True |
log_prob |
bool |
bool, optional. Whether to return the log probability (default is False). |
False |
dtype |
jnp.dtype, optional. The dtype of the output (default is float_). |
<class 'jax.numpy.float64'> |
Returns:
Type | Description |
---|---|
ArrayLike |
The probability of the given value or array of values. |
Examples:
>>> pexp(1.0, 0.5)
Array(0.39346933, dtype=float32, weak_type=True)
stamox.distribution.qexp(p: Union[float, ArrayLike], rate: Union[float, ArrayLike], lower_tail = True, log_prob = False, dtype = <class 'jax.numpy.float64'>) -> ArrayLike
¤
Computes the quantile of an exponential distribution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
p |
Union[float, ArrayLike] |
Probability or log probability. |
required |
rate |
Union[float, ArrayLike] |
Rate parameter of the exponential distribution. |
required |
lower_tail |
bool |
Whether to compute the lower tail. Defaults to True. |
True |
log_prob |
bool |
Whether |
False |
dtype |
jnp.dtype |
The dtype of the output. Defaults to jnp.float_. |
<class 'jax.numpy.float64'> |
Returns:
Type | Description |
---|---|
ArrayLike |
The quantile of the exponential distribution. |
Examples:
>>> qexp(0.5, 1.0)
Array([0.6931472], dtype=float32, weak_type=True)
stamox.distribution.dexp(x: Union[Float, ArrayLike], rate: Union[Float, ArrayLike], lower_tail: bool = True, log_prob: bool = False, dtype = <class 'jax.numpy.float64'>) -> ArrayLike
¤
Calculates the derivative of the exponential distribution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
Union[Float, ArrayLike] |
The value at which to evaluate the derivative. |
required |
rate |
Union[Float, ArrayLike] |
The rate parameter of the exponential distribution. |
required |
lower_tail |
bool |
Whether to calculate the lower tail probability. Defaults to True. |
True |
log_prob |
bool |
Whether to return the log probability. Defaults to False. |
False |
dtype |
jnp.dtype |
The dtype of the output. Defaults to jnp.float_. |
<class 'jax.numpy.float64'> |
Returns:
Type | Description |
---|---|
ArrayLike |
The derivative of the exponential distribution evaluated at x. |
Examples:
>>> dexp(1.0, 0.5, lower_tail=True, log_prob=False)
Array([0.30326533], dtype=float32, weak_type=True)
stamox.distribution.rexp(key: Union[jax.Array, jax._src.prng.PRNGKeyArray], sample_shape: Optional[Sequence[int]] = None, rate: Union[Float, ArrayLike] = None, lower_tail: bool = True, log_prob: bool = False, dtype = <class 'jax.numpy.float64'>) -> ArrayLike
¤
Generates random samples from the exponential distribution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
KeyArray |
A PRNGKey to use for generating random numbers. |
required |
sample_shape |
Optional[Shape] |
The shape of the output array. Defaults to None. |
None |
rate |
Union[Float, ArrayLike] |
The rate parameter of the exponential distribution. Defaults to None. |
None |
lower_tail |
bool |
Whether to return the lower tail of the distribution. Defaults to True. |
True |
log_prob |
bool |
Whether to return the log probability of the samples. Defaults to False. |
False |
dtype |
jnp.dtype |
The dtype of the output. Defaults to jnp.float_. |
<class 'jax.numpy.float64'> |
Returns:
Type | Description |
---|---|
ArrayLike |
An array of random samples from the exponential distribution. |
Examples:
>>> key = jax.random.PRNGKey(0)
>>> rexp(key, sample_shape=(2, 3), rate=1.0, lower_tail=False, log_prob=True)
Array([[-0.69314718, -0.69314718, -0.69314718],
[-0.69314718, -0.69314718, -0.69314718]], dtype=float32)