Cauchy Distribution¤
stamox.distribution.pcauchy(q: Union[Float, ArrayLike], loc: Union[Float, ArrayLike] = 0.0, scale: Union[Float, ArrayLike] = 1.0, lower_tail: Bool = True, log_prob: Bool = False, dtype = <class 'jax.numpy.float64'>) -> ArrayLike
¤
Calculates the cumulative denisty probability c function of the Cauchy distribution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
q |
Union[Float, ArrayLike] |
The value at which to evaluate the CDF. |
required |
loc |
Union[Float, ArrayLike] |
The location parameter of the Cauchy distribution. Defaults to 0.0. |
0.0 |
scale |
Union[Float, ArrayLike] |
The scale parameter of the Cauchy distribution. Defaults to 1.0. |
1.0 |
lower_tail |
Bool |
Whether to return the lower tail probability. Defaults to True. |
True |
log_prob |
Bool |
Whether to return the log probability. Defaults to False. |
False |
dtype |
jnp.dtype |
The dtype of the output. Defaults to jnp.float_. |
<class 'jax.numpy.float64'> |
Returns:
Type | Description |
---|---|
ArrayLike |
The cumulative density function of the Cauchy distribution. |
Examples:
>>> pcauchy(1.0, loc=0.0, scale=1.0, lower_tail=True, log_prob=False)
Array(0.75, dtype=float32, weak_type=True)
stamox.distribution.qcauchy(q: Union[Float, ArrayLike], loc: Union[Float, ArrayLike] = 0.0, scale: Union[Float, ArrayLike] = 1.0, lower_tail: Bool = True, log_prob: Bool = False, dtype = <class 'jax.numpy.float64'>) -> ArrayLike
¤
Computes the quantile of the Cauchy distribution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
q |
Union[float, array-like] |
Quantiles to compute. |
required |
loc |
Union[float, array-like] |
Location parameter. Defaults to 0.0. |
0.0 |
scale |
Union[float, array-like] |
Scale parameter. Defaults to 1.0. |
1.0 |
lower_tail |
bool |
Whether to compute the lower tail. Defaults to True. |
True |
log_prob |
bool |
Whether to compute the log probability. Defaults to False. |
False |
dtype |
jnp.dtype |
The dtype of the output. Defaults to jnp.float_. |
<class 'jax.numpy.float64'> |
Returns:
Type | Description |
---|---|
ArrayLike |
The quantiles of the Cauchy distribution. |
Examples:
>>> qcauchy(0.5, loc=1.0, scale=2.0)
Array([1.], dtype=float32, weak_type=True)
stamox.distribution.dcauchy(x: Union[Float, ArrayLike], loc: Union[Float, ArrayLike] = 0.0, scale: Union[Float, ArrayLike] = 1.0, lower_tail: Bool = True, log_prob: Bool = False, dtype = <class 'jax.numpy.float64'>) -> ArrayLike
¤
Computes the pdf of the Cauchy distribution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
Union[Float, ArrayLike] |
The input values. |
required |
loc |
Union[Float, ArrayLike] |
The location parameter. Defaults to 0.0. |
0.0 |
scale |
Union[Float, ArrayLike] |
The scale parameter. Defaults to 1.0. |
1.0 |
lower_tail |
Bool |
Whether to compute the lower tail. Defaults to True. |
True |
log_prob |
Bool |
Whether to compute the log probability. Defaults to False. |
False |
dtype |
jnp.dtype |
The dtype of the output. Defaults to jnp.float_. |
<class 'jax.numpy.float64'> |
Returns:
Type | Description |
---|---|
ArrayLike |
The pdf of the Cauchy distribution. |
Examples:
>>> dcauchy(1.0, loc=0.0, scale=1.0)
Array([0.15915494], dtype=float32, weak_type=True)
stamox.distribution.rcauchy(key: PRNGKeyArray, sample_shape: Optional[Sequence[int]] = None, loc: Union[Float, ArrayLike] = 0.0, scale: Union[Float, ArrayLike] = 1.0, lower_tail: Bool = True, log_prob: Bool = False, dtype = <class 'jax.numpy.float64'>) -> ArrayLike
¤
Generates random samples from the Cauchy distribution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
PRNGKeyArray |
A PRNGKey to use for generating the samples. |
required |
sample_shape |
Optional[Sequence[int]] |
The shape of the output array. |
None |
loc |
Union[Float, ArrayLike] |
The location parameter of the Cauchy distribution. |
0.0 |
scale |
Union[Float, ArrayLike] |
The scale parameter of the Cauchy distribution. |
1.0 |
lower_tail |
Bool |
Whether to return the lower tail probability. |
True |
log_prob |
Bool |
Whether to return the log probability. |
False |
dtype |
The dtype of the output. |
<class 'jax.numpy.float64'> |
Returns:
Type | Description |
---|---|
ArrayLike |
An array of samples from the Cauchy distribution. |
Examples:
>>> key = jax.random.PRNGKey(0)
>>> rcauchy(key, sample_shape=(2, 3), loc=0.0, scale=1.0)
Array([[ 0.23841971, -3.0880406 , 0.9507532 ],
[ 2.8963416 , 0.31303588, -0.14792857]], dtype=float32)