Source code for pyhs3.likelihoods

"""
HS3 Likelihood implementations.

Provides Pydantic classes for handling HS3 likelihood specifications
including likelihood mappings between distributions and data.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Annotated, cast

import numpy as np
import numpy.typing as npt
from pydantic import Field, model_validator

from pyhs3.collections import NamedCollection, NamedModel
from pyhs3.data import Data, Datum

if TYPE_CHECKING:
    from pyhs3.workspace import Workspace
from pyhs3.distributions import Distributions
from pyhs3.distributions.core import Distribution
from pyhs3.typing.annotations import (
    FKListSchema,
    FKListSerializer,
    make_fk_list_validator,
)


[docs] class Likelihood(NamedModel): """ Likelihood specification mapping distributions to observations. Represents a likelihood function that combines parameterized distributions with observations to generate a likelihood function L(θ₁, θ₂, ...). The likelihood is the product of PDFs evaluated at observed data points. Attributes: name: Custom string identifier for the likelihood distributions: Array of strings referencing distributions data: Array of strings referencing data or inline values for constraints aux_distributions: Optional array of auxiliary distributions for regularization """ distributions: Annotated[ list[str] | Distributions, make_fk_list_validator(Distribution), FKListSerializer, FKListSchema, ] = Field(..., repr=False) data: Annotated[ list[str] | Data, make_fk_list_validator(Datum), FKListSerializer, FKListSchema, ] = Field(..., repr=False) aux_distributions: Annotated[ list[str] | Distributions | None, make_fk_list_validator(Distribution), FKListSerializer, FKListSchema, ] = Field(default=None, repr=False) def validate_unique_axis_names(self, workspace: Workspace | None = None) -> None: """Raise ValueError if any observable axis name appears more than once. When *workspace* is provided, unresolved string FK references in ``data`` are resolved via ``workspace.data`` before checking. Without a workspace, string entries are skipped. """ seen: dict[str, str] = {} duplicates: list[str] = [] for entry in self.data: if isinstance(entry, str): if workspace is None or workspace.data is None: continue datum = workspace.data.get(entry) if datum is None: continue else: datum = entry for axis in datum.axes or []: if axis.name in seen: duplicates.append( f"'{axis.name}' in '{datum.name}' and '{seen[axis.name]}'" ) else: seen[axis.name] = datum.name if duplicates: msg = ( f"Likelihood '{self.name}' has duplicate observable axis names: " + ", ".join(duplicates) ) raise ValueError(msg) def data_arrays(self) -> dict[str, npt.NDArray[np.float64]]: """Observable data as numpy arrays keyed by axis name. Returns a dict mapping each observable axis name to a 1-D float64 array of event values. Only data entries with both ``axes`` and ``entries`` are included (i.e. :class:`~pyhs3.data.UnbinnedData`). Suitable for passing directly to compiled or JAX functions:: fn(**likelihood.data_arrays(), **params) """ result: dict[str, npt.NDArray[np.float64]] = {} # self.data is guaranteed FK-resolved (no string entries after workspace construction). for datum in cast(Data, self.data): if datum.axes is None: continue entries = getattr(datum, "entries", None) if entries is None: continue entries_arr = np.asarray(entries, dtype=np.float64) n_axes = len(datum.axes) if entries_arr.size == 0: entries_arr = entries_arr.reshape(0, n_axes) for ax_idx, axis in enumerate(datum.axes): result[axis.name] = entries_arr[:, ax_idx] return result @model_validator(mode="after") def validate_distributions_data_pairing(self) -> Likelihood: """Validate that distributions and data are properly paired.""" if len(self.distributions) != len(self.data): msg = ( f"Likelihood '{self.name}': distributions and data must have the same length, " f"got {len(self.distributions)} distributions and {len(self.data)} data entries" ) raise ValueError(msg) if len(self.distributions) == 0 and not self.aux_distributions: msg = ( f"Likelihood '{self.name}': must have at least one distribution/data pair " f"or provide aux_distributions" ) raise ValueError(msg) return self
[docs] class Likelihoods(NamedCollection[Likelihood]): """ Collection of HS3 likelihood specifications. Manages a set of likelihood instances that define mappings between distributions and observations for statistical inference. Provides dict-like access to likelihoods by name. """ root: list[Likelihood] = Field(default_factory=list)