Source code for pyhs3.axes

"""
HS3 Axis implementations.

Provides Pydantic classes for handling HS3 axis specifications
including unbinned axes (with min/max bounds) and binned axes
(with regular or irregular binning).
"""

from __future__ import annotations

from itertools import pairwise
from typing import Annotated, Any, Literal, TypeAlias

import hist
import numpy as np
from pydantic import (
    ConfigDict,
    Discriminator,
    Field,
    Tag,
    model_validator,
)

from pyhs3.collections import NamedCollection, NamedModel
from pyhs3.exceptions import custom_error_msg


[docs] class Axis(NamedModel): """ Base axis specification for data coordinates. Attributes: name: Name of the axis/variable """
[docs] class BoundedAxis(Axis): """ Axis with required finite min/max bounds. Attributes: name: Name of the axis/variable min: Minimum value (required) max: Maximum value (required) """ min: float = Field(..., repr=False) max: float = Field(..., repr=False) @model_validator(mode="after") def check_min_le_max(self) -> BoundedAxis: """Validate that max >= min.""" if self.max < self.min: msg = f"{type(self).__name__} '{self.name}': max ({self.max}) must be >= min ({self.min})" raise ValueError(msg) return self
[docs] class UnbinnedAxis(BoundedAxis): """ Axis for unbinned data with required finite bounds. Attributes: name: Name of the axis/variable min: Minimum value (required, inherited from BoundedAxis) max: Maximum value (required, inherited from BoundedAxis) """ def to_hist(self) -> Any: """ Convert this axis to a hist.axis object. This is a base implementation that should be overridden by subclasses that have specific binning information (like BinnedAxis). Returns: A hist.axis object Raises: ValueError: If axis has insufficient binning information """ msg = f"UnbinnedAxis '{self.name}' does not have binning information for histogram conversion" raise ValueError(msg)
[docs] class ConstantAxis(Axis): """ Axis for constant data. Alias for Axis with required const field. Attributes: name: Name of the axis/variable const: true (required) """ model_config = ConfigDict(frozen=True, extra="forbid") const: Literal[True] = Field(True, repr=False, init=False)
[docs] class RegularAxis(BoundedAxis): """ Attributes: name: Name of the axis/variable min: Minimum value (inherited from BoundedAxis) max: Maximum value (inherited from BoundedAxis) nbins: Number of bins (for regular binning) """ nbins: int = Field(repr=False) @model_validator(mode="after") def check_binning(self) -> RegularAxis: """Validate that max >= min.""" if self.nbins <= 0: msg = f"RegularAxis '{self.name}' must have positive number of bins, got {self.nbins}" raise ValueError(msg) return self @property def edges(self) -> list[float]: """Get the bin edges for this axis. Returns: List of bin edges. Generates edges using linspace. """ return list(np.linspace(self.min, self.max, self.nbins + 1)) def to_hist(self) -> hist.axis.Regular: """ Convert this axis to a hist.axis object. Returns: A hist.axis.Regular object """ return hist.axis.Regular(self.nbins, self.min, self.max, name=self.name)
[docs] class IrregularAxis(Axis): """ Attributes: name: Name of the axis/variable edges: Bin edges array (length n+1) """ edges: list[float] = Field(repr=False) @model_validator(mode="after") def validate_binning(self) -> IrregularAxis: """Ensure proper binning specification for binned data.""" if len(self.edges) < 2: msg = f"IrregularAxis '{self.name}' must have at least 2 edges" raise ValueError(msg) # Check that edges are in ascending order for prev, curr in pairwise(self.edges): if curr <= prev: msg = f"IrregularAxis '{self.name}' edges must be in ascending order" raise ValueError(msg) return self @property def min(self) -> float: """Return lower edge.""" return self.edges[0] @property def max(self) -> float: """Return upper edge.""" return self.edges[-1] @property def nbins(self) -> int: """Get the nbins for this axis. Returns: Number of bins. """ return len(self.edges) - 1 def to_hist(self) -> hist.axis.Variable: """ Convert this axis to a hist.axis object. Returns: A hist.axis.Variable object """ return hist.axis.Variable(self.edges, name=self.name)
[docs] class DomainCoordinateAxis(Axis): """ Axis for domain coordinates with optional bounds. Represents a coordinate axis in a parameter domain, where bounds may be fully specified, partially specified, or unbounded (infinite). Attributes: name: Name of the axis/variable min: Minimum value (optional, defaults to -inf) max: Maximum value (optional, defaults to +inf) Examples: Create an unbounded domain axis: >>> from pyhs3.axes import DomainCoordinateAxis >>> axis = DomainCoordinateAxis(name="x") >>> axis DomainCoordinateAxis(x ∈ (-∞, +∞)) Create a lower-bounded domain: >>> axis = DomainCoordinateAxis(name="x", min=-5) >>> axis DomainCoordinateAxis(x ∈ [-5, +∞)) Create an upper-bounded domain: >>> axis = DomainCoordinateAxis(name="x", max=10) >>> axis DomainCoordinateAxis(x ∈ (-∞, 10]) Create a fully bounded domain: >>> axis = DomainCoordinateAxis(name="x", min=0, max=1) >>> axis DomainCoordinateAxis(x ∈ [0, 1]) Integers are displayed without trailing .0: >>> axis = DomainCoordinateAxis(name="x", min=0.0, max=5.0) >>> axis DomainCoordinateAxis(x ∈ [0, 5]) """ model_config = ConfigDict(serialize_by_alias=True) v_min: float | None = Field( default=None, alias="min", repr=False, exclude_if=lambda v: v is None ) v_max: float | None = Field( default=None, alias="max", repr=False, exclude_if=lambda v: v is None ) @property def min(self) -> float: """Returns defined minimum or (negative) :data:`np.inf <numpy.inf>`""" return -np.inf if self.v_min is None else self.v_min @property def max(self) -> float: """Returns defined maximum or (positive) :data:`np.inf <numpy.inf>`""" return np.inf if self.v_max is None else self.v_max @model_validator(mode="after") def check_min_le_max(self) -> DomainCoordinateAxis: """Validate that max >= min when both are provided.""" if self.max < self.min: msg = f"DomainCoordinateAxis '{self.name}': max ({self.max}) must be >= min ({self.min})" raise ValueError(msg) return self def __repr__(self) -> str: # Determine interval brackets left_bracket = "[" if self.v_min is not None else "(" right_bracket = "]" if self.v_max is not None else ")" # Determine displayed bounds min_str = f"{self.v_min:g}" if self.v_min is not None else "-∞" max_str = f"{self.v_max:g}" if self.v_max is not None else "+∞" return f"DomainCoordinateAxis({self.name}{left_bracket}{min_str}, {max_str}{right_bracket})"
def _binned_axis_discriminator(v: Any) -> str | None: if isinstance(v, dict): if "edges" in v and "nbins" not in v: return "irregular" if "nbins" in v and "edges" not in v: return "regular" return None # Already-constructed model case if isinstance(v, IrregularAxis): return "irregular" if isinstance(v, RegularAxis): return "regular" return None BinnedAxis = Annotated[ ( Annotated[RegularAxis, Tag("regular")] | Annotated[IrregularAxis, Tag("irregular")] ), Discriminator(_binned_axis_discriminator), custom_error_msg( { "union_tag_not_found": "Unknown axis {input}'. You must specify either regular binning (nbins/min/max) or irregular binning (edges).", "missing": "{input_value['name']} is missing {loc}", } ), ]
[docs] class BinnedAxes(NamedCollection[BinnedAxis]): """ Collection of binned axis. """ def get_total_bins(self) -> int: """Calculate total number of bins across all axes.""" total = 1 for axis in self: total *= axis.nbins return total
DomainAxis: TypeAlias = DomainCoordinateAxis | ConstantAxis
[docs] class UnbinnedAxes(NamedCollection[UnbinnedAxis]): """ Collection of UnbinnedAxis. """ root: list[UnbinnedAxis] = Field(default_factory=list)
[docs] class Axes(NamedCollection[BinnedAxis | UnbinnedAxis]): """ Collection of BinnedAxis | UnbinnedAxis. """ root: list[BinnedAxis | UnbinnedAxis] = Field(default_factory=list)
[docs] class DomainAxes(NamedCollection[DomainAxis]): """ Collection of BinnedAxis | UnbinnedAxis. """ root: list[DomainAxis] = Field(default_factory=list)