Source code for pyhs3.distributions

"""
HS3 Distribution implementations.

Provides classes for handling various probability distributions including
Gaussian, Mixture, Product, Crystal Ball, and Generic distributions.
Includes both standard HS3 distributions and CMS-specific extensions.
"""

from __future__ import annotations

from collections.abc import Iterator
from typing import Annotated, Any, TypeVar

from pydantic import Field, PrivateAttr, RootModel

# Import modules instead of individual classes
from pyhs3.distributions import (
    basic,
    cms,
    composite,
    histfactory,
    histogram,
    mathematical,
    physics,
)
from pyhs3.distributions.core import Distribution
from pyhs3.exceptions import custom_error_msg

# Export distribution classes for backwards compatibility
# Basic distributions
GaussianDist = basic.GaussianDist
UniformDist = basic.UniformDist
PoissonDist = basic.PoissonDist
ExponentialDist = basic.ExponentialDist
LogNormalDist = basic.LogNormalDist
LandauDist = basic.LandauDist

# Composite distributions
MixtureDist = composite.MixtureDist
ProductDist = composite.ProductDist

# Histogram distributions
HistogramDist = histogram.HistogramDist

# HistFactory distributions
HistFactoryDistChannel = histfactory.HistFactoryDistChannel

# Mathematical distributions
GenericDist = mathematical.GenericDist
PolynomialDist = mathematical.PolynomialDist
BernsteinPolyDist = mathematical.BernsteinPolyDist

# Physics distributions
CrystalBallDist = physics.CrystalBallDist
AsymmetricCrystalBallDist = physics.AsymmetricCrystalBallDist
ArgusDist = physics.ArgusDist

# CMS-specific distributions
FastVerticalInterpHistPdf2Dist = cms.FastVerticalInterpHistPdf2Dist
GGZZBackgroundDist = cms.GGZZBackgroundDist
QQZZBackgroundDist = cms.QQZZBackgroundDist
FastVerticalInterpHistPdf2D2Dist = cms.FastVerticalInterpHistPdf2D2Dist

# Export all distribution classes
__all__ = [
    "ArgusDist",
    "AsymmetricCrystalBallDist",
    "BernsteinPolyDist",
    "CrystalBallDist",
    "Distribution",
    "Distributions",
    "ExponentialDist",
    "FastVerticalInterpHistPdf2D2Dist",
    "FastVerticalInterpHistPdf2Dist",
    "GGZZBackgroundDist",
    "GaussianDist",
    "GenericDist",
    "HistFactoryDistChannel",
    "HistogramDist",
    "LandauDist",
    "LogNormalDist",
    "MixtureDist",
    "PoissonDist",
    "PolynomialDist",
    "ProductDist",
    "QQZZBackgroundDist",
    "UniformDist",
    "registered_distributions",
]

DistT = TypeVar("DistT", bound="Distribution")

# Combine all distribution registries
registered_distributions: dict[str, type[Distribution]] = {
    **basic.distributions,
    **composite.distributions,
    **histfactory.distributions,
    **histogram.distributions,
    **mathematical.distributions,
    **physics.distributions,
    **cms.distributions,
}

# Type alias for all distribution types using discriminated union
DistributionType = Annotated[
    # Basic distributions
    basic.GaussianDist
    | basic.UniformDist
    | basic.PoissonDist
    | basic.ExponentialDist
    | basic.LogNormalDist
    | basic.LandauDist
    # Composite distributions
    | composite.MixtureDist
    | composite.ProductDist
    # Histogram distributions
    | histogram.HistogramDist
    # HistFactory distributions
    | histfactory.HistFactoryDistChannel
    # Mathematical distributions
    | mathematical.GenericDist
    | mathematical.PolynomialDist
    | mathematical.BernsteinPolyDist
    # Physics distributions
    | physics.CrystalBallDist
    | physics.AsymmetricCrystalBallDist
    | physics.ArgusDist
    # CMS distributions
    | cms.FastVerticalInterpHistPdf2Dist
    | cms.GGZZBackgroundDist
    | cms.QQZZBackgroundDist
    | cms.FastVerticalInterpHistPdf2D2Dist,
    Field(discriminator="type"),
]


[docs] class Distributions(RootModel[list[DistributionType]]): """ Collection of distributions for a probabilistic model. Manages a set of distribution instances, providing dict-like access by distribution name. Handles distribution creation from configuration dictionaries and maintains a registry of available distribution types. Attributes: dists: Mapping from distribution names to Distribution instances. HS3 Reference: :hs3:label:`distributions <hs3.sec:distributions>` """ root: Annotated[ list[DistributionType], custom_error_msg( { "union_tag_invalid": "Unknown distribution type '{tag}' does not match any of the expected distributions: {expected_tags}" } ), ] = Field(default_factory=list) _map: dict[str, Distribution] = PrivateAttr(default_factory=dict) def model_post_init(self, __context: Any, /) -> None: """Initialize computed collections after Pydantic validation.""" self._map = {dist.name: dist for dist in self.root} def __getitem__(self, item: str) -> Distribution: return self._map[item] def __contains__(self, item: str) -> bool: return item in self._map def __iter__(self) -> Iterator[Distribution]: # type: ignore[override] # https://github.com/pydantic/pydantic/issues/8872 return iter(self.root) def __len__(self) -> int: return len(self.root)