Source code for pyhs3.distributions.core

"""
Core distribution classes and utilities.

Provides the base Distribution class and common utilities used by both
standard and CMS-specific distribution implementations.
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Mapping
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, cast

import pytensor.tensor as pt
from pydantic import PrivateAttr
from pytensor.graph.replace import graph_replace

from pyhs3.base import Evaluable
from pyhs3.context import Context
from pyhs3.normalization import gauss_legendre_integral
from pyhs3.typing.aliases import TensorVar

if TYPE_CHECKING:
    from pyhs3.distributions import Distributions


@dataclass
class LogProbTerms:
    """
    Structured per-channel contributions to :attr:`pyhs3.model.Model.log_prob`.

    Distributions describe *what* they contribute via
    :meth:`Distribution.log_prob_terms`; the model owns the channel-dataset
    pairing and decides *how* the pieces enter the joint log-likelihood
    (event weighting, summing over events, global constraint deduplication).

    Attributes:
        per_event: Log-density terms with the observable as a free pt.vector
            input, broadcasting to shape (N, M) for N events and parameter
            batch size M.  The model sums these over the event axis (applying
            per-event weights when present).
        channel: Scalar terms added once per channel, broadcasting onto the
            (M,) parameter batch axis (e.g. the -nu yield term of an extended
            mixture).
        constraints: Log-terms keyed by factor name, depending only on scalar
            nuisance parameters.  The model adds each name exactly once
            globally, so constraints shared across channels are not
            double-counted.
    """

    per_event: list[TensorVar] = field(default_factory=list)
    channel: list[TensorVar] = field(default_factory=list)
    constraints: dict[str, TensorVar] = field(default_factory=dict)


[docs] class Distribution(Evaluable, ABC): """ Base class for probability distributions in HS3. Provides the foundation for all distribution implementations, handling parameter management, constant generation, and symbolic expression evaluation using PyTensor. Distributions separate the main probability model (likelihood) from additional extended likelihood terms (e.g., constraints). The complete probability is the product of both terms. All distributions are automatically normalized over the domain of their observables unless explicitly opted out via _normalizable = False. Inherits parameter processing functionality from Evaluable. Subclasses must implement _expression() to define computation logic. """ _normalizable: bool = PrivateAttr(default=True) @abstractmethod def likelihood(self, context: Context) -> TensorVar: """ Main probability model for the distribution. This is the core probability density function (PDF) for the distribution. For example, the Poisson probability for observed data, Gaussian PDF, etc. Must be implemented by all subclasses. Args: context: Mapping of names to pytensor variables Returns: TensorVar: Main probability density Raises: TypeError: Must be implemented by subclasses """ def normalization_expression( self, _context: Context, _observable_name: str ) -> TensorVar | None: """ Return the antiderivative expression, or None for numerical fallback. Override in subclasses to provide analytical normalization. The returned expression should be the antiderivative F(x) such that the integral ∫f(x)dx from a to b equals F(b) - F(a). Args: context: Mapping of names to pytensor variables observable_name: Name of the observable to integrate over Returns: Symbolic antiderivative expression, or None for numerical fallback. """ return None def _normalization_integral( self, context: Context, obs_name: str, lower: TensorVar, upper: TensorVar ) -> TensorVar | None: """ Evaluate normalization integral using the antiderivative expression. This is a private method that evaluates F(upper) - F(lower) where F is the antiderivative returned by normalization_expression(). Args: context: Mapping of names to pytensor variables obs_name: Name of the observable to integrate over lower: Lower integration bound upper: Upper integration bound Returns: Symbolic integral expression, or None if normalization_expression() returns None. """ expr = self.normalization_expression(context, obs_name) # pylint: disable=assignment-from-none if expr is None: return None # Use the leaf (not the view) as the substitution target so graph_replace # propagates through every ExpandDims(leaf) view in the expression. leaf = context.parameters[obs_name] upper_t = pt.as_tensor_variable([upper], dtype=leaf.dtype) lower_t = pt.as_tensor_variable([lower], dtype=leaf.dtype) upper_val = cast(TensorVar, graph_replace(expr, [(leaf, upper_t)])) lower_val = cast(TensorVar, graph_replace(expr, [(leaf, lower_t)])) return cast(TensorVar, upper_val - lower_val) def _apply_normalization( self, raw: TensorVar, context: Context, ) -> TensorVar: """ Apply normalization to a raw likelihood expression. Normalizes a likelihood over observables present in the context. Attempts analytical integration first via _normalization_integral(), then falls back to nested Gauss-Legendre quadrature for multi-dimensional integrals. Args: raw: Raw (unnormalized) likelihood expression context: Mapping of names to pytensor variables (includes observables) Returns: Normalized likelihood expression """ # Explicit opt-out for distributions that should not be normalized if not self._normalizable: return raw matching = [ (name, lower, upper) for name, (lower, upper) in context.observables.items() if name in self.parameters ] if not matching: return raw # Single observable: try analytical integral first if len(matching) == 1: obs_name, lower, upper = matching[0] integral = self._normalization_integral(context, obs_name, lower, upper) if integral is not None: return cast(TensorVar, raw / integral) if len(matching) > 1: obs_names = [name for name, _, _ in matching] msg = ( f"Multi-dimensional normalization is not yet supported " f"(observables: {obs_names}). " f"See https://github.com/scipp-atlas/pyhs3/issues/214" ) raise NotImplementedError(msg) # Single observable: fall back to Gauss-Legendre quadrature. # Pass the leaf (not the view) so graph_replace substitutes through # every ExpandDims(leaf) view inside the integrand. obs_name, lower, upper = matching[0] integral_expr = gauss_legendre_integral( raw, context.parameters[obs_name], lower, upper ) return cast(TensorVar, raw / integral_expr) def _expression(self, context: Context) -> TensorVar: """ Complete probability combining main likelihood with extended terms. Returns the product of likelihood() and extended_likelihood(). This provides the complete probability for the distribution. All distributions are automatically normalized over observables present in the context, unless explicitly opted out via _normalizable = False. Subclasses typically do not need to override this method - just implement likelihood() and optionally extended_likelihood(). Args: context: Mapping of names to pytensor variables Returns: TensorVar: Complete probability density """ raw = self.likelihood(context) # Apply normalization (respects _normalizable flag internally) raw = self._apply_normalization(raw, context) return cast(TensorVar, raw * self.extended_likelihood(context)) def log_expression(self, context: Context) -> TensorVar: """ Log-probability combining main likelihood with extended terms. Returns the sum of log(likelihood()) and log(extended_likelihood()). This is mathematically equivalent to log(likelihood * extended_likelihood) but can be more numerically stable. All distributions are automatically normalized over observables present in the context, unless explicitly opted out via _normalizable = False. PyTensor handles optimization and simplification automatically. Args: context: Mapping of names to pytensor variables Returns: TensorVar: Log-probability density """ raw = self.likelihood(context) # Apply normalization (respects _normalizable flag internally) raw = self._apply_normalization(raw, context) return cast( TensorVar, pt.log(raw) + pt.log(self.extended_likelihood(context)), ) def extended_likelihood( self, _context: Context, _data: TensorVar | None = None ) -> TensorVar: """ Extended likelihood contribution in normal space. Returns additional likelihood terms for extended ML fitting. Override only when the extended terms belong in the distribution's per-event density, like constraint terms (HistFactory). Terms that enter the likelihood once per channel (e.g. the Poisson yield term of an extended MixtureDist, which involves the observed event count) must not use this hook: ``_expression()`` multiplies the result into the density, so ``Model.log_prob`` would count it once per event when summing over data. Such terms are assembled once per channel by ``Model.log_prob``, which owns the channel-dataset pairing. Default: no contribution (returns 1.0 in normal space). Args: context: Mapping of names to pytensor variables data: Optional data tensor for data-dependent terms Returns: TensorVar: Likelihood contribution (default: 1.0 = no contribution) """ return pt.constant(1.0) def log_prob_terms( self, expressions: Mapping[str, TensorVar], _distributions: Distributions, ) -> LogProbTerms: """ Structured contributions of this distribution to the joint log-likelihood. Called by :attr:`pyhs3.model.Model.log_prob` for each channel after the model graph is built. Override when the distribution's terms do not all enter the likelihood as a per-event log-density — e.g. once-per- channel yield terms (extended :class:`~pyhs3.distributions.MixtureDist`) or globally-deduplicated constraint factors (:class:`~pyhs3.distributions.ProductDist`). Default: a single per-event ``log(PDF)`` term. Args: expressions: Compiled symbolic expressions for all distributions, keyed by name (``model.distributions``). distributions: Distribution objects keyed by name, so composite distributions can delegate to their components' hooks. Returns: LogProbTerms: per-event, per-channel, and constraint contributions. """ return LogProbTerms(per_event=[pt.log(expressions[self.name])])