Source code for pyhs3.parameter_points
"""
HS3 Parameter Point implementations.
Provides Pydantic classes for handling HS3 parameter point specifications including
individual parameters and parameter sets for defining model parameter values.
"""
from __future__ import annotations
from collections.abc import Callable, Iterator
from typing import Any
import pytensor.tensor as pt
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, RootModel
from pyhs3.typing.aliases import TensorVar
[docs]
class ParameterPoint(BaseModel):
"""
Individual parameter specification with name and value.
Represents a single parameter specification with its value and optional configuration.
Used within parameter sets to specify concrete parameter values for model
evaluation and fitting.
Parameters:
name: Name identifier for the parameter
value: Numeric value of the parameter
const: Whether parameter is constant (optional, defaults to False)
nbins: Number of bins for binned parameters (optional)
kind: Type of tensor to create (optional, defaults to pt.scalar)
"""
model_config = ConfigDict()
name: str = Field(..., repr=True)
value: float = Field(..., repr=False)
const: bool = Field(default=False, repr=False)
nbins: int | None = Field(default=None, repr=False)
kind: Callable[..., TensorVar] = Field(default=pt.scalar, exclude=True, repr=False)
[docs]
class ParameterSet(BaseModel):
"""
Named collection of parameter specifications (matches HS3Spec structure).
Represents a complete set of parameter values that can be used to
evaluate a model. Each parameter set contains multiple individual
parameter points with their names and values.
Parameters:
name: Name identifier for the parameter set
parameters: List of ParameterPoint specifications
"""
model_config = ConfigDict()
name: str = Field(..., repr=True)
parameters: list[ParameterPoint] = Field(default_factory=list, repr=False)
@property
def points(self) -> dict[str, ParameterPoint]:
"""Compatibility property for core.py access."""
return {param.name: param for param in self.parameters}
def __len__(self) -> int:
"""Number of parameters in this set."""
return len(self.parameters)
def __contains__(self, param_name: str) -> bool:
"""Check if a parameter with the given name exists in this set."""
return param_name in self.points
def __getitem__(self, item: str | int) -> ParameterPoint:
"""Get a parameter by name or index."""
if isinstance(item, int):
return self.parameters[item]
return self.points[item]
def get(
self, param_name: str, default: ParameterPoint | None = None
) -> ParameterPoint | None:
"""Get a parameter by name, returning default if not found."""
return self.points.get(param_name, default)
def __iter__(self) -> Iterator[ParameterPoint]: # type: ignore[override]
"""Iterate over the parameters."""
return iter(self.parameters)
[docs]
class ParameterPoints(RootModel[list[ParameterSet]]):
"""
Collection of HS3 parameter sets for model configuration.
Manages a set of parameter set instances that define parameter values
for model evaluation. Provides dict-like access to parameter sets by
name and handles parameter set creation from configuration dictionaries.
Attributes:
parameter_sets: Mapping from parameter set names to ParameterSet instances.
"""
root: list[ParameterSet] = Field(default_factory=list)
_map: dict[str, ParameterSet] = PrivateAttr(default_factory=dict)
def model_post_init(self, __context: Any, /) -> None:
"""Initialize computed collections after Pydantic validation."""
self._map = {param_set.name: param_set for param_set in self.root}
def __getitem__(self, item: str | int) -> ParameterSet:
if isinstance(item, int):
return self.root[item]
return self._map[item]
def get(
self, item: str, default: ParameterSet | None = None
) -> ParameterSet | None:
"""Get a parameter set by name, returning default if not found."""
return self._map.get(item, default)
def __contains__(self, item: str) -> bool:
return item in self._map
def __iter__(self) -> Iterator[ParameterSet]: # type: ignore[override] # https://github.com/pydantic/pydantic/issues/8872
return iter(self.root)
def __len__(self) -> int:
return len(self.root)