Skip to content

Boostrap Function¤

stamox.sample.bootstrap_sample(data: ArrayLike, num_samples: int, *, key: PRNGKeyArray = None) -> ArrayLike ¤

Generates num_samples bootstrap samples from data with replacement.

Parameters:

Name Type Description Default
data array-like

The original data.

required
num_samples int

The number of bootstrap samples to generate.

required
key jrandom.KeyArray

A random key array. Defaults to None.

None

Returns:

Type Description
ArrayLike

An array of size (num_samples, len(data)) containing the bootstrap samples.

Examples:

>>> import jax.numpy as jnp
>>> import jax.random as jrandom
>>> from stamox.functions import bootstrap_sample
>>> data = jnp.arange(10)
>>> key = jrandom.PRNGKey(0)
>>> bootstrap_sample(data, num_samples=3, key=key)
Array([[9, 1, 6, 2, 9, 3, 9, 9, 4, 5],
        [4, 0, 4, 4, 6, 2, 5, 6, 5, 3],
        [7, 6, 9, 0, 0, 7, 0, 5, 8, 4]], dtype=int32)

stamox.sample.bootstrap(data: ArrayLike, call: Callable[..., ~ReturnValue], num_samples: int, *, key: PRNGKeyArray = None) -> PyTree ¤

Generates num_samples bootstrap samples from data with replacement, and calls call on each sample.

Parameters:

Name Type Description Default
data array-like

The original data.

required
call Callable[..., ReturnValue]

The function to call on each bootstrap sample.

required
num_samples int

The number of bootstrap samples to generate.

required
key jrandom.KeyArray

A random key array. Defaults to None.

None

Returns:

Type Description
PyTree

The return value of call on each bootstrap sample.

Examples:

>>> import jax.numpy as jnp
>>> import jax.random as jrandom
>>> from stamox.functions import bootstrap
>>> data = jnp.arange(10)
>>> bootstrap(data, jnp.mean, 3, key=key)
Array([5.7000003, 3.9      , 4.6      ], dtype=float32)