Random sample
fenbux.rand(dist: AbstractDistribution, key: Array, shape: collections.abc.Sequence[int] = (), dtype: Union[Any, str, numpy.dtype, fenbux.core._SupportsDType] = <class 'float'>) -> PyTree
Random number generator
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dist |
AbstractDistribution |
Distribution object. |
required |
key |
KeyArray |
Random number generator key. |
required |
shape |
Shape |
Shape of the random number. |
() |
dtype |
Data type of the random number. |
<class 'float'> |
Examples:
>>> import jax.random as jr
>>> from fenbux import Normal
>>> key = jr.PRNGKey(0)
>>> dist = Normal(0.0, 1.0)
>>> rand(dist, key, (2, ))
Array([-0.20584235, 0.46256348], dtype=float32)