Skip to content

Random sample

fenbux.rand(dist: AbstractDistribution, key: Array, shape: collections.abc.Sequence[int] = (), dtype: Union[Any, str, numpy.dtype, fenbux.core._SupportsDType] = <class 'float'>) -> PyTree

Random number generator

Parameters:

Name Type Description Default
dist AbstractDistribution

Distribution object.

required
key KeyArray

Random number generator key.

required
shape Shape

Shape of the random number.

()
dtype

Data type of the random number.

<class 'float'>

Examples:

>>> import jax.random as jr
>>> from fenbux import Normal
>>> key = jr.PRNGKey(0)
>>> dist = Normal(0.0, 1.0)
>>> rand(dist, key, (2, ))
Array([-0.20584235,  0.46256348], dtype=float32)