Source code for pyhs3.distributions

"""
HS3 Distribution implementations.

Provides classes for handling various probability distributions including
Gaussian, Mixture, Product, Crystal Ball, and Generic distributions.
Includes both standard HS3 distributions and CMS-specific extensions.
"""

from __future__ import annotations

from typing import Annotated

from pydantic import Field

# Import modules instead of individual classes
from pyhs3.collections import NamedCollection
from pyhs3.distributions import (
    basic,
    cms,
    composite,
    histfactory,
    histogram,
    mathematical,
    physics,
)
from pyhs3.distributions.core import Distribution, LogProbTerms
from pyhs3.exceptions import custom_error_msg

# Export distribution classes for backwards compatibility
# Basic distributions
GaussianDist = basic.GaussianDist
UniformDist = basic.UniformDist
PoissonDist = basic.PoissonDist
ExponentialDist = basic.ExponentialDist
LogNormalDist = basic.LogNormalDist
LandauDist = basic.LandauDist

# Composite distributions
MixtureDist = composite.MixtureDist
ProductDist = composite.ProductDist

# Histogram distributions
HistogramDist = histogram.HistogramDist

# HistFactory distributions
HistFactoryDistChannel = histfactory.HistFactoryDistChannel

# Mathematical distributions
GenericDist = mathematical.GenericDist
PolynomialDist = mathematical.PolynomialDist
BernsteinPolyDist = mathematical.BernsteinPolyDist

# Physics distributions
CrystalBallDist = physics.CrystalBallDist
AsymmetricCrystalBallDist = physics.AsymmetricCrystalBallDist
ArgusDist = physics.ArgusDist

# CMS-specific distributions
FastVerticalInterpHistPdf2Dist = cms.FastVerticalInterpHistPdf2Dist
GGZZBackgroundDist = cms.GGZZBackgroundDist
QQZZBackgroundDist = cms.QQZZBackgroundDist
FastVerticalInterpHistPdf2D2Dist = cms.FastVerticalInterpHistPdf2D2Dist

# Export all distribution classes
__all__ = [
    "ArgusDist",
    "AsymmetricCrystalBallDist",
    "BernsteinPolyDist",
    "CrystalBallDist",
    "Distribution",
    "Distributions",
    "ExponentialDist",
    "FastVerticalInterpHistPdf2D2Dist",
    "FastVerticalInterpHistPdf2Dist",
    "GGZZBackgroundDist",
    "GaussianDist",
    "GenericDist",
    "HistFactoryDistChannel",
    "HistogramDist",
    "LandauDist",
    "LogNormalDist",
    "LogProbTerms",
    "MixtureDist",
    "PoissonDist",
    "PolynomialDist",
    "ProductDist",
    "QQZZBackgroundDist",
    "UniformDist",
    "registered_distributions",
]

# Combine all distribution registries
registered_distributions: dict[str, type[Distribution]] = {
    **basic.distributions,
    **composite.distributions,
    **histfactory.distributions,
    **histogram.distributions,
    **mathematical.distributions,
    **physics.distributions,
    **cms.distributions,
}

# Type alias for all distribution types using discriminated union
DistributionType = Annotated[
    # Basic distributions
    basic.GaussianDist
    | basic.UniformDist
    | basic.PoissonDist
    | basic.ExponentialDist
    | basic.LogNormalDist
    | basic.LandauDist
    # Composite distributions
    | composite.MixtureDist
    | composite.ProductDist
    # HistFactory distributions
    | histfactory.HistFactoryDistChannel
    # Mathematical distributions
    | mathematical.GenericDist
    | mathematical.PolynomialDist
    | mathematical.BernsteinPolyDist
    # Physics distributions
    | physics.CrystalBallDist
    | physics.AsymmetricCrystalBallDist
    | physics.ArgusDist
    # CMS distributions
    | cms.FastVerticalInterpHistPdf2Dist
    | cms.GGZZBackgroundDist
    | cms.QQZZBackgroundDist
    | cms.FastVerticalInterpHistPdf2D2Dist,
    Field(discriminator="type"),
]


[docs] class Distributions(NamedCollection[DistributionType]): """ Collection of distributions for a probabilistic model. Manages a set of distribution instances, providing dict-like access by distribution name. Handles distribution creation from configuration dictionaries and maintains a registry of available distribution types. HS3 Reference: :hs3:label:`distributions <hs3.sec:distributions>` """ _enforce_unique_names: bool = True root: Annotated[ list[DistributionType], custom_error_msg( { "union_tag_invalid": "Unknown distribution type '{tag}' does not match any of the expected distributions: {expected_tags}" } ), ] = Field(default_factory=list)