Source code for pyhs3.core

from __future__ import annotations

import json
import logging
import os
import sys
from collections import Counter
from collections.abc import Callable, Mapping
from pathlib import Path
from typing import Any, Literal, TypeAlias, TypeVar, cast

import numpy as np
import numpy.typing as npt
import pytensor.tensor as pt
from pydantic import BaseModel, ConfigDict, Field, ValidationError
from pytensor.compile.function import function
from pytensor.graph.traversal import applys_between, explicit_graph_inputs
from rich.progress import (
    BarColumn,
    Progress,
    SpinnerColumn,
    TaskProgressColumn,
    TextColumn,
    TimeElapsedColumn,
    TimeRemainingColumn,
)

from pyhs3.analyses import Analyses
from pyhs3.context import Context
from pyhs3.data import Data
from pyhs3.distributions import Distributions
from pyhs3.domains import Domain, Domains, ProductDomain
from pyhs3.exceptions import WorkspaceValidationError
from pyhs3.functions import Functions
from pyhs3.likelihoods import Likelihoods
from pyhs3.metadata import Metadata
from pyhs3.networks import build_dependency_graph
from pyhs3.parameter_points import ParameterPoints, ParameterSet
from pyhs3.typing.aliases import TensorVar

log = logging.getLogger(__name__)

TDefault = TypeVar("TDefault")

Axis: TypeAlias = tuple[float | None, float | None]


[docs] class Workspace(BaseModel): """ Workspace for managing HS3 model specifications. A workspace contains parameter points, distributions, domains, and functions that define a probabilistic model. It provides methods to construct Model objects with specific parameter values and domain constraints. Attributes: metadata: Required metadata containing HS3 version and optional attribution distributions: List of distribution configurations functions: List of function configurations domains: List of domain configurations parameter_points: List of parameter point configurations data: Data specifications for observations likelihoods: Likelihood specifications mapping distributions to data analyses: Analysis configurations for automated analyses misc: Arbitrary user-created information parameter_collection (ParameterPoints): Named parameter sets. distribution_set (Distributions): Available distributions. domain_collection (Domains): Domain constraints for parameters. function_set (Functions): Available functions for parameter computation. HS3 Reference: See :hs3:label:`HS3 file format specification <hs3.file-format>` for the complete workspace structure. """ model_config = ConfigDict(arbitrary_types_allowed=True) # Required field metadata: Metadata # Optional fields using discriminated unions distributions: Distributions | None = Field( default_factory=lambda: Distributions([]) ) functions: Functions | None = Field(default_factory=lambda: Functions([])) domains: Domains | None = Field(default_factory=lambda: Domains([])) parameter_points: ParameterPoints | None = Field( default_factory=lambda: ParameterPoints([]) ) data: Data | None = Field(default_factory=lambda: Data([])) likelihoods: Likelihoods | None = Field(default_factory=lambda: Likelihoods([])) analyses: Analyses | None = Field(default_factory=lambda: Analyses([])) misc: dict[str, Any] | None = Field(default_factory=dict) @classmethod def load( cls, path: str | os.PathLike[str], *, verbose: bool = False, suppress_traceback: bool = True, ) -> Workspace: """ Load workspace from a JSON file. Args: path: Path to the JSON file containing the HS3 specification verbose: If True, show all errors. If False, show first 20 and summarize rest. suppress_traceback: If True, suppress traceback on validation errors (default True). Returns: Workspace: The loaded workspace instance """ path_obj = Path(path) with path_obj.open("r", encoding="utf-8") as f: spec_dict = json.load(f) try: return cls(**spec_dict) except ValidationError as e: error_summary = cls._format_validation_error(e, path, verbose) if suppress_traceback: sys.tracebacklimit = 0 raise WorkspaceValidationError(error_summary) from None @classmethod def _format_validation_error( cls, validation_error: ValidationError, path: str | os.PathLike[str], verbose: bool, ) -> str: """ Format a ValidationError into a readable error summary. Args: validation_error: The ValidationError to format path: Path to the file that caused the error verbose: If True, show all errors. If False, show first 20 and summarize rest. Returns: Formatted error message string """ error_count = len(validation_error.errors()) error_types: Counter[str] = Counter() loc_errors: Counter[tuple[str, ...]] = Counter() for error in validation_error.errors(): error_types[error["type"]] += 1 loc_errors[ tuple("#" if isinstance(key, int) else key for key in error["loc"]) ] += 1 # Create a concise error message error_summary = ( f"Workspace validation failed with {error_count} errors from {path}\n" ) error_summary += "\nError breakdown by type:\n" for error_type, count in error_types.most_common(): error_summary += f" {error_type}: {count}\n" error_summary += "\nError breakdown by component:\n" for loc, count in loc_errors.most_common(): loc_str = ".".join(loc) error_summary += f" {loc_str}: {count}\n" # Show detailed errors with improved formatting errors_to_show = ( validation_error.errors() if verbose else validation_error.errors()[:20] ) error_summary += ( f"\nErrors for debugging ({'all' if verbose else 'first 20'}):\n" ) for i, error in enumerate(errors_to_show): # Format location more readably loc_parts = [] for part in error.get("loc", []): if isinstance(part, int): loc_parts.append(f"[{part}]") else: loc_parts.append(str(part)) # Build readable location string readable_loc = "" for j, part in enumerate(loc_parts): if j == 0: readable_loc = part elif part.startswith("["): readable_loc += part # Index directly follows else: readable_loc += f" -> {part}" # Add name from input if available input_data: Any = error.get("input", {}) if isinstance(input_data, dict) and "name" in input_data: name = input_data["name"] if readable_loc and not readable_loc.endswith("]"): readable_loc += f"('{name}')" msg = error.get("msg", "Unknown error") error_summary += f" {i + 1}. {readable_loc}: {msg}\n" if not verbose and error_count > 20: error_summary += f" ... and {error_count - 20} more errors (use verbose=True to see all)\n" return error_summary def model( self, *, domain: int | str | Domain = 0, parameter_set: int | str | ParameterSet = 0, progress: bool = True, mode: str = "FAST_RUN", ) -> Model: """ Constructs a `Model` object using the provided domain and parameter set. Args: domain (int | str | Domain): Identifier or object specifying the domain to use. parameter_set (int | str | ParameterSet): Identifier or object specifying the parameter values to use. progress (bool): Whether to show progress bar during dependency graph construction. Defaults to True. mode (str): PyTensor compilation mode. Defaults to "FAST_RUN". Options: "FAST_RUN" (apply all rewrites, use C implementations), "FAST_COMPILE" (few rewrites, Python implementations), "NUMBA" (compile using Numba), "JAX" (compile using JAX), "PYTORCH" (compile using PyTorch), "DebugMode" (debugging), "NanGuardMode" (NaN detection). Returns: Model: The constructed model object. """ selected_domain = ( domain if isinstance(domain, Domain) else self.domains[domain] if self.domains else ProductDomain(name="default") ) parameterset = ( parameter_set if isinstance(parameter_set, ParameterSet) else self.parameter_points[parameter_set] if self.parameter_points else ParameterSet(name="default", parameters=[]) ) return Model( parameterset=parameterset or ParameterSet(name="default"), distributions=self.distributions or Distributions(), domain=selected_domain or Domain(name="default", type="unknown"), functions=self.functions or Functions(), progress=progress, mode=mode, )
[docs] class Model: """ Probabilistic model with compiled tensor operations. A model represents a specific instantiation of a workspace with concrete parameter values and domain constraints. It builds symbolic computation graphs for distributions and functions, with optional compilation for performance optimization. The model handles dependency resolution between parameters, functions, and distributions, ensuring proper evaluation order through topological sorting of the computation graph. HS3 Reference: Models are computational representations of :hs3:label:`HS3 workspaces <hs3.file-format>`. """
[docs] def __init__( self, *, parameterset: ParameterSet, distributions: Distributions, domain: Domain, functions: Functions, progress: bool = True, mode: str = "FAST_RUN", ): """ Represents a probabilistic model composed of parameters, domains, distributions, and functions. Args: parameterset (ParameterSet): The parameter set used in the model. distributions (Distributions): Set of distributions to include. domain (Domain): Domain constraints for parameters. functions (Functions): Set of functions that compute parameter values. progress (bool): Whether to show progress bar during dependency graph construction. mode (str): PyTensor compilation mode. Defaults to "FAST_RUN". Options: "FAST_RUN" (apply all rewrites, use C implementations), "FAST_COMPILE" (few rewrites, Python implementations), "NUMBA" (compile using Numba), "JAX" (compile using JAX), "PYTORCH" (compile using PyTorch), "DebugMode" (debugging), "NanGuardMode" (NaN detection). Attributes: domain (Domain): The original domain with constraints for parameters. parameterset (ParameterSet): The original parameter set with parameter values. distributions (dict[str, pytensor.tensor.variable.TensorVariable]): Symbolic distribution expressions. parameters (dict[str, pytensor.tensor.variable.TensorVariable]): Symbolic parameter variables. functions (dict[str, pytensor.tensor.variable.TensorVariable]): Computed function values. mode (str): PyTensor compilation mode. _compiled_functions (dict[str, Callable[..., npt.NDArray[np.float64]]]): Cache of compiled PyTensor functions. """ self.parameterset = parameterset self.domain = domain self._distribution_objects = ( distributions # Store original distribution objects ) self._function_objects = functions # Store original function objects self.parameters: dict[str, TensorVar] = {} self.functions: dict[str, TensorVar] = {} self.distributions: dict[str, TensorVar] = {} self.modifiers: dict[str, TensorVar] = {} self.mode = mode self._compiled_functions: dict[str, Callable[..., npt.NDArray[np.float64]]] = {} self._compiled_inputs: dict[str, list[TensorVar]] = {} # Build dependency graph with proper entity identification self._build_dependency_graph(functions, distributions, progress)
@staticmethod def _ensure_array( value: float | list[float] | npt.NDArray[np.float64], ) -> npt.NDArray[np.float64]: """ Ensure a value is a numpy array with dtype float64. Converts scalars to 0-d arrays and lists to 1-d arrays. Existing numpy arrays are converted to float64 dtype if needed. Args: value: Input value (scalar, list, or array) Returns: NumPy array with dtype float64 """ return np.asarray(value, dtype=np.float64) def _build_dependency_graph( self, functions: Functions, distributions: Distributions, progress: bool = True, ) -> None: """ Build and evaluate dependency graph for functions and distributions. This method properly handles cross-references between functions, distributions, and parameters by building a complete dependency graph first, then evaluating in topological order. """ # Build dependency graph using the networks module graph, constants_map, modifiers_map = build_dependency_graph( self.parameterset, functions, distributions ) # Get topological order (handles cycle detection internally) sorted_nodes = graph.topological_sort() # Evaluate nodes in topological order with optional progress bar with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}", style="cyan"), BarColumn(), TaskProgressColumn(), TimeElapsedColumn(), TimeRemainingColumn(), expand=True, transient=True, # Progress bar disappears when finished disable=not progress, # Disable progress bar if progress=False ) as progress_bar: task = progress_bar.add_task( "Building expressions...", total=len(sorted_nodes) ) for node_idx in sorted_nodes: node_data = graph[node_idx] node_type: Literal[ "parameter", "constant", "function", "distribution", "modifier" ] = node_data["type"] node_name = node_data["name"] # Truncate long names to prevent jumpiness max_name_length = 60 display_name = node_name if len(node_name) > max_name_length: display_name = node_name[: max_name_length - 3] + "..." # Update progress description with current entity (fixed width) progress_bar.update( task, description=f"Building {node_type:<12}: {display_name:<{max_name_length}}", ) # Build context with all currently available entities context_data = { **self.parameters, **self.functions, **self.distributions, **self.modifiers, } context = Context(parameters=context_data) if node_type == "parameter": # Create parameter tensor with domain bounds applied domain_bounds = ( self.domain.get(node_name, (None, None)) if self.domain else (None, None) ) param_point = ( self.parameterset.get(node_name) if self.parameterset else None ) # Default to vector for observed data parameters, scalar for others if param_point: param_kind = param_point.kind elif "_observed" in node_name: param_kind = pt.vector else: param_kind = pt.scalar self.parameters[node_name] = create_bounded_tensor( node_name, domain_bounds, param_kind ) elif node_type == "constant": # Constants are pre-created by distributions - add to parameters self.parameters[node_name] = constants_map[node_name] elif node_type == "function": # Functions are evaluated by design self.functions[node_name] = functions[node_name].expression(context) elif node_type == "modifier": # Modifiers are evaluated and stored for later use by distributions # Use pre-built modifiers map for efficient O(1) lookup self.modifiers[node_name] = modifiers_map[node_name].expression( context ) else: # node_type == "distribution" # Distributions are evaluated by design self.distributions[node_name] = distributions[node_name].expression( context ) # Advance progress progress_bar.advance(task) def _get_compiled_function( self, name: str ) -> Callable[..., npt.NDArray[np.float64]]: """ Get or create a compiled PyTensor function for the specified distribution. The distribution expression already includes both the main likelihood and extended likelihood terms, so no additional combination is needed. Args: name (str): Name of the distribution. Returns: Callable: Compiled PyTensor function. """ if name not in self._compiled_functions: # Get the distribution expression (already includes extended_likelihood) dist_expression = self.distributions[name] inputs = [ var for var in explicit_graph_inputs([dist_expression]) if var.name is not None ] # Cache the inputs list for consistent ordering self._compiled_inputs[name] = cast(list[TensorVar], inputs) # Use the specified PyTensor mode compilation_mode = self.mode self._compiled_functions[name] = cast( Callable[..., npt.NDArray[np.float64]], function( inputs=inputs, outputs=dist_expression, mode=compilation_mode, on_unused_input="ignore", name=name, trust_input=True, ), ) return self._compiled_functions[name] def pdf_unsafe( self, name: str, **parametervalues: float | list[float] | npt.NDArray[np.float64], ) -> npt.NDArray[np.float64]: """ Evaluates the PDF with automatic type conversion (convenience method). This method automatically converts parameter values to numpy arrays before evaluation. Use this for convenience in testing or interactive use. For performance-critical code, prefer :meth:`pdf` with pre-converted numpy arrays. Args: name (str): Name of the distribution to evaluate. **parametervalues: Values for each parameter (floats, lists, or arrays). Returns: npt.NDArray[np.float64]: The evaluated PDF value. See Also: :meth:`pdf`: Type-safe version requiring numpy arrays :meth:`logpdf_unsafe`: Log PDF with automatic type conversion Example: >>> model.pdf_unsafe("gauss", x=1.5, mu=0.0, sigma=1.0) # floats ok # doctest: +SKIP >>> model.pdf_unsafe("gauss", x=[1.5], mu=0.0, sigma=1.0) # lists ok # doctest: +SKIP """ # Convert all parameter values to numpy arrays converted_params = { key: self._ensure_array(value) for key, value in parametervalues.items() } return self.pdf(name, **converted_params) def pdf( self, name: str, **parametervalues: npt.NDArray[np.float64] ) -> npt.NDArray[np.float64]: """ Evaluates the probability density function of the specified distribution. This method requires all parameter values to be numpy arrays with dtype float64. For automatic type conversion, use :meth:`pdf_unsafe` instead. Args: name (str): Name of the distribution to evaluate. **parametervalues: Values for each parameter as numpy arrays. Returns: npt.NDArray[np.float64]: The evaluated PDF value. Raises: TypeError: If any parameter value is not a numpy array. See Also: :meth:`pdf_unsafe`: Convenience version with automatic type conversion :meth:`logpdf`: Log PDF with strict type checking Example: >>> import numpy as np >>> model.pdf("gauss", x=np.array(1.5), mu=np.array(0.0), sigma=np.array(1.0)) # doctest: +SKIP """ # Use compiled function for better performance func = self._get_compiled_function(name) positional_values = self._reorder_params(name, parametervalues) return func(*positional_values) def logpdf_unsafe( self, name: str, **parametervalues: float | list[float] | npt.NDArray[np.float64], ) -> npt.NDArray[np.float64]: """ Evaluates the log PDF with automatic type conversion (convenience method). This method automatically converts parameter values to numpy arrays before evaluation. Use this for convenience in testing or interactive use. For performance-critical code, prefer :meth:`logpdf` with pre-converted numpy arrays. Args: name (str): Name of the distribution to evaluate. **parametervalues: Values for each parameter (floats, lists, or arrays). Returns: npt.NDArray[np.float64]: The log of the PDF. See Also: :meth:`logpdf`: Type-safe version requiring numpy arrays :meth:`pdf_unsafe`: PDF with automatic type conversion Example: >>> model.logpdf_unsafe("gauss", x=1.5, mu=0.0, sigma=1.0) # floats ok # doctest: +SKIP """ return np.log(self.pdf_unsafe(name, **parametervalues)) def logpdf( self, name: str, **parametervalues: npt.NDArray[np.float64] ) -> npt.NDArray[np.float64]: """ Evaluates the natural logarithm of the PDF. This method requires all parameter values to be numpy arrays with dtype float64. For automatic type conversion, use :meth:`logpdf_unsafe` instead. Args: name (str): Name of the distribution to evaluate. **parametervalues: Values for each parameter as numpy arrays. Returns: npt.NDArray[np.float64]: The log of the PDF. Raises: TypeError: If any parameter value is not a numpy array. See Also: :meth:`logpdf_unsafe`: Convenience version with automatic type conversion :meth:`pdf`: PDF with strict type checking Example: >>> import numpy as np >>> model.logpdf("gauss", x=np.array(1.5), mu=np.array(0.0), sigma=np.array(1.0)) # doctest: +SKIP """ return np.log(self.pdf(name, **parametervalues)) def pars(self, name: str) -> list[str]: """ Get the ordered list of input parameter names for a distribution. This method returns the parameter names in the exact order expected by the compiled PDF function. This is useful when you need to know the order of parameters for programmatic access. Args: name: Distribution name Returns: List of parameter names in the order expected by pdf() Example: >>> model.pars("model_singlechannel") # doctest: +SKIP ['uncorr_bkguncrt_1', 'uncorr_bkguncrt_0', 'model_singlechannel_observed', 'mu', 'Lumi'] """ if name not in self._compiled_inputs: # Trigger compilation to populate cache self._get_compiled_function(name) return [var.name for var in self._compiled_inputs[name] if var.name is not None] def parsort(self, name: str, names: list[str]) -> list[int]: """ Similar to numpy's argsort, returns the indices that would sort the parameters. Args: name: Distribution name names: Parameter names to sort Returns: List of indices that would sort the parameters Example: >>> model.parsort("model_singlechannel", ["mu", "Lumi", "uncorr_bkguncrt_0", "uncorr_bkguncrt_1", "model_singlechannel_observed"]) # doctest: +SKIP [3, 2, 4, 0, 1] """ return [names.index(par) for par in self.pars(name)] def _reorder_params( self, name: str, params: Mapping[str, npt.NDArray[np.float64]], ) -> list[npt.NDArray[np.float64]]: """ Reorder parameters to match the expected input order for a distribution. Args: name: Distribution name params: Dictionary of parameter values (numpy arrays) Returns: List of values in the correct order for the compiled function """ input_order = self.pars(name) return [params[param_name] for param_name in input_order] def visualize_graph( self, name: str, fmt: str = "svg", outfile: str | None = None, path: str | None = None, ) -> str: """ Visualize the computation graph for a distribution. Args: name (str): Distribution name. fmt (str): Output format ('svg', 'png', 'pdf'). Defaults to 'svg'. outfile (str | None): Output filename. If None, uses '{name}_graph.{fmt}'. path (str | None): Directory path for output. If None, uses current working directory. Returns: str: Path to the generated visualization file. Raises: ImportError: If pydot is not installed. """ try: from pytensor.printing import ( # noqa: PLC0415 # pylint: disable=import-outside-toplevel pydotprint, ) except ImportError as e: msg = "Graph visualization requires pydot. Install with: pip install pydot" raise ImportError(msg) from e if name not in self.distributions: msg = f"Distribution '{name}' not found in model" raise ValueError(msg) dist = self.distributions[name] if outfile is not None: filename = outfile else: base_filename = f"{name}_graph.{fmt}" if path is not None: filename = str(Path(path) / base_filename) else: filename = base_filename pydotprint( dist, outfile=filename, format=fmt, with_ids=True, high_contrast=True ) return filename def __repr__(self) -> str: """Provide a concise overview of the model structure.""" param_names = list(self.parameters.keys()) dist_names = list(self.distributions.keys()) func_names = list(self.functions.keys()) param_display = ", ".join(param_names[:5]) + ( "..." if len(param_names) > 5 else "" ) dist_display = ", ".join(dist_names[:3]) + ( "..." if len(dist_names) > 3 else "" ) func_display = ", ".join(func_names[:3]) + ( "..." if len(func_names) > 3 else "" ) mode_status = self.mode return f"""Model( mode: {mode_status} parameters: {len(param_names)} ({param_display}) distributions: {len(dist_names)} ({dist_display}) functions: {len(func_names)} ({func_display}) )""" def graph_summary(self, name: str) -> str: """ Get a summary of the computation graph structure. Args: name (str): Distribution name. Returns: str: Summary of the graph structure. """ if name not in self.distributions: msg = f"Distribution '{name}' not found in model" raise ValueError(msg) dist = self.distributions[name] inputs = list(explicit_graph_inputs([dist])) # Count different types of operations applies = list(applys_between(inputs, [dist])) op_types: dict[str, int] = {} for apply in applies: op_name = type(apply.op).__name__ op_types[op_name] = op_types.get(op_name, 0) + 1 compile_info = f"\n Mode: {self.mode}\n Compiled: {'Yes' if self.mode != 'FAST_COMPILE' and name in self._compiled_functions else 'No'}" return f"""Distribution '{name}': Input variables: {len(inputs)} Graph operations: {len(applies)} Operation types: {dict(sorted(op_types.items()))}{compile_info} """
def create_bounded_tensor( name: str, domain: Axis, kind: Callable[..., TensorVar] = pt.scalar ) -> TensorVar: """ Creates a tensor variable with optional domain constraints. Args: name: Name of the parameter. domain (tuple): Tuple specifying (min, max) range. Use None for unbounded sides. For example: (0.0, None) for lower bound only, (None, 1.0) for upper bound only. If both bounds are None, returns an unbounded tensor. kind: pt.scalar for scalars, pt.vector for vectors (default: pt.scalar). Returns: pytensor.tensor.variable.TensorVariable: The tensor variable, clipped to domain if bounds exist. Examples: >>> sigma = create_bounded_tensor("sigma", (0.0, None)) # sigma >= 0 (scalar) >>> fraction = create_bounded_tensor("fraction", (0.0, 1.0)) # 0 <= fraction <= 1 (scalar) >>> temperatures = create_bounded_tensor("temperatures", (None, 100.0), pt.vector) # vector <= 100 >>> unbounded = create_bounded_tensor("unbounded", (None, None)) # no bounds applied """ min_bound, max_bound = domain # Create the base tensor tensor = kind(name) # If both bounds are None, return unbounded tensor if min_bound is None and max_bound is None: return tensor # Use infinity constants for unbounded sides min_val = pt.constant(-np.inf) if min_bound is None else pt.constant(min_bound) max_val = pt.constant(np.inf) if max_bound is None else pt.constant(max_bound) clipped = pt.clip(tensor, min_val, max_val) clipped.name = tensor.name # Preserve the original name return cast(TensorVar, clipped)