from __future__ import annotations
import logging
import sys
from collections import OrderedDict
from collections.abc import Iterator
from dataclasses import dataclass, field
from pathlib import Path
from typing import Callable, TypeVar, cast
import numpy as np
import numpy.typing as npt
import pytensor.tensor as pt
import rustworkx as rx
from pytensor.compile.function import function
from pytensor.graph.basic import applys_between, graph_inputs
from rich.progress import (
BarColumn,
Progress,
SpinnerColumn,
TaskProgressColumn,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
)
from pyhs3 import typing as T
from pyhs3.distributions import DistributionSet
from pyhs3.functions import FunctionSet
from pyhs3.typing_compat import TypeAlias
log = logging.getLogger(__name__)
TDefault = TypeVar("TDefault")
if sys.version_info >= (3, 10):
Axis: TypeAlias = tuple[float | None, float | None]
else:
Axis: TypeAlias = tuple["float | None", "float | None"]
[docs]
class Workspace:
"""
Workspace for managing HS3 model specifications.
A workspace contains parameter points, distributions, domains, and functions
that define a probabilistic model. It provides methods to construct Model
objects with specific parameter values and domain constraints.
Attributes:
parameter_collection (ParameterCollection): Named parameter sets.
distribution_set (DistributionSet): Available distributions.
domain_collection (DomainCollection): Domain constraints for parameters.
function_set (FunctionSet): Available functions for parameter computation.
"""
[docs]
def __init__(self, spec: T.HS3Spec):
"""
Manages the overall structure of the model including parameters, domains, and distributions.
Args:
spec (dict): A dictionary containing model definitions including parameter points, distributions,
and domains.
Attributes:
parameter_collection (ParameterCollection): Set of named parameter points.
distribution_set (DistributionSet): All distributions used in the workspace.
domain_collection (DomainCollection): Domain definitions for all parameters.
"""
self.parameter_collection = ParameterCollection(
spec.get("parameter_points", [])
)
self.distribution_set = DistributionSet(spec.get("distributions", []))
self.domain_collection = DomainCollection(spec.get("domains", []))
self.function_set = FunctionSet(spec.get("functions", []))
def model(
self,
*,
domain: int | str | DomainSet = 0,
parameter_point: int | str | ParameterSet = 0,
progress: bool = True,
mode: str = "FAST_RUN",
) -> Model:
"""
Constructs a `Model` object using the provided domain and parameter point.
Args:
domain (int | str | DomainSet): Identifier or object specifying the domain to use.
parameter_point (int | str | ParameterSet): Identifier or object specifying the parameter values to use.
progress (bool): Whether to show progress bar during dependency graph construction. Defaults to True.
mode (str): PyTensor compilation mode. Defaults to "FAST_RUN".
Options: "FAST_RUN" (apply all rewrites, use C implementations),
"FAST_COMPILE" (few rewrites, Python implementations),
"NUMBA" (compile using Numba), "JAX" (compile using JAX),
"PYTORCH" (compile using PyTorch), "DebugMode" (debugging),
"NanGuardMode" (NaN detection).
Returns:
Model: The constructed model object.
"""
domainset = (
domain if isinstance(domain, DomainSet) else self.domain_collection[domain]
)
parameterset = (
parameter_point
if isinstance(parameter_point, ParameterSet)
else self.parameter_collection[parameter_point]
)
# Verify that domains are a subset of parameters (not all parameters need bounds)
param_names = set(parameterset.points.keys())
domain_names = set(domainset.domains.keys())
assert domain_names.issubset(param_names), (
f"Domain names must be a subset of parameter names. "
f"Extra domains: {domain_names - param_names}"
)
return Model(
parameterset=parameterset,
distributions=self.distribution_set,
domains=domainset,
functions=self.function_set,
progress=progress,
mode=mode,
)
[docs]
class Model:
"""
Probabilistic model with compiled tensor operations.
A model represents a specific instantiation of a workspace with concrete
parameter values and domain constraints. It builds symbolic computation
graphs for distributions and functions, with optional compilation for
performance optimization.
The model handles dependency resolution between parameters, functions,
and distributions, ensuring proper evaluation order through topological
sorting of the computation graph.
"""
[docs]
def __init__(
self,
*,
parameterset: ParameterSet,
distributions: DistributionSet,
domains: DomainSet,
functions: FunctionSet,
progress: bool = True,
mode: str = "FAST_RUN",
):
"""
Represents a probabilistic model composed of parameters, domains, distributions, and functions.
Args:
parameterset (ParameterSet): The parameter set used in the model.
distributions (DistributionSet): Set of distributions to include.
domains (DomainSet): Domain constraints for parameters.
functions (FunctionSet): Set of functions that compute parameter values.
progress (bool): Whether to show progress bar during dependency graph construction.
mode (str): PyTensor compilation mode. Defaults to "FAST_RUN".
Options: "FAST_RUN" (apply all rewrites, use C implementations),
"FAST_COMPILE" (few rewrites, Python implementations),
"NUMBA" (compile using Numba), "JAX" (compile using JAX),
"PYTORCH" (compile using PyTorch), "DebugMode" (debugging),
"NanGuardMode" (NaN detection).
Attributes:
parameters (dict[str, pytensor.tensor.variable.TensorVariable]): Symbolic parameter variables.
parameterset (ParameterSet): The original set of parameter values.
distributions (dict[str, pytensor.tensor.variable.TensorVariable]): Symbolic distribution expressions.
functions (dict[str, pytensor.tensor.variable.TensorVariable]): Computed function values.
mode (str): PyTensor compilation mode.
_compiled_functions (dict[str, Callable[..., npt.NDArray[np.float64]]]): Cache of compiled PyTensor functions.
"""
self.parameters = {}
self.parameterset = parameterset
self.functions: dict[str, T.TensorVar] = {}
self.mode = mode
self._compiled_functions: dict[str, Callable[..., npt.NDArray[np.float64]]] = {}
for parameter_point in parameterset:
# Create scalar parameter with domain bounds applied
domain = domains.domains.get(parameter_point.name, (None, None))
self.parameters[parameter_point.name] = boundedscalar(
parameter_point.name, domain
)
self.distributions: dict[str, T.TensorVar] = {}
# Build dependency graph with proper entity identification
self._build_dependency_graph(functions, distributions, progress)
def _build_dependency_graph(
self,
functions: FunctionSet,
distributions: DistributionSet,
progress: bool = True,
) -> None:
"""
Build and evaluate dependency graph for functions and distributions.
This method properly handles cross-references between functions, distributions,
and parameters by building a complete dependency graph first, then evaluating
in topological order.
"""
graph = rx.PyDiGraph()
nodes: dict[str, int] = {}
# Build entity type mapping for O(1) lookup
entity_types: dict[str, str] = {}
# Build constants mapping for O(1) lookup
constants_map: dict[str, T.TensorVar] = {}
# Map all parameter names
for param in self.parameterset:
entity_types[param.name] = "parameter"
# Map all function names
for func in functions:
entity_types[func.name] = "function"
# Map all distribution names and collect their constants
for dist in distributions:
entity_types[dist.name] = "distribution"
# Also map any constants generated by this distribution
for constant_name, constant_tensor in dist.constants.items():
entity_types[constant_name] = "constant"
constants_map[constant_name] = constant_tensor
# First pass: Add all nodes to the graph using entity_types
for entity_name, entity_type in entity_types.items():
node_idx = graph.add_node({"type": entity_type, "name": entity_name})
nodes[entity_name] = node_idx
# Second pass: Add edges by iterating through all computational entities
# Both functions and distributions have .parameters, so treat them uniformly
for entity in [*functions, *distributions]:
entity_idx = nodes[entity.name]
# Add dependencies (parameters this entity depends on)
for param_name in entity.parameters:
try:
param_idx = nodes[param_name]
except KeyError:
msg = (
f"Unknown entity referenced: '{param_name}' from '{entity.name}'. "
f"Not found in parameters, functions, or distributions."
)
raise ValueError(msg) from None
# Add edge: dependency -> entity (param/func/dist feeds into entity)
graph.add_edge(param_idx, entity_idx, None)
# Third pass: Evaluate in topological order
try:
sorted_nodes = rx.topological_sort(graph)
except rx.DAGHasCycle as e:
msg = "Circular dependency detected in model"
raise ValueError(msg) from e
# Evaluate nodes in topological order with optional progress bar
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}", style="cyan"),
BarColumn(),
TaskProgressColumn(),
TimeElapsedColumn(),
TimeRemainingColumn(),
expand=True,
transient=True, # Progress bar disappears when finished
disable=not progress, # Disable progress bar if progress=False
) as progress_bar:
task = progress_bar.add_task(
"Building expressions...", total=len(sorted_nodes)
)
for node_idx in sorted_nodes:
node_data = graph[node_idx]
node_type = node_data["type"]
node_name = node_data["name"]
# Truncate long names to prevent jumpiness
max_name_length = 60
display_name = node_name
if len(node_name) > max_name_length:
display_name = node_name[: max_name_length - 3] + "..."
# Update progress description with current entity (fixed width)
progress_bar.update(
task,
description=f"Building {node_type:<12}: {display_name:<{max_name_length}}",
)
# Build context with all currently available entities
context = {**self.parameters, **self.functions, **self.distributions}
if node_type == "parameter":
# Parameters are already created with bounds applied, nothing to do
pass
elif node_type == "constant":
# Constants are pre-created by distributions - add to parameters
self.parameters[node_name] = constants_map[node_name]
elif node_type == "function":
# Functions are evaluated by design
self.functions[node_name] = functions[node_name].expression(context)
elif node_type == "distribution":
# Distributions are evaluated by design
self.distributions[node_name] = distributions[node_name].expression(
context
)
# Advance progress
progress_bar.advance(task)
def _get_compiled_function(
self, name: str
) -> Callable[..., npt.NDArray[np.float64]]:
"""
Get or create a compiled PyTensor function for the specified distribution.
Args:
name (str): Name of the distribution.
Returns:
Callable: Compiled PyTensor function.
"""
if name not in self._compiled_functions:
dist = self.distributions[name]
inputs = [var for var in graph_inputs([dist]) if var.name is not None]
# Use the specified PyTensor mode
compilation_mode = self.mode
self._compiled_functions[name] = cast(
Callable[..., npt.NDArray[np.float64]],
function(
inputs=inputs,
outputs=dist,
mode=compilation_mode,
on_unused_input="ignore",
), # type: ignore[no-untyped-call]
)
return self._compiled_functions[name]
def pdf(self, name: str, **parametervalues: float) -> npt.NDArray[np.float64]:
"""
Evaluates the probability density function of the specified distribution.
Args:
name (str): Name of the distribution to evaluate.
**parametervalues (float): Values for each distribution parameter.
Returns:
npt.NDArray[np.float64]: The evaluated PDF value.
"""
if self.mode != "FAST_COMPILE":
# Use compiled function for better performance
func = self._get_compiled_function(name)
inputs = [
var
for var in graph_inputs([self.distributions[name]])
if var.name is not None
]
positional_values = []
for var in inputs:
assert var.name is not None
positional_values.append(parametervalues[var.name])
return func(*positional_values)
# Use original uncompiled approach
dist = self.distributions[name]
inputs = [var for var in graph_inputs([dist]) if var.name is not None]
keyword_values: dict[str, float] = {}
for var in inputs:
assert var.name is not None
keyword_values[var.name] = parametervalues[var.name]
func = cast(
Callable[..., npt.NDArray[np.float64]],
function(inputs=inputs, outputs=dist), # type: ignore[no-untyped-call]
)
return func(**keyword_values)
def logpdf(self, name: str, **parametervalues: float) -> npt.NDArray[np.float64]:
"""
Evaluates the natural logarithm of the PDF.
Args:
name (str): Name of the distribution to evaluate.
**parametervalues (float): Values for each distribution parameter.
Returns:
npt.NDArray[np.float64]: The log of the PDF.
"""
return np.log(self.pdf(name, **parametervalues))
def visualize_graph(
self,
name: str,
fmt: str = "svg",
outfile: str | None = None,
path: str | None = None,
) -> str:
"""
Visualize the computation graph for a distribution.
Args:
name (str): Distribution name.
fmt (str): Output format ('svg', 'png', 'pdf'). Defaults to 'svg'.
outfile (str | None): Output filename. If None, uses '{name}_graph.{fmt}'.
path (str | None): Directory path for output. If None, uses current working directory.
Returns:
str: Path to the generated visualization file.
Raises:
ImportError: If pydot is not installed.
"""
try:
from pytensor.printing import ( # noqa: PLC0415 # pylint: disable=import-outside-toplevel
pydotprint,
)
except ImportError as e:
msg = "Graph visualization requires pydot. Install with: pip install pydot"
raise ImportError(msg) from e
if name not in self.distributions:
msg = f"Distribution '{name}' not found in model"
raise ValueError(msg)
dist = self.distributions[name]
if outfile is not None:
filename = outfile
else:
base_filename = f"{name}_graph.{fmt}"
if path is not None:
filename = str(Path(path) / base_filename)
else:
filename = base_filename
pydotprint( # type: ignore[no-untyped-call]
dist, outfile=filename, format=fmt, with_ids=True, high_contrast=True
)
return filename
def __repr__(self) -> str:
"""Provide a concise overview of the model structure."""
param_names = list(self.parameters.keys())
dist_names = list(self.distributions.keys())
func_names = list(self.functions.keys())
param_display = ", ".join(param_names[:5]) + (
"..." if len(param_names) > 5 else ""
)
dist_display = ", ".join(dist_names[:3]) + (
"..." if len(dist_names) > 3 else ""
)
func_display = ", ".join(func_names[:3]) + (
"..." if len(func_names) > 3 else ""
)
mode_status = self.mode
return f"""Model(
mode: {mode_status}
parameters: {len(param_names)} ({param_display})
distributions: {len(dist_names)} ({dist_display})
functions: {len(func_names)} ({func_display})
)"""
def graph_summary(self, name: str) -> str:
"""
Get a summary of the computation graph structure.
Args:
name (str): Distribution name.
Returns:
str: Summary of the graph structure.
"""
if name not in self.distributions:
msg = f"Distribution '{name}' not found in model"
raise ValueError(msg)
dist = self.distributions[name]
inputs = list(graph_inputs([dist]))
# Count different types of operations
applies = list(applys_between(inputs, [dist]))
op_types: dict[str, int] = {}
for apply in applies:
op_name = type(apply.op).__name__
op_types[op_name] = op_types.get(op_name, 0) + 1
compile_info = f"\n Mode: {self.mode}\n Compiled: {'Yes' if self.mode != 'FAST_COMPILE' and name in self._compiled_functions else 'No'}"
return f"""Distribution '{name}':
Input variables: {len(inputs)}
Graph operations: {len(applies)}
Operation types: {dict(sorted(op_types.items()))}{compile_info}
"""
class ParameterCollection:
"""
Collection of named parameter sets for model configuration.
Manages multiple parameter sets, each containing a collection of
parameter points with specific names and values. Provides dict-like
access to parameter sets by name or index.
Attributes:
sets (dict[str, ParameterSet]): Mapping from parameter set names to ParameterSet objects.
"""
def __init__(self, parametersets: list[T.ParameterPoint]):
"""
A collection of named parameter sets.
Args:
parametersets (list): List of parameterset configurations.
Attributes:
sets (OrderedDict): Mapping from parameter set names to ParameterSet objects.
"""
self.sets: dict[str, ParameterSet] = OrderedDict()
for parameterset_config in parametersets:
parameterset = ParameterSet(
parameterset_config["name"], parameterset_config["parameters"]
)
self.sets[parameterset.name] = parameterset
def __getitem__(self, item: str | int) -> ParameterSet:
key = list(self.sets.keys())[item] if isinstance(item, int) else item
return self.sets[key]
def get(
self, item: str, default: TDefault | None = None
) -> ParameterSet | TDefault | None:
"""Get a parameter set by name, returning default if not found."""
return self.sets.get(item, default)
def __contains__(self, item: str) -> bool:
return item in self.sets
def __iter__(self) -> Iterator[ParameterSet]:
return iter(self.sets.values())
def __len__(self) -> int:
return len(self.sets)
class ParameterSet:
"""
Named collection of parameter points with specific values.
Represents a single configuration of parameter values that can be
used to evaluate a model. Each parameter set contains multiple
parameter points, each with a name and numeric value.
Attributes:
name (str): Name of the parameter set.
points (dict[str, ParameterPoint]): Mapping of parameter names to ParameterPoint objects.
"""
def __init__(self, name: str, points: list[T.Parameter]):
"""
Represents a single named set of parameter values.
Args:
name (str): Name of the parameter set.
points (list): List of parameter point configurations.
Attributes:
name (str): Name of the parameter set.
points (dict[str, ParameterPoint]): Mapping of parameter names to ParameterPoint objects.
"""
self.name = name
self.points: dict[str, ParameterPoint] = OrderedDict()
for points_config in points:
point = ParameterPoint(points_config["name"], points_config["value"])
self.points[point.name] = point
def __getitem__(self, item: str | int) -> ParameterPoint:
key = list(self.points.keys())[item] if isinstance(item, int) else item
return self.points[key]
def get(
self, item: str, default: TDefault | None = None
) -> ParameterPoint | TDefault | None:
"""Get a parameter point by name, returning default if not found."""
return self.points.get(item, default)
def __contains__(self, item: str) -> bool:
return item in self.points
def __iter__(self) -> Iterator[ParameterPoint]:
return iter(self.points.values())
def __len__(self) -> int:
return len(self.points)
@dataclass
class ParameterPoint:
"""
Represents a single parameter point.
Attributes:
name (str): Name of the parameter.
value (float): Value of the parameter.
"""
name: str
value: float
class DomainCollection:
"""
Collection of domain constraints for model parameters.
Manages domain sets that define valid ranges for model parameters.
Each domain set specifies minimum and maximum bounds for parameters,
which are used to create bounded tensor variables.
Attributes:
domains (dict[str, DomainSet]): Mapping from domain names to DomainSet objects.
"""
def __init__(self, domainsets: list[T.Domain]):
"""
Collection of named domain sets.
Args:
domainsets (list): List of domain set configurations.
Attributes:
domains (OrderedDict): Mapping of domain names to DomainSet objects.
"""
self.domains: dict[str, DomainSet] = OrderedDict()
for domain_config in domainsets:
domain = DomainSet(
domain_config["axes"], domain_config["name"], domain_config["type"]
)
self.domains[domain.name] = domain
def __getitem__(self, item: str | int) -> DomainSet:
key = list(self.domains.keys())[item] if isinstance(item, int) else item
return self.domains[key]
def get(
self, item: str, default: TDefault | None = None
) -> DomainSet | TDefault | None:
"""Get a domain set by name, returning default if not found."""
return self.domains.get(item, default)
def __contains__(self, item: str) -> bool:
return item in self.domains
def __iter__(self) -> Iterator[DomainSet]:
return iter(self.domains.values())
def __len__(self) -> int:
return len(self.domains)
@dataclass
class DomainPoint:
"""
Represents a valid domain (axis) for a single parameter.
Attributes:
name (str): Name of the parameter.
min (float): Minimum value.
max (float): Maximum value.
range (tuple): Computed range as (min, max), not included in serialization.
"""
name: str
min: float
max: float
range: tuple[float, float] = field(init=False, repr=False)
def __post_init__(self) -> None:
self.range = (self.min, self.max)
def to_dict(self) -> T.Axis:
"""
to dictionary
"""
return {"name": self.name, "min": self.min, "max": self.max}
class DomainSet:
"""
Set of parameter domain constraints with bounds.
Defines valid ranges for multiple parameters, specifying minimum
and maximum bounds for each. Used to create bounded tensor variables
that are automatically clipped to their valid ranges.
Attributes:
name (str): Name of the domain set.
kind (str): Type of the domain set.
domains (dict[str, Axis]): Mapping of parameter names to (min, max) tuples.
"""
def __init__(self, axes: list[T.Axis], name: str, kind: str):
"""
Represents a set of valid domains for parameters.
Args:
axes (list): List of domain configurations.
name (str): Name of the domain set.
kind (str): Type of the domain.
Attributes:
domains (OrderedDict): Mapping of parameter names to allowed ranges.
"""
self.name = name
self.kind = kind
self.domains: dict[str, Axis] = OrderedDict()
for axis_config in axes:
domain = DomainPoint(
axis_config["name"], axis_config["min"], axis_config["max"]
)
self.domains[domain.name] = domain.range
def __getitem__(self, item: int | str) -> Axis:
key = list(self.domains.keys())[item] if isinstance(item, int) else item
return self.domains[key]
def get(
self, item: str, default: TDefault | Axis = (None, None)
) -> Axis | TDefault:
"""Get domain bounds for a parameter, returning default if not found."""
return self.domains.get(item, default)
def __contains__(self, item: str) -> bool:
return item in self.domains
def __iter__(self) -> Iterator[Axis]:
return iter(self.domains.values())
def __len__(self) -> int:
return len(self.domains)
def boundedscalar(name: str, domain: Axis) -> T.TensorVar:
"""
Creates a scalar tensor variable with optional domain constraints.
Args:
name: Name of the scalar parameter.
domain (tuple): Tuple specifying (min, max) range. Use None for unbounded sides.
For example: (0.0, None) for lower bound only, (None, 1.0) for upper bound only.
If both bounds are None, returns an unbounded scalar.
Returns:
pytensor.tensor.variable.TensorVariable: The scalar tensor, clipped to domain if bounds exist.
Examples:
>>> boundedscalar("sigma", (0.0, None)) # sigma >= 0
>>> boundedscalar("fraction", (0.0, 1.0)) # 0 <= fraction <= 1
>>> boundedscalar("temperature", (None, 100.0)) # temperature <= 100
>>> boundedscalar("unbounded", (None, None)) # no bounds applied
"""
min_bound, max_bound = domain
# Create the base scalar tensor
tensor = pt.scalar(name)
# If both bounds are None, return unbounded scalar
if min_bound is None and max_bound is None:
return cast(T.TensorVar, tensor)
# Use infinity constants for unbounded sides
min_val = pt.constant(-np.inf) if min_bound is None else pt.constant(min_bound)
max_val = pt.constant(np.inf) if max_bound is None else pt.constant(max_bound)
clipped = pt.clip(tensor, min_val, max_val)
clipped.name = tensor.name # Preserve the original name
return cast(T.TensorVar, clipped)