Normal Distribution¤
stamox.distribution.pnorm(q: Union[Float, ArrayLike], mean: Union[Float, ArrayLike] = 0.0, sd: Union[Float, ArrayLike] = 1.0, lower_tail = True, log_prob = False, dtype = <class 'jax.numpy.float64'>) -> ArrayLike
¤
Calculate the cumulative distribution function (CDF) of the normal distribution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
q |
Union[Float, ArrayLike] |
The quantiles to calculate the CDF at. |
required |
mean |
Union[Float, ArrayLike] |
The mean of the normal distribution. Defaults to 0.0. |
0.0 |
sd |
Union[Float, ArrayLike] |
The standard deviation of the normal distribution. Defaults to 1.0. |
1.0 |
lower_tail |
bool |
If True, calculate the probability that x is less than or equal to the given quantile(s). If False, calculate the probability that x is greater than the given quantile(s). Defaults to True. |
True |
log_prob |
bool |
If True, return the log of the CDF instead of the actual value. Defaults to False. |
False |
dtype |
jnp.dtype |
The dtype of the output. Defaults to jnp.float_. |
<class 'jax.numpy.float64'> |
Returns:
Type | Description |
---|---|
ArrayLike |
The CDF of the normal distribution evaluated at x. |
Examples:
>>> pnorm(2.0)
Array(0.97724986, dtype=float32)
>>> pnorm([1.5, 2.0, 2.5], mean=2.0, sd=0.5, lower_tail=False)
Array([0.8413447 , 0.5 , 0.15865529], dtype=float32)
stamox.distribution.qnorm(p: Union[Float, ArrayLike], mean: Union[Float, ArrayLike] = 0.0, sd: Union[Float, ArrayLike] = 1.0, lower_tail = True, log_prob = False, dtype = <class 'jax.numpy.float64'>) -> ArrayLike
¤
Calculates the quantile function of the normal distribution for a given probability.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
p |
float or jnp.ndarray |
Probability values. |
required |
mean |
float or jnp.ndarray |
Mean of the normal distribution. Default is 0.0. |
0.0 |
sd |
float or jnp.ndarray |
Standard deviation of the normal distribution. Default is 1.0. |
1.0 |
lower_tail |
bool |
If |
True |
log_prob |
bool |
If |
False |
dtype |
jnp.dtype |
The dtype of the output. Default is |
<class 'jax.numpy.float64'> |
Returns:
Type | Description |
---|---|
ArrayLike |
The inverse cumulative density function of the normal distribution evaluated at |
Examples:
>>> qnorm(0.5)
Array([0.], dtype=float32)
>>> qnorm([0.25, 0.75], mean=3, sd=2)
Array([1.6510204, 4.3489795], dtype=float32)
stamox.distribution.dnorm(x: Union[Float, ArrayLike], mean: Union[Float, ArrayLike] = 0.0, sd: Union[Float, ArrayLike] = 1.0, lower_tail = True, log_prob = False, dtype = <class 'jax.numpy.float64'>) -> ArrayLike
¤
Probability density function (PDF) for Normal distribution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
Union[Float, ArrayLike] |
The input value(s) at which to evaluate the PDF. |
required |
mean |
Union[Float, ArrayLike] |
The mean of the normal distribution. Defaults to 0.0. |
0.0 |
sd |
Union[Float, ArrayLike] |
The standard deviation of the normal distribution. Defaults to 1.0. |
1.0 |
lower_tail |
bool |
If True (default), returns the cumulative distribution function (CDF) from negative infinity up to x. Otherwise, returns the CDF from x to positive infinity. |
True |
log_prob |
bool |
If True, returns the log-probability instead of the probability. |
False |
dtype |
jnp.dtype |
The dtype of the output. Default is |
<class 'jax.numpy.float64'> |
Returns:
Type | Description |
---|---|
ArrayLike |
The probability density function evaluated at point(s) x. |
Examples:
>>> import jax.numpy as jnp
>>> x = jnp.array([0.5, 1.0, -1.5])
>>> dnorm(x)
Array([0.35206532, 0.24197075, 0.12951761], dtype=float32)
stamox.distribution.rnorm(key: PRNGKeyArray, sample_shape: Optional[Sequence[int]] = None, mean: Union[Float, ArrayLike] = 0.0, sd: Union[Float, ArrayLike] = 1.0, lower_tail: Bool = True, log_prob: Bool = False, dtype = <class 'jax.numpy.float64'>) -> ArrayLike
¤
Generates random variables from a normal distribution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
PRNGKeyArray |
A KeyArray object used to generate the random numbers. |
required |
sample_shape |
Optional[Sequence[int]] |
An optional tuple of integers specifying the shape of the |
None |
mean |
Union[Float, ArrayLike] |
The mean of the normal distribution. Defaults to 0.0. |
0.0 |
sd |
Union[Float, ArrayLike] |
The standard deviation of the normal distribution. Defaults to 1.0. |
1.0 |
lower_tail |
Bool |
If True (default), returns the cumulative distribution function (CDF) from negative infinity up to x. Otherwise, returns the CDF from x to positive infinity. |
True |
log_prob |
Bool |
If True, returns the log-probability instead of the probability. |
False |
dtype |
The dtype of the output. Defaults to jnp.float_. |
<class 'jax.numpy.float64'> |
Returns:
Type | Description |
---|---|
ArrayLike |
Random samples from a normal distribution. |
Examples:
>>> key = random.PRNGKey(0)
>>> rnorm(key, sample_shape=(3, 2))
Array([[ 0.18784384, -1.2833426 ],
[ 0.6494181 , 1.2490594 ],
[ 0.24447003, -0.11744965]], dtype=float32)