Skip to content

Normal Distribution¤

stamox.distribution.pnorm(q: Union[Float, ArrayLike], mean: Union[Float, ArrayLike] = 0.0, sd: Union[Float, ArrayLike] = 1.0, lower_tail = True, log_prob = False, dtype = <class 'jax.numpy.float64'>) -> ArrayLike ¤

Calculate the cumulative distribution function (CDF) of the normal distribution.

Parameters:

Name Type Description Default
q Union[Float, ArrayLike]

The quantiles to calculate the CDF at.

required
mean Union[Float, ArrayLike]

The mean of the normal distribution. Defaults to 0.0.

0.0
sd Union[Float, ArrayLike]

The standard deviation of the normal distribution. Defaults to 1.0.

1.0
lower_tail bool

If True, calculate the probability that x is less than or equal to the given quantile(s). If False, calculate the probability that x is greater than the given quantile(s). Defaults to True.

True
log_prob bool

If True, return the log of the CDF instead of the actual value. Defaults to False.

False
dtype jnp.dtype

The dtype of the output. Defaults to jnp.float_.

<class 'jax.numpy.float64'>

Returns:

Type Description
ArrayLike

The CDF of the normal distribution evaluated at x.

Examples:

>>> pnorm(2.0)
Array(0.97724986, dtype=float32)
>>> pnorm([1.5, 2.0, 2.5], mean=2.0, sd=0.5, lower_tail=False)
Array([0.8413447 , 0.5       , 0.15865529], dtype=float32)

stamox.distribution.qnorm(p: Union[Float, ArrayLike], mean: Union[Float, ArrayLike] = 0.0, sd: Union[Float, ArrayLike] = 1.0, lower_tail = True, log_prob = False, dtype = <class 'jax.numpy.float64'>) -> ArrayLike ¤

Calculates the quantile function of the normal distribution for a given probability.

Parameters:

Name Type Description Default
p float or jnp.ndarray

Probability values.

required
mean float or jnp.ndarray

Mean of the normal distribution. Default is 0.0.

0.0
sd float or jnp.ndarray

Standard deviation of the normal distribution. Default is 1.0.

1.0
lower_tail bool

If True, returns P(X ≤ x). If False, returns P(X > x). Default is True.

True
log_prob bool

If True, returns the logarithm of the quantile function. Default is False.

False
dtype jnp.dtype

The dtype of the output. Default is jnp.float_.

<class 'jax.numpy.float64'>

Returns:

Type Description
ArrayLike

The inverse cumulative density function of the normal distribution evaluated at q.

Examples:

>>> qnorm(0.5)
Array([0.], dtype=float32)
>>> qnorm([0.25, 0.75], mean=3, sd=2)
Array([1.6510204, 4.3489795], dtype=float32)

stamox.distribution.dnorm(x: Union[Float, ArrayLike], mean: Union[Float, ArrayLike] = 0.0, sd: Union[Float, ArrayLike] = 1.0, lower_tail = True, log_prob = False, dtype = <class 'jax.numpy.float64'>) -> ArrayLike ¤

Probability density function (PDF) for Normal distribution.

Parameters:

Name Type Description Default
x Union[Float, ArrayLike]

The input value(s) at which to evaluate the PDF.

required
mean Union[Float, ArrayLike]

The mean of the normal distribution. Defaults to 0.0.

0.0
sd Union[Float, ArrayLike]

The standard deviation of the normal distribution. Defaults to 1.0.

1.0
lower_tail bool

If True (default), returns the cumulative distribution function (CDF) from negative infinity up to x. Otherwise, returns the CDF from x to positive infinity.

True
log_prob bool

If True, returns the log-probability instead of the probability.

False
dtype jnp.dtype

The dtype of the output. Default is jnp.float_.

<class 'jax.numpy.float64'>

Returns:

Type Description
ArrayLike

The probability density function evaluated at point(s) x.

Examples:

>>> import jax.numpy as jnp
>>> x = jnp.array([0.5, 1.0, -1.5])
>>> dnorm(x)
Array([0.35206532, 0.24197075, 0.12951761], dtype=float32)

stamox.distribution.rnorm(key: PRNGKeyArray, sample_shape: Optional[Sequence[int]] = None, mean: Union[Float, ArrayLike] = 0.0, sd: Union[Float, ArrayLike] = 1.0, lower_tail: Bool = True, log_prob: Bool = False, dtype = <class 'jax.numpy.float64'>) -> ArrayLike ¤

Generates random variables from a normal distribution.

Parameters:

Name Type Description Default
key PRNGKeyArray

A KeyArray object used to generate the random numbers.

required
sample_shape Optional[Sequence[int]]

An optional tuple of integers specifying the shape of the

None
mean Union[Float, ArrayLike]

The mean of the normal distribution. Defaults to 0.0.

0.0
sd Union[Float, ArrayLike]

The standard deviation of the normal distribution. Defaults to 1.0.

1.0
lower_tail Bool

If True (default), returns the cumulative distribution function (CDF) from negative infinity up to x. Otherwise, returns the CDF from x to positive infinity.

True
log_prob Bool

If True, returns the log-probability instead of the probability.

False
dtype

The dtype of the output. Defaults to jnp.float_.

<class 'jax.numpy.float64'>

Returns:

Type Description
ArrayLike

Random samples from a normal distribution.

Examples:

>>> key = random.PRNGKey(0)
>>> rnorm(key, sample_shape=(3, 2))
Array([[ 0.18784384, -1.2833426 ],
        [ 0.6494181 ,  1.2490594 ],
        [ 0.24447003, -0.11744965]], dtype=float32)