Skip to content

Scipy With PyTree

In fenbux, we provide wrappers of scipy.stats for disributions, which accept pytrees as inputs and output pytrees.

from fenbux import scipy_stats as stats

# Create a normal distribution
loc = {'a': 0.0, 'b': 1.0}
scale = {'a': 1.0, 'b': 2.0}
x = [0.0, 1.0]
normal = stats.norm(loc=loc, scale=scale)
normal.mean(), normal.pdf(x)
({'a': 0.0, 'b': 1.0},
 {'a': [0.3989422804014327, 0.24197072451914337],
  'b': [0.17603266338214976, 0.19947114020071635]})