"""
HS3 Functions implementation.
Provides classes for handling HS3 functions including product functions,
generic functions with mathematical expressions, and interpolation functions.
"""
from __future__ import annotations
import logging
from collections.abc import Iterator
from typing import Annotated, Any, Literal, TypeVar, cast
import pytensor.tensor as pt
import sympy as sp
from pydantic import (
BaseModel,
ConfigDict,
Field,
PrivateAttr,
RootModel,
model_validator,
)
from pyhs3.exceptions import UnknownInterpolationCodeError, custom_error_msg
from pyhs3.generic_parse import analyze_sympy_expr, parse_expression, sympy_to_pytensor
from pyhs3.typing.aliases import TensorVar
log = logging.getLogger(__name__)
FuncT = TypeVar("FuncT", bound="Function")
[docs]
class Function(BaseModel):
"""Base class for HS3 functions."""
model_config = ConfigDict(serialize_by_alias=True)
name: str
type: str
_parameters: dict[str, str] = PrivateAttr(default_factory=dict)
@property
def parameters(self) -> dict[str, str]:
"""Access to parameter mapping."""
return self._parameters
def expression(self, _: dict[str, TensorVar]) -> TensorVar:
"""
Evaluate the function expression.
Args:
context: Mapping of names to pytensor variables
Returns:
PyTensor expression representing the function result
"""
msg = f"Function type {self.type} not implemented"
raise NotImplementedError(msg)
[docs]
class SumFunction(Function):
"""Sum function that adds summands together."""
type: Literal["sum"] = "sum"
summands: list[str]
@model_validator(mode="after")
def process_parameters(self) -> SumFunction:
"""Build the parameters dict from summands."""
self._parameters = {name: name for name in self.summands}
return self
def expression(self, context: dict[str, TensorVar]) -> TensorVar:
"""
Evaluate the sum function.
Args:
context: Mapping of names to PyTensor variables.
Returns:
TensorVar: PyTensor expression representing the sum of all summands.
"""
if not self.summands:
return pt.constant(0.0)
result = context[self.summands[0]]
for summand in self.summands[1:]:
result = result + context[summand]
return result
[docs]
class ProductFunction(Function):
"""Product function that multiplies factors together."""
type: Literal["product"] = "product"
factors: list[str]
@model_validator(mode="after")
def process_parameters(self) -> ProductFunction:
"""Build the parameters dict from factors."""
self._parameters = {name: name for name in self.factors}
return self
def expression(self, context: dict[str, TensorVar]) -> TensorVar:
"""
Evaluate the product function.
Args:
context: Mapping of names to PyTensor variables.
Returns:
TensorVar: PyTensor expression representing the product of all factors.
"""
if not self.factors:
return pt.constant(1.0)
result = context[self.factors[0]]
for factor in self.factors[1:]:
result = result * context[factor]
return result
[docs]
class GenericFunction(Function):
"""
Generic function with custom mathematical expression.
Evaluates arbitrary mathematical expressions using SymPy parsing
and PyTensor computation. Supports common mathematical operations
including arithmetic, trigonometric, exponential, and logarithmic functions.
The expression is parsed once during initialization and converted to
a PyTensor computation graph for efficient evaluation.
Parameters:
name (str): Name of the function.
expression (str): Mathematical expression string to evaluate.
Examples:
>>> func = GenericFunction(name="quadratic", expression="x**2 + 2*x + 1")
>>> func = GenericFunction(name="sinusoid", expression="sin(x) * exp(-t)")
"""
model_config = ConfigDict(arbitrary_types_allowed=True, serialize_by_alias=True)
type: Literal["generic_function"] = "generic_function"
expression_str: str = Field(alias="expression")
_sympy_expr: sp.Expr = PrivateAttr(default=None)
_dependent_vars: list[str] = PrivateAttr(default_factory=list)
@model_validator(mode="after")
def setup_expression(self) -> GenericFunction:
"""Parse and analyze the expression during initialization."""
# Parse and analyze the expression during initialization
self._sympy_expr = parse_expression(self.expression_str)
# Analyze the expression to determine dependencies
analysis = analyze_sympy_expr(self._sympy_expr)
independent_vars = [str(symbol) for symbol in analysis["independent_vars"]]
self._dependent_vars = [str(symbol) for symbol in analysis["dependent_vars"]]
# Set parameters based on the analyzed expression
self._parameters = {var: var for var in independent_vars}
return self
def expression(self, context: dict[str, TensorVar]) -> TensorVar:
"""
Evaluate the generic function expression.
Args:
context: Mapping of names to PyTensor variables.
Returns:
TensorVar: PyTensor expression representing the parsed mathematical expression.
"""
# Get required variables using the parameters determined during initialization
variables = [context[name] for name in self._parameters.values()]
# Convert using the pre-parsed sympy expression
return sympy_to_pytensor(self._sympy_expr, variables)
[docs]
class InterpolationFunction(Function):
r"""
Piecewise interpolation function implementation.
Implements ROOT's PiecewiseInterpolation logic to morph between nominal
and variation distributions based on nuisance parameter values.
Supports multiple interpolation codes (0-6) for different mathematical approaches.
Mathematical Formulations:
For **additive** interpolation modes (codes 0, 2, 3, 4):
.. math::
\text{result} = \text{nominal} + \sum_i I_i(\theta_i; \text{low}_i, \text{nominal}, \text{high}_i)
For **multiplicative** interpolation modes (codes 1, 5, 6):
.. math::
\text{result} = \text{nominal} \times \prod_i [1 + I_i(\theta_i; \text{low}_i/\text{nominal}, 1, \text{high}_i/\text{nominal})]
Parameters:
name: Name of the function
high: High variation parameter names
low: Low variation parameter names
nom: Nominal parameter name
interpolationCodes: Interpolation method codes (0-6)
positiveDefinite: Whether function should be positive definite
vars: Variable names this function depends on (nuisance parameters)
"""
type: Literal["interpolation"] = "interpolation"
high: list[str]
low: list[str]
nom: str
interpolationCodes: list[int]
positiveDefinite: bool
vars: list[str]
@model_validator(mode="after")
def process_parameters(self) -> InterpolationFunction:
"""Build the parameters dict and validate interpolation codes."""
# Validate interpolation codes
valid_codes = {0, 1, 2, 3, 4, 5, 6}
for code in self.interpolationCodes:
if code not in valid_codes:
msg = f"Unknown interpolation code {code} in function '{self.name}'. Valid codes are 0-6."
raise UnknownInterpolationCodeError(msg)
# Build parameters dict - all high, low, nom, and vars parameters
all_params = [*self.high, *self.low, self.nom, *self.vars]
self._parameters = {name: name for name in all_params}
return self
def _flexible_interp_single(
self,
interp_code: int,
low_val: TensorVar,
high_val: TensorVar,
boundary: float,
nominal: TensorVar,
param_val: TensorVar,
) -> TensorVar:
r"""
Implement flexible interpolation for a single parameter.
Based on ROOT's flexibleInterpSingle method with support for
interpolation codes 0-6. This method computes the interpolation
contribution :math:`I_i(\theta_i)` for a single nuisance parameter.
Args:
interp_code: Interpolation code (0-6) determining the mathematical approach
low_val: Low variation value (used when :math:`\theta < 0`)
high_val: High variation value (used when :math:`\theta \geq 0`)
boundary: Boundary value for switching between interpolation and extrapolation (typically 1.0)
nominal: Nominal value (baseline)
param_val: Parameter value :math:`\theta` (nuisance parameter)
Returns:
Interpolated contribution :math:`I_i(\theta_i)` to be added (additive modes)
or multiplied (multiplicative modes) with the result
Note:
The returned value interpretation depends on the interpolation code:
- Codes 0,2,3,4: Direct additive contribution
- Codes 1,5,6: Multiplicative factor (subtract 1 before use)
"""
# Codes 0, 2, 3, 4 are additive modes
# Codes 1, 5, 6 are multiplicative modes
if interp_code == 0:
# Linear interpolation/extrapolation (additive)
return cast(
TensorVar,
pt.switch(
param_val >= 0,
param_val * (high_val - nominal),
param_val * (nominal - low_val),
),
)
if interp_code == 1:
# Exponential interpolation/extrapolation (multiplicative)
ratio_high = high_val / nominal
ratio_low = low_val / nominal
return cast(
TensorVar,
pt.switch(
param_val >= 0,
cast(TensorVar, pt.power(ratio_high, param_val)) - 1.0, # type: ignore[no-untyped-call]
cast(TensorVar, pt.power(ratio_low, -param_val)) - 1.0, # type: ignore[no-untyped-call]
),
)
if interp_code == 2:
# Exponential interpolation, linear extrapolation (additive)
return cast(
TensorVar,
pt.switch(
pt.abs(param_val) <= boundary,
# Exponential interpolation for |theta| <= 1
pt.switch(
param_val >= 0,
(high_val - nominal) * (pt.exp(param_val) - 1),
(nominal - low_val) * (pt.exp(-param_val) - 1),
),
# Linear extrapolation for |theta| > 1
pt.switch(
param_val >= 0,
(high_val - nominal)
* (
pt.exp(boundary)
- 1
+ (param_val - boundary) * pt.exp(boundary)
),
(nominal - low_val)
* (
pt.exp(boundary)
- 1
+ (-param_val - boundary) * pt.exp(boundary)
),
),
),
)
if interp_code == 3:
# Similar to code 2 but with different extrapolation
return cast(
TensorVar,
pt.switch(
pt.abs(param_val) <= boundary,
# Exponential interpolation for |theta| <= 1
pt.switch(
param_val >= 0,
(high_val - nominal) * (pt.exp(param_val) - 1),
(nominal - low_val) * (pt.exp(-param_val) - 1),
),
# Linear extrapolation for |theta| > 1
pt.switch(
param_val >= 0,
param_val * (high_val - nominal),
param_val * (nominal - low_val),
),
),
)
if interp_code == 4:
# Polynomial interpolation + linear extrapolation (additive)
return cast(
TensorVar,
pt.switch(
pt.abs(param_val) >= boundary,
# Linear extrapolation for |theta| >= 1
pt.switch(
param_val >= 0,
param_val * (high_val - nominal),
param_val * (nominal - low_val),
),
# 6th order polynomial interpolation for |theta| < 1
pt.switch(
param_val >= 0,
param_val
* (high_val - nominal)
* (
1
+ param_val * param_val * (-3 + param_val * param_val) / 16
),
param_val
* (nominal - low_val)
* (
1
+ param_val * param_val * (-3 + param_val * param_val) / 16
),
),
),
)
if interp_code == 5:
# Polynomial interpolation + exponential extrapolation (multiplicative)
ratio_high = high_val / nominal
ratio_low = low_val / nominal
return cast(
TensorVar,
pt.switch(
pt.abs(param_val) >= boundary,
# Exponential extrapolation for |theta| >= 1
pt.switch(
param_val >= 0,
cast(TensorVar, pt.power(ratio_high, param_val)) - 1.0, # type: ignore[no-untyped-call]
cast(TensorVar, pt.power(ratio_low, -param_val)) - 1.0, # type: ignore[no-untyped-call]
),
# 6th order polynomial interpolation for |theta| < 1
pt.switch(
param_val >= 0,
param_val
* (ratio_high - 1.0)
* (
1
+ param_val * param_val * (-3 + param_val * param_val) / 16
),
param_val
* (ratio_low - 1.0)
* (
1
+ param_val * param_val * (-3 + param_val * param_val) / 16
),
),
),
)
# Code 6: Polynomial interpolation + linear extrapolation (multiplicative)
ratio_high = high_val / nominal
ratio_low = low_val / nominal
return cast(
TensorVar,
pt.switch(
pt.abs(param_val) >= boundary,
# Linear extrapolation for |theta| >= 1
pt.switch(
param_val >= 0,
param_val * (ratio_high - 1.0),
param_val * (ratio_low - 1.0),
),
# 6th order polynomial interpolation for |theta| < 1
pt.switch(
param_val >= 0,
param_val
* (ratio_high - 1.0)
* (1 + param_val * param_val * (-3 + param_val * param_val) / 16),
param_val
* (ratio_low - 1.0)
* (1 + param_val * param_val * (-3 + param_val * param_val) / 16),
),
),
)
def expression(self, context: dict[str, TensorVar]) -> TensorVar:
r"""
Evaluate the interpolation function.
Implements ROOT's PiecewiseInterpolation algorithm following the mathematical
formulations described in the class docstring. The algorithm proceeds as:
1. Start with nominal value: :math:`\text{result} = \text{nominal}`
2. For each nuisance parameter :math:`\theta_i`, compute interpolation contribution :math:`I_i(\theta_i)`
3. Combine contributions based on interpolation mode:
- **Additive modes** (codes 0,2,3,4): :math:`\text{result} += I_i(\theta_i)`
- **Multiplicative modes** (codes 1,5,6): :math:`\text{result} \times= (1 + I_i(\theta_i))`
4. Apply positive definite constraint: :math:`\text{result} = \max(\text{result}, 0)` if requested
Args:
context: Mapping of names to pytensor variables containing:
- Nominal parameter (referenced by `nom`)
- High/low variation parameters (referenced by `high`/`low` lists)
- Nuisance parameters (referenced by `vars` list)
Returns:
PyTensor expression representing the interpolated result
Note:
The evaluation order ensures that all interpolation contributions are properly
combined according to their mathematical modes before applying constraints.
"""
# Start with nominal value
nominal = context[self.nom]
result = nominal
# Apply interpolation for each nuisance parameter
for i, var_name in enumerate(self.vars):
if (
i >= len(self.high)
or i >= len(self.low)
or i >= len(self.interpolationCodes)
):
log.warning(
"Parameter index %d exceeds variation lists for function %s",
i,
self.name,
)
continue
param_val = context[var_name]
low_val = context[self.low[i]]
high_val = context[self.high[i]]
interp_code = self.interpolationCodes[i]
# Calculate interpolated contribution
contribution = self._flexible_interp_single(
interp_code=interp_code,
low_val=low_val,
high_val=high_val,
boundary=1.0,
nominal=nominal,
param_val=param_val,
)
# Add contribution based on interpolation mode
if interp_code in [0, 2, 3, 4]: # Additive modes
result = result + contribution
else: # Multiplicative modes (1, 5, 6)
result = result * (1.0 + contribution)
# Apply positive definite constraint if requested
if self.positiveDefinite:
result = pt.maximum(result, 0.0)
return result
[docs]
class ProcessNormalizationFunction(Function):
r"""
Process normalization function with systematic variations.
Implements the CMS Combine ProcessNormalization class which computes
a normalization factor based on a nominal value and systematic variations.
Parameters:
name: Name of the function
expression: Expression identifier (typically same as name)
nominalValue: Base normalization value
thetaList: Names of symmetric variation nuisance parameters
logKappa: Symmetric log-normal variation factors
asymmThetaList: Names of asymmetric variation nuisance parameters
logAsymmKappa: Asymmetric [low, high] log-normal variation factors
otherFactorList: Names of additional multiplicative factors
"""
type: Literal["CMS::process_normalization"] = "CMS::process_normalization"
expression_name: str = Field(alias="expression")
nominalValue: float
thetaList: list[str]
logKappa: list[float]
asymmThetaList: list[str]
logAsymmKappa: list[list[float]]
otherFactorList: list[str]
@model_validator(mode="after")
def process_parameters(self) -> ProcessNormalizationFunction:
"""Build the parameters dict from all parameter lists."""
# All parameters this function depends on
all_params = [*self.thetaList, *self.asymmThetaList, *self.otherFactorList]
self._parameters = {name: name for name in all_params}
return self
def _asym_interpolation(
self, theta: TensorVar, kappa_sum: float, kappa_diff: float
) -> TensorVar:
"""
Implement asymmetric interpolation function.
Based on CMS Combine's _asym_interpolation function.
Args:
theta: Nuisance parameter value
kappa_sum: Sum of low and high kappa values
kappa_diff: Difference of high and low kappa values
Returns:
Interpolated value
"""
abs_theta = pt.abs(theta)
# Polynomial coefficients for smooth interpolation
# Based on _asym_poly = jnp.array([3.0, -10.0, 15.0, 0.0]) / 8.0
poly_result = (3.0 * theta**4 - 10.0 * theta**2 + 15.0) / 8.0
# Choose between linear extrapolation (|theta| > 1) and polynomial interpolation (|theta| <= 1)
smooth_function = pt.switch(abs_theta > 1.0, abs_theta, poly_result)
# Apply asymmetric interpolation formula
morph = 0.5 * (kappa_diff * theta + kappa_sum * smooth_function)
return cast(TensorVar, morph)
def expression(self, context: dict[str, TensorVar]) -> TensorVar:
"""
Evaluate the process normalization function.
Args:
context: Mapping of names to PyTensor variables.
Returns:
TensorVar: PyTensor expression representing the normalization factor.
"""
# Start with the nominal value
result = pt.constant(self.nominalValue)
# Add symmetric variations
sym_shift = pt.constant(0.0)
for theta_name, kappa in zip(self.thetaList, self.logKappa, strict=False):
theta = context[theta_name]
sym_shift = sym_shift + kappa * theta
# Add asymmetric variations
asym_shift = pt.constant(0.0)
for theta_name, kappa_pair in zip(
self.asymmThetaList, self.logAsymmKappa, strict=False
):
theta = context[theta_name]
kappa_lo, kappa_hi = kappa_pair
kappa_sum = kappa_hi + kappa_lo
kappa_diff = kappa_hi - kappa_lo
asym_contribution = self._asym_interpolation(theta, kappa_sum, kappa_diff)
asym_shift = asym_shift + asym_contribution
# Apply exponential of total shift
result = result * pt.exp(sym_shift + asym_shift)
# Multiply by additional factors
for factor_name in self.otherFactorList:
factor = context[factor_name]
result = result * factor
return cast(TensorVar, result)
# Define the union type for all function configurations
FunctionConfig = (
SumFunction
| ProductFunction
| GenericFunction
| InterpolationFunction
| ProcessNormalizationFunction
)
registered_functions: dict[str, type[Function]] = {
"sum": SumFunction,
"product": ProductFunction,
"generic_function": GenericFunction,
"interpolation": InterpolationFunction,
"CMS::process_normalization": ProcessNormalizationFunction,
}
# Type alias for all function types using discriminated union
FunctionType = Annotated[
SumFunction
| ProductFunction
| GenericFunction
| InterpolationFunction
| ProcessNormalizationFunction,
Field(discriminator="type"),
]
[docs]
class Functions(RootModel[list[FunctionType]]):
"""
Collection of HS3 functions for parameter computation.
Manages a set of function instances that compute parameter values
based on other parameters. Functions can be products, generic
mathematical expressions, or interpolation functions.
Provides dict-like access to functions by name and handles
function creation from configuration dictionaries.
Attributes:
funcs: Mapping from function names to Function instances.
"""
root: Annotated[
list[FunctionType],
custom_error_msg(
{
"union_tag_invalid": "Unknown function type '{tag}' does not match any of the expected functions: {expected_tags}"
}
),
] = Field(default_factory=list)
_map: dict[str, Function] = PrivateAttr(default_factory=dict)
def model_post_init(self, __context: Any, /) -> None:
"""Initialize computed collections after Pydantic validation."""
self._map = {func.name: func for func in self.root}
def __getitem__(self, item: str) -> Function:
return self._map[item]
def __contains__(self, item: str) -> bool:
return item in self._map
def __iter__(self) -> Iterator[Function]: # type: ignore[override] # https://github.com/pydantic/pydantic/issues/8872
return iter(self.root)
def __len__(self) -> int:
return len(self.root)