Source code for pyhs3.functions

"""
HS3 Functions implementation.

Provides classes for handling HS3 functions including product functions,
generic functions with mathematical expressions, and interpolation functions.
"""

from __future__ import annotations

import logging
from collections.abc import Iterator
from typing import Any, Generic, TypeVar, cast

import pytensor.tensor as pt

from pyhs3 import typing as T
from pyhs3.exceptions import UnknownInterpolationCodeError
from pyhs3.generic_parse import analyze_sympy_expr, parse_expression, sympy_to_pytensor
from pyhs3.typing import function as TF

log = logging.getLogger(__name__)


FuncT = TypeVar("FuncT", bound="Function[T.Function]")
FuncConfigT = TypeVar("FuncConfigT", bound=T.Function)


[docs] class Function(Generic[FuncConfigT]): """Base class for HS3 functions."""
[docs] def __init__(self, *, name: str, kind: str, parameters: list[str]): """ Base class for functions that compute parameter values. Args: name: Name of the function kind: Type of the function (product, generic_function, interpolation) parameters: List of parameter/function names this function depends on """ self.name = name self.kind = kind self.parameters = parameters
def expression(self, _: dict[str, T.TensorVar]) -> T.TensorVar: """ Evaluate the function expression. Args: context: Mapping of names to pytensor variables Returns: PyTensor expression representing the function result """ msg = f"Function type {self.kind} not implemented" raise NotImplementedError(msg) @classmethod def from_dict(cls, config: dict[str, Any]) -> Function[FuncConfigT]: """Create a Function instance from dictionary configuration.""" raise NotImplementedError
[docs] class ProductFunction(Function[TF.ProductFunction]): """Product function that multiplies factors together."""
[docs] def __init__(self, *, name: str, factors: list[str]): """ Initialize a ProductFunction. Args: name: Name of the function factors: List of factor names to multiply together """ # factors become the parameters this function depends on super().__init__(name=name, kind="product", parameters=factors) self.factors = factors
@classmethod def from_dict(cls, config: dict[str, Any]) -> ProductFunction: """Create a ProductFunction from dictionary configuration.""" return cls(name=config["name"], factors=config["factors"]) def expression(self, context: dict[str, T.TensorVar]) -> T.TensorVar: """ Evaluate the product function. Args: context: Mapping of names to PyTensor variables. Returns: T.TensorVar: PyTensor expression representing the product of all factors. """ if not self.factors: return pt.constant(1.0) result = context[self.factors[0]] for factor in self.factors[1:]: result = result * context[factor] return result
[docs] class GenericFunction(Function[TF.GenericFunction]): """ Generic function with custom mathematical expression. Evaluates arbitrary mathematical expressions using SymPy parsing and PyTensor computation. Supports common mathematical operations including arithmetic, trigonometric, exponential, and logarithmic functions. The expression is parsed once during initialization and converted to a PyTensor computation graph for efficient evaluation. Parameters: name (str): Name of the function. expression (str): Mathematical expression string to evaluate. Examples: >>> func = GenericFunction(name="quadratic", expression="x**2 + 2*x + 1") >>> func = GenericFunction(name="sinusoid", expression="sin(x) * exp(-t)") """
[docs] def __init__(self, *, name: str, expression: str): """ Initialize a GenericFunction. Args: name: Name of the function expression: Mathematical expression string """ self.expression_str = expression # Parse expression during initialization like GenericDist self.sympy_expr = parse_expression(expression) analysis = analyze_sympy_expr(self.sympy_expr) parameters = [str(symbol) for symbol in analysis["independent_vars"]] # Initialize parent with the parsed parameters super().__init__(name=name, kind="generic_function", parameters=parameters)
@classmethod def from_dict(cls, config: dict[str, Any]) -> GenericFunction: """Create a GenericFunction from dictionary configuration.""" return cls(name=config["name"], expression=config["expression"]) def expression(self, context: dict[str, T.TensorVar]) -> T.TensorVar: """ Evaluate the generic function expression. Args: context: Mapping of names to PyTensor variables. Returns: T.TensorVar: PyTensor expression representing the parsed mathematical expression. """ # Get required variables variables = [context[name] for name in self.parameters] # Convert using the pre-parsed sympy expression return sympy_to_pytensor(self.sympy_expr, variables)
[docs] class InterpolationFunction(Function[TF.InterpolationFunction]): r""" Piecewise interpolation function implementation. Implements ROOT's PiecewiseInterpolation logic to morph between nominal and variation distributions based on nuisance parameter values. Supports multiple interpolation codes (0-6) for different mathematical approaches. Mathematical Formulations: For **additive** interpolation modes (codes 0, 2, 3, 4): .. math:: \text{result} = \text{nominal} + \sum_i I_i(\theta_i; \text{low}_i, \text{nominal}, \text{high}_i) For **multiplicative** interpolation modes (codes 1, 5, 6): .. math:: \text{result} = \text{nominal} \times \prod_i [1 + I_i(\theta_i; \text{low}_i/\text{nominal}, 1, \text{high}_i/\text{nominal})] Interpolation Code Definitions: **Code 0** - Linear Interpolation/Extrapolation (Additive): .. math:: I_0(\theta) = \begin{cases} \theta(\text{high} - \text{nom}) & \text{if } \theta \geq 0 \\ \theta(\text{nom} - \text{low}) & \text{if } \theta < 0 \end{cases} **Code 1** - Exponential Interpolation/Extrapolation (Multiplicative): .. math:: I_1(\theta) = \begin{cases} \left(\frac{\text{high}}{\text{nom}}\right)^{\theta} - 1 & \text{if } \theta \geq 0 \\ \left(\frac{\text{low}}{\text{nom}}\right)^{-\theta} - 1 & \text{if } \theta < 0 \end{cases} **Code 2** - Exponential Interpolation + Linear Extrapolation (Additive): Uses :math:`\exp(\theta)` behavior for :math:`|\theta| \leq 1`, linear extrapolation for :math:`|\theta| > 1` with smooth transition at :math:`\theta = \pm 1`. **Code 3** - Exponential Interpolation + Different Linear Extrapolation (Additive): Uses :math:`\exp(\theta)` behavior for :math:`|\theta| \leq 1`, different linear extrapolation for :math:`|\theta| > 1` compared to code 2. **Code 4** - 6th Order Polynomial Interpolation + Linear Extrapolation (Additive): .. math:: I_4(\theta) = \begin{cases} \text{linear extrapolation} & \text{if } |\theta| \geq 1 \\ \theta \times (1 + \theta^2(-3 + \theta^2)/16) \times (\text{high} - \text{nom}) & \text{if } \theta \geq 0, |\theta| < 1 \end{cases} **Code 5** - 6th Order Polynomial Interpolation + Exponential Extrapolation (Multiplicative): Uses exponential extrapolation for :math:`|\theta| \geq 1`, 6th order polynomial for :math:`|\theta| < 1`. Recommended for normalization factors. **Code 6** - 6th Order Polynomial Interpolation + Linear Extrapolation (Multiplicative): Uses linear extrapolation for :math:`|\theta| \geq 1`, 6th order polynomial for :math:`|\theta| < 1`. Recommended for normalization factors (no roots outside :math:`|\theta| < 1`). Args: name: Name of the function high: High variation parameter names low: Low variation parameter names nom: Nominal parameter name interpolationCodes: Interpolation method codes (0-6) positiveDefinite: Whether function should be positive definite parameters: Variable names this function depends on (nuisance parameters) Note: - At :math:`\theta_i = 0`, all codes return the nominal value - At :math:`\theta_i = \pm 1`, variations should match high/low values for appropriate codes - Polynomial codes (4,5,6) provide smoother interpolation with matching derivatives - Based on A.Bukin, Budker INP, Novosibirsk and ROOT's RooFit implementation """
[docs] def __init__( self, *, name: str, high: list[str], low: list[str], nom: str, interpolationCodes: list[int], positiveDefinite: bool, parameters: list[str], ): """ Initialize an InterpolationFunction. Args: name: Name of the function high: High variation parameter names low: Low variation parameter names nom: Nominal parameter name interpolationCodes: Interpolation method codes (0-6) positiveDefinite: Whether function should be positive definite parameters: Variable names this function depends on (nuisance parameters) Raises: UnknownInterpolationCodeError: If any interpolation code is not in range 0-6 """ super().__init__(name=name, kind="interpolation", parameters=parameters) # Validate interpolation codes at initialization valid_codes = {0, 1, 2, 3, 4, 5, 6} for code in interpolationCodes: if code not in valid_codes: msg = f"Unknown interpolation code {code} in function '{name}'. Valid codes are 0-6." raise UnknownInterpolationCodeError(msg) self.high = high self.low = low self.nom = nom self.interpolationCodes = interpolationCodes self.positiveDefinite = positiveDefinite
@classmethod def from_dict(cls, config: dict[str, Any]) -> InterpolationFunction: """Create an InterpolationFunction from dictionary configuration.""" return cls( name=config["name"], high=config["high"], low=config["low"], nom=config["nom"], interpolationCodes=config["interpolationCodes"], positiveDefinite=config["positiveDefinite"], parameters=config["vars"], ) def _flexible_interp_single( self, interp_code: int, low_val: T.TensorVar, high_val: T.TensorVar, boundary: float, nominal: T.TensorVar, param_val: T.TensorVar, ) -> T.TensorVar: r""" Implement flexible interpolation for a single parameter. Based on ROOT's flexibleInterpSingle method with support for interpolation codes 0-6. This method computes the interpolation contribution :math:`I_i(\theta_i)` for a single nuisance parameter. Args: interp_code: Interpolation code (0-6) determining the mathematical approach low_val: Low variation value (used when :math:`\theta < 0`) high_val: High variation value (used when :math:`\theta \geq 0`) boundary: Boundary value for switching between interpolation and extrapolation (typically 1.0) nominal: Nominal value (baseline) param_val: Parameter value :math:`\theta` (nuisance parameter) Returns: Interpolated contribution :math:`I_i(\theta_i)` to be added (additive modes) or multiplied (multiplicative modes) with the result Note: The returned value interpretation depends on the interpolation code: - Codes 0,2,3,4: Direct additive contribution - Codes 1,5,6: Multiplicative factor (subtract 1 before use) """ # Codes 0, 2, 3, 4 are additive modes # Codes 1, 5, 6 are multiplicative modes if interp_code == 0: # Linear interpolation/extrapolation (additive) return cast( T.TensorVar, pt.switch( param_val >= 0, param_val * (high_val - nominal), param_val * (nominal - low_val), ), ) if interp_code == 1: # Exponential interpolation/extrapolation (multiplicative) ratio_high = high_val / nominal ratio_low = low_val / nominal return cast( T.TensorVar, pt.switch( param_val >= 0, cast(T.TensorVar, pt.power(ratio_high, param_val)) - 1.0, # type: ignore[no-untyped-call] cast(T.TensorVar, pt.power(ratio_low, -param_val)) - 1.0, # type: ignore[no-untyped-call] ), ) if interp_code == 2: # Exponential interpolation, linear extrapolation (additive) return cast( T.TensorVar, pt.switch( pt.abs(param_val) <= boundary, # Exponential interpolation for |theta| <= 1 pt.switch( param_val >= 0, (high_val - nominal) * (pt.exp(param_val) - 1), (nominal - low_val) * (pt.exp(-param_val) - 1), ), # Linear extrapolation for |theta| > 1 pt.switch( param_val >= 0, (high_val - nominal) * ( pt.exp(boundary) - 1 + (param_val - boundary) * pt.exp(boundary) ), (nominal - low_val) * ( pt.exp(boundary) - 1 + (-param_val - boundary) * pt.exp(boundary) ), ), ), ) if interp_code == 3: # Similar to code 2 but with different extrapolation return cast( T.TensorVar, pt.switch( pt.abs(param_val) <= boundary, # Exponential interpolation for |theta| <= 1 pt.switch( param_val >= 0, (high_val - nominal) * (pt.exp(param_val) - 1), (nominal - low_val) * (pt.exp(-param_val) - 1), ), # Linear extrapolation for |theta| > 1 pt.switch( param_val >= 0, param_val * (high_val - nominal), param_val * (nominal - low_val), ), ), ) if interp_code == 4: # Polynomial interpolation + linear extrapolation (additive) return cast( T.TensorVar, pt.switch( pt.abs(param_val) >= boundary, # Linear extrapolation for |theta| >= 1 pt.switch( param_val >= 0, param_val * (high_val - nominal), param_val * (nominal - low_val), ), # 6th order polynomial interpolation for |theta| < 1 pt.switch( param_val >= 0, param_val * (high_val - nominal) * ( 1 + param_val * param_val * (-3 + param_val * param_val) / 16 ), param_val * (nominal - low_val) * ( 1 + param_val * param_val * (-3 + param_val * param_val) / 16 ), ), ), ) if interp_code == 5: # Polynomial interpolation + exponential extrapolation (multiplicative) ratio_high = high_val / nominal ratio_low = low_val / nominal return cast( T.TensorVar, pt.switch( pt.abs(param_val) >= boundary, # Exponential extrapolation for |theta| >= 1 pt.switch( param_val >= 0, cast(T.TensorVar, pt.power(ratio_high, param_val)) - 1.0, # type: ignore[no-untyped-call] cast(T.TensorVar, pt.power(ratio_low, -param_val)) - 1.0, # type: ignore[no-untyped-call] ), # 6th order polynomial interpolation for |theta| < 1 pt.switch( param_val >= 0, param_val * (ratio_high - 1.0) * ( 1 + param_val * param_val * (-3 + param_val * param_val) / 16 ), param_val * (ratio_low - 1.0) * ( 1 + param_val * param_val * (-3 + param_val * param_val) / 16 ), ), ), ) # Code 6: Polynomial interpolation + linear extrapolation (multiplicative) ratio_high = high_val / nominal ratio_low = low_val / nominal return cast( T.TensorVar, pt.switch( pt.abs(param_val) >= boundary, # Linear extrapolation for |theta| >= 1 pt.switch( param_val >= 0, param_val * (ratio_high - 1.0), param_val * (ratio_low - 1.0), ), # 6th order polynomial interpolation for |theta| < 1 pt.switch( param_val >= 0, param_val * (ratio_high - 1.0) * (1 + param_val * param_val * (-3 + param_val * param_val) / 16), param_val * (ratio_low - 1.0) * (1 + param_val * param_val * (-3 + param_val * param_val) / 16), ), ), ) def expression(self, context: dict[str, T.TensorVar]) -> T.TensorVar: r""" Evaluate the interpolation function. Implements ROOT's PiecewiseInterpolation algorithm following the mathematical formulations described in the class docstring. The algorithm proceeds as: 1. Start with nominal value: :math:`\text{result} = \text{nominal}` 2. For each nuisance parameter :math:`\theta_i`, compute interpolation contribution :math:`I_i(\theta_i)` 3. Combine contributions based on interpolation mode: - **Additive modes** (codes 0,2,3,4): :math:`\text{result} += I_i(\theta_i)` - **Multiplicative modes** (codes 1,5,6): :math:`\text{result} \times= (1 + I_i(\theta_i))` 4. Apply positive definite constraint: :math:`\text{result} = \max(\text{result}, 0)` if requested Args: context: Mapping of names to pytensor variables containing: - Nominal parameter (referenced by `nom`) - High/low variation parameters (referenced by `high`/`low` lists) - Nuisance parameters (referenced by `parameters` list) Returns: PyTensor expression representing the interpolated result Note: The evaluation order ensures that all interpolation contributions are properly combined according to their mathematical modes before applying constraints. """ # Start with nominal value nominal = context[self.nom] result = nominal # Apply interpolation for each nuisance parameter for i, var_name in enumerate(self.parameters): if ( i >= len(self.high) or i >= len(self.low) or i >= len(self.interpolationCodes) ): log.warning( "Parameter index %d exceeds variation lists for function %s", i, self.name, ) continue param_val = context[var_name] low_val = context[self.low[i]] high_val = context[self.high[i]] interp_code = self.interpolationCodes[i] # Calculate interpolated contribution contribution = self._flexible_interp_single( interp_code=interp_code, low_val=low_val, high_val=high_val, boundary=1.0, nominal=nominal, param_val=param_val, ) # Add contribution based on interpolation mode if interp_code in [0, 2, 3, 4]: # Additive modes result = result + contribution else: # Multiplicative modes (1, 5, 6) result = result * (1.0 + contribution) # Apply positive definite constraint if requested if self.positiveDefinite: result = pt.maximum(result, 0.0) return result
registered_functions: dict[str, type[Function[Any]]] = { "product": ProductFunction, "generic_function": GenericFunction, "interpolation": InterpolationFunction, } class FunctionSet: """ Collection of HS3 functions for parameter computation. Manages a set of function instances that compute parameter values based on other parameters. Functions can be products, generic mathematical expressions, or interpolation functions. Provides dict-like access to functions by name and handles function creation from configuration dictionaries. Attributes: funcs (dict[str, Function[Any]]): Mapping from function names to Function instances. """ def __init__(self, funcs: list[T.Function]) -> None: """ Collection of functions that compute parameter values. Args: funcs: List of function configurations from HS3 spec """ self.funcs: dict[str, Function[Any]] = {} for func_config in funcs: func_type = func_config["type"] the_func = registered_functions.get(func_type, Function) if the_func is Function: msg = f"Unknown function type: {func_type}" raise ValueError(msg) func = the_func.from_dict( {k: v for k, v in func_config.items() if k != "type"} ) self.funcs[func.name] = func def __getitem__(self, item: str) -> Function[Any]: return self.funcs[item] def __contains__(self, item: str) -> bool: return item in self.funcs def __iter__(self) -> Iterator[Function[Any]]: return iter(self.funcs.values()) def __len__(self) -> int: return len(self.funcs)