Skip to content

Gamma Distribution¤

stamox.distribution.pgamma(q: Union[Float, ArrayLike], shape: Union[Float, ArrayLike] = 1.0, rate: Union[Float, ArrayLike] = 1.0, lower_tail: bool = True, log_prob: bool = False, dtype = <class 'jax.numpy.float64'>) -> ArrayLike ¤

Computes the cumulative distribution function of the gamma distribution.

Parameters:

Name Type Description Default
q Union[Float, ArrayLike]

A float or array-like object representing the input to the gamma function.

required
shape Union[Float, ArrayLike]

A float or array-like object representing the shape parameter of the gamma function.

1.0
rate Union[Float, ArrayLike]

A float or array-like object representing the rate parameter of the gamma function.

1.0
lower_tail bool

A boolean indicating whether to compute the lower tail of the gamma function.

True
log_prob bool

A boolean indicating whether to compute the logarithm of the probability density function.

False
dtype

The dtype of the output. Defaults to jnp.float_.

<class 'jax.numpy.float64'>

Returns:

Type Description
ArrayLike

The CDF value of the given value or array of values.

Examples:

>>> pgamma(1.0, 0.5, 0.5)
Array(0.6826893, dtype=float32, weak_type=True)

stamox.distribution.qgamma(p: Union[Float, ArrayLike], shape: Union[Float, ArrayLike] = 1.0, rate: Union[Float, ArrayLike] = 1.0, lower_tail: bool = True, log_prob: bool = False, dtype = <class 'jax.numpy.float64'>) -> ArrayLike ¤

Computes the quantile of the gamma distribution.

Parameters:

Name Type Description Default
p Union[Float, ArrayLike]

A float or array-like object representing the quantile.

required
shape Union[Float, ArrayLike]

A float or array-like object representing the shape parameter of the gamma distribution.

1.0
rate Union[Float, ArrayLike]

A float or array-like object representing the rate parameter of the gamma distribution.

1.0
lower_tail bool

A boolean indicating whether to compute the lower tail (default) or upper tail.

True
log_prob bool

A boolean indicating whether to compute the log probability (default False).

False
dtype

The dtype of the output. Defaults to float_.

<class 'jax.numpy.float64'>

Returns:

Type Description
ArrayLike

The quantile of the gamma distribution.

Examples:

>>> qgamma(0.5, 0.5, 0.5)
Array([0.45493677], dtype=float32)

stamox.distribution.dgamma(x: Union[Float, ArrayLike], shape: Union[Float, ArrayLike] = 1.0, rate: Union[Float, ArrayLike] = 1.0, lower_tail: bool = True, log_prob: bool = False, dtype = <class 'jax.numpy.float64'>) -> ArrayLike ¤

Compute density of gamma distribution.

Parameters:

Name Type Description Default
x Union[Float, ArrayLike]

The value at which to evaluate the gamma distribution.

required
shape Union[Float, ArrayLike]

The shape parameter of the gamma distribution. Defaults to 1.0.

1.0
rate Union[Float, ArrayLike]

The rate parameter of the gamma distribution. Defaults to 1.0.

1.0
lower_tail bool

Whether to compute the lower tail of the gamma distribution. Defaults to True.

True
log_prob bool

Whether to return the log probability. Defaults to False.

False
dtype jnp.dtype

The dtype of the output. Defaults to

<class 'jax.numpy.float64'>

Returns:

Type Description
ArrayLike

The density of the gamma distribution evaluated at x. If log_prob is True, returns the log probability.

Examples:

>>> dgamma(1.0, 0.5, 0.5)
Array(0.24197064, dtype=float32, weak_type=True)

stamox.distribution.rgamma(key, sample_shape: Optional[Sequence[int]] = None, shape: Union[Float, ArrayLike] = 1.0, rate: Union[Float, ArrayLike] = 1.0, lower_tail: bool = True, log_prob: bool = False, dtype = <class 'jax.numpy.float64'>) -> ArrayLike ¤

Generates random gamma values.

Parameters:

Name Type Description Default
key

A PRNGKey to use for the random number generation.

required
sample_shape Optional[Sequence[int]]

An optional shape for the output array.

None
shape Union[Float, ArrayLike]

The shape parameter of the gamma distribution.

1.0
rate Union[Float, ArrayLike]

The rate parameter of the gamma distribution.

1.0
lower_tail bool

Whether to return the lower tail of the distribution.

True
log_prob bool

Whether to return the log probability of the result.

False
dtype

The dtype of the output. Defaults to jnp.float_.

<class 'jax.numpy.float64'>

Returns:

Type Description
ArrayLike

A random gamma value or an array of random gamma values.

Examples:

>>> rgamma(key, shape=0.5, rate=0.5)
Array(0.3384059, dtype=float32)