Source code for pyhs3.domains

"""
HS3 Domain implementations.

Provides Pydantic classes for handling HS3 domain specifications including
axes and product domains for defining parameter spaces and integration regions.
"""

from __future__ import annotations

from typing import Annotated, Any, Literal

from pydantic import Field, PrivateAttr, model_validator

from pyhs3.axes import ConstantAxis, DomainAxes, DomainAxis, DomainCoordinateAxis
from pyhs3.collections import NamedCollection, NamedModel
from pyhs3.exceptions import custom_error_msg


[docs] class Domain(NamedModel): """ Base class for HS3 domain specifications. Provides the foundation for all domain implementations, handling common properties like name and type identification. Domains define parameter spaces for integration, constraints, and likelihood evaluation. Parameters: name: Name identifier for the domain type: Domain type identifier """ type: str = Field(..., repr=False) @property def dimension(self) -> int: """Number of dimensions in this domain.""" raise NotImplementedError @property def axis_names(self) -> list[str]: """List of axis names in this domain. Note: may not be implemented for all domain types.""" raise NotImplementedError def __len__(self) -> int: """Number of axes in this domain.""" return 0 def __contains__(self, _axis_name: str) -> bool: """Check if an axis with the given name exists in this domain.""" return False def get(self, _axis_name: str, default: Any = None) -> Any: """Get axis bounds for a parameter name. Note: may not be implemented for all domain types.""" return default def __getitem__(self, axis_name: str) -> Any: """Get axis bounds for a parameter name (dict-like access).""" raise KeyError(axis_name)
[docs] class ProductDomain(Domain): """ Product domain specification for multi-dimensional parameter spaces. Defines a Cartesian product of axes to create multi-dimensional parameter domains. Used for specifying integration regions, parameter constraints, and likelihood evaluation domains in HS3 specifications. The domain represents the Cartesian product: axis₁ x axis₂ x ... x axisₙ where each axis defines a one-dimensional range. Parameters: name: Name identifier for the domain type: Domain type identifier (always "product_domain") axes: List of axis specifications defining each dimension """ type: Literal["product_domain"] = Field(default="product_domain", repr=False) axes: DomainAxes = Field(default_factory=lambda: DomainAxes([]), repr=False) _axes_map: dict[str, DomainAxis] = PrivateAttr(default_factory=dict) @model_validator(mode="after") def initialize_axes_map(self) -> ProductDomain: """Initialize the internal axes mapping for fast lookup.""" self._axes_map = {axis.name: axis for axis in self.axes} return self @model_validator(mode="after") def validate_unique_axis_names(self) -> ProductDomain: """Validate that all axis names are unique within the domain.""" axis_names = [axis.name for axis in self.axes] if len(axis_names) != len(set(axis_names)): duplicates = [name for name in axis_names if axis_names.count(name) > 1] msg = ( f"Domain '{self.name}' contains duplicate axis names: {set(duplicates)}" ) raise ValueError(msg) return self @property def dimension(self) -> int: """Number of dimensions (axes) in this domain.""" return len(self.axes) @property def axis_names(self) -> list[str]: """List of axis names in this domain.""" return [axis.name for axis in self.axes] def __len__(self) -> int: """Number of axes in this domain.""" return len(self.axes) def __contains__(self, axis_name: str) -> bool: """Check if an axis with the given name exists in this domain.""" return axis_name in self._axes_map def get( self, axis_name: str, default: tuple[float | None, float | None] = (None, None) ) -> tuple[float | None, float | None]: """ Get axis bounds for a parameter name. Args: axis_name: Name of the axis to get bounds for. default: Default value to return if axis not found. Returns: Tuple of (min, max) bounds if axis exists, otherwise default. """ axis = self._axes_map.get(axis_name) return ( (axis.min, axis.max) if axis is not None and isinstance(axis, DomainCoordinateAxis) else default ) def __getitem__(self, axis_name: str) -> tuple[float | None, float | None]: """Get axis bounds for a parameter name (dict-like access).""" axis = self._axes_map.get(axis_name) if axis is not None: if isinstance(axis, ConstantAxis): msg = f"Axis '{axis_name}' is a constant axis with no min or max." raise ValueError(msg) return (axis.min, axis.max) msg = f"No axis named '{axis_name}' found in domain '{self.name}'" raise KeyError(msg)
# Type alias for all domain types using discriminated union # Currently only ProductDomain exists, but this allows for future domain types DomainType = Annotated[ProductDomain, Field(discriminator="type")]
[docs] class Domains(NamedCollection[DomainType]): """ Collection of HS3 domains for parameter space definitions. Manages a set of domain instances that define parameter spaces, integration regions, and constraints. Provides dict-like access to domains by name and handles domain creation from configuration dictionaries. Attributes: domains: Mapping from domain names to Domain instances. """ root: Annotated[ list[DomainType], custom_error_msg( { "union_tag_invalid": "Unknown domain type '{tag}' does not match any of the expected domains: {expected_tags}" } ), ]