from __future__ import annotations
import json
import logging
import os
from collections.abc import Callable
from pathlib import Path
from typing import Any, TypeAlias, TypeVar, cast
import numpy as np
import numpy.typing as npt
import pytensor.tensor as pt
from pydantic import BaseModel, ConfigDict, Field
from pytensor.compile.function import function
from pytensor.graph.basic import applys_between, graph_inputs
from rich.progress import (
BarColumn,
Progress,
SpinnerColumn,
TaskProgressColumn,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
)
from pyhs3.distributions import Distributions
from pyhs3.domains import Domain, Domains, ProductDomain
from pyhs3.functions import Functions
from pyhs3.metadata import Metadata
from pyhs3.networks import build_dependency_graph
from pyhs3.parameter_points import ParameterPoints, ParameterSet
from pyhs3.typing.aliases import TensorVar
log = logging.getLogger(__name__)
TDefault = TypeVar("TDefault")
Axis: TypeAlias = tuple[float | None, float | None]
[docs]
class Workspace(BaseModel):
"""
Workspace for managing HS3 model specifications.
A workspace contains parameter points, distributions, domains, and functions
that define a probabilistic model. It provides methods to construct Model
objects with specific parameter values and domain constraints.
Attributes:
metadata: Required metadata containing HS3 version and optional attribution
distributions: List of distribution configurations
functions: List of function configurations
domains: List of domain configurations
parameter_points: List of parameter point configurations
data: List of data configurations
likelihoods: List of likelihood configurations
analyses: List of analysis configurations
misc: Arbitrary user-created information
parameter_collection (ParameterPoints): Named parameter sets.
distribution_set (Distributions): Available distributions.
domain_collection (Domains): Domain constraints for parameters.
function_set (Functions): Available functions for parameter computation.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
# Required field
metadata: Metadata
# Optional fields using discriminated unions
distributions: Distributions | None = Field(
default_factory=lambda: Distributions([])
)
functions: Functions | None = Field(default_factory=lambda: Functions([]))
domains: Domains | None = Field(default_factory=lambda: Domains([]))
parameter_points: ParameterPoints | None = Field(
default_factory=lambda: ParameterPoints([])
)
data: list[dict[str, Any]] | None = Field(default_factory=list)
likelihoods: list[dict[str, Any]] | None = Field(default_factory=list)
analyses: list[dict[str, Any]] | None = Field(default_factory=list)
misc: dict[str, Any] | None = Field(default_factory=dict)
@classmethod
def load(cls, path: str | os.PathLike[str]) -> Workspace:
"""
Load workspace from a JSON file.
Args:
path: Path to the JSON file containing the HS3 specification
Returns:
Workspace: The loaded workspace instance
"""
path_obj = Path(path)
with path_obj.open("r", encoding="utf-8") as f:
spec_dict = json.load(f)
return cls(**spec_dict)
def model(
self,
*,
domain: int | str | Domain = 0,
parameter_set: int | str | ParameterSet = 0,
progress: bool = True,
mode: str = "FAST_RUN",
) -> Model:
"""
Constructs a `Model` object using the provided domain and parameter set.
Args:
domain (int | str | Domain): Identifier or object specifying the domain to use.
parameter_set (int | str | ParameterSet): Identifier or object specifying the parameter values to use.
progress (bool): Whether to show progress bar during dependency graph construction. Defaults to True.
mode (str): PyTensor compilation mode. Defaults to "FAST_RUN".
Options: "FAST_RUN" (apply all rewrites, use C implementations),
"FAST_COMPILE" (few rewrites, Python implementations),
"NUMBA" (compile using Numba), "JAX" (compile using JAX),
"PYTORCH" (compile using PyTorch), "DebugMode" (debugging),
"NanGuardMode" (NaN detection).
Returns:
Model: The constructed model object.
"""
selected_domain = (
domain
if isinstance(domain, Domain)
else self.domains[domain]
if self.domains
else ProductDomain(name="default")
)
parameterset = (
parameter_set
if isinstance(parameter_set, ParameterSet)
else self.parameter_points[parameter_set]
if self.parameter_points
else ParameterSet(name="default", parameters=[])
)
return Model(
parameterset=parameterset or ParameterSet(name="default"),
distributions=self.distributions or Distributions(),
domain=selected_domain or Domain(name="default", type="unknown"),
functions=self.functions or Functions(),
progress=progress,
mode=mode,
)
[docs]
class Model:
"""
Probabilistic model with compiled tensor operations.
A model represents a specific instantiation of a workspace with concrete
parameter values and domain constraints. It builds symbolic computation
graphs for distributions and functions, with optional compilation for
performance optimization.
The model handles dependency resolution between parameters, functions,
and distributions, ensuring proper evaluation order through topological
sorting of the computation graph.
"""
[docs]
def __init__(
self,
*,
parameterset: ParameterSet,
distributions: Distributions,
domain: Domain,
functions: Functions,
progress: bool = True,
mode: str = "FAST_RUN",
):
"""
Represents a probabilistic model composed of parameters, domains, distributions, and functions.
Args:
parameterset (ParameterSet): The parameter set used in the model.
distributions (Distributions): Set of distributions to include.
domain (Domain): Domain constraints for parameters.
functions (Functions): Set of functions that compute parameter values.
progress (bool): Whether to show progress bar during dependency graph construction.
mode (str): PyTensor compilation mode. Defaults to "FAST_RUN".
Options: "FAST_RUN" (apply all rewrites, use C implementations),
"FAST_COMPILE" (few rewrites, Python implementations),
"NUMBA" (compile using Numba), "JAX" (compile using JAX),
"PYTORCH" (compile using PyTorch), "DebugMode" (debugging),
"NanGuardMode" (NaN detection).
Attributes:
domain (Domain): The original domain with constraints for parameters.
parameterset (ParameterSet): The original parameter set with parameter values.
distributions (dict[str, pytensor.tensor.variable.TensorVariable]): Symbolic distribution expressions.
parameters (dict[str, pytensor.tensor.variable.TensorVariable]): Symbolic parameter variables.
functions (dict[str, pytensor.tensor.variable.TensorVariable]): Computed function values.
mode (str): PyTensor compilation mode.
_compiled_functions (dict[str, Callable[..., npt.NDArray[np.float64]]]): Cache of compiled PyTensor functions.
"""
self.parameterset = parameterset
self.domain = domain
self.parameters: dict[str, TensorVar] = {}
self.functions: dict[str, TensorVar] = {}
self.distributions: dict[str, TensorVar] = {}
self.mode = mode
self._compiled_functions: dict[str, Callable[..., npt.NDArray[np.float64]]] = {}
# Build dependency graph with proper entity identification
self._build_dependency_graph(functions, distributions, progress)
def _build_dependency_graph(
self,
functions: Functions,
distributions: Distributions,
progress: bool = True,
) -> None:
"""
Build and evaluate dependency graph for functions and distributions.
This method properly handles cross-references between functions, distributions,
and parameters by building a complete dependency graph first, then evaluating
in topological order.
"""
# Build dependency graph using the networks module
graph, constants_map = build_dependency_graph(
self.parameterset, functions, distributions
)
# Get topological order (handles cycle detection internally)
sorted_nodes = graph.topological_sort()
# Evaluate nodes in topological order with optional progress bar
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}", style="cyan"),
BarColumn(),
TaskProgressColumn(),
TimeElapsedColumn(),
TimeRemainingColumn(),
expand=True,
transient=True, # Progress bar disappears when finished
disable=not progress, # Disable progress bar if progress=False
) as progress_bar:
task = progress_bar.add_task(
"Building expressions...", total=len(sorted_nodes)
)
for node_idx in sorted_nodes:
node_data = graph[node_idx]
node_type = node_data["type"]
node_name = node_data["name"]
# Truncate long names to prevent jumpiness
max_name_length = 60
display_name = node_name
if len(node_name) > max_name_length:
display_name = node_name[: max_name_length - 3] + "..."
# Update progress description with current entity (fixed width)
progress_bar.update(
task,
description=f"Building {node_type:<12}: {display_name:<{max_name_length}}",
)
# Build context with all currently available entities
context = {**self.parameters, **self.functions, **self.distributions}
if node_type == "parameter":
# Create parameter tensor with domain bounds applied
domain_bounds = (
self.domain.get(node_name, (None, None))
if self.domain
else (None, None)
)
param_point = (
self.parameterset.get(node_name) if self.parameterset else None
)
param_kind = param_point.kind if param_point else pt.scalar
self.parameters[node_name] = create_bounded_tensor(
node_name, domain_bounds, param_kind
)
elif node_type == "constant":
# Constants are pre-created by distributions - add to parameters
self.parameters[node_name] = constants_map[node_name]
elif node_type == "function":
# Functions are evaluated by design
self.functions[node_name] = functions[node_name].expression(context)
elif node_type == "distribution":
# Distributions are evaluated by design
self.distributions[node_name] = distributions[node_name].expression(
context
)
# Advance progress
progress_bar.advance(task)
def _get_compiled_function(
self, name: str
) -> Callable[..., npt.NDArray[np.float64]]:
"""
Get or create a compiled PyTensor function for the specified distribution.
Args:
name (str): Name of the distribution.
Returns:
Callable: Compiled PyTensor function.
"""
if name not in self._compiled_functions:
dist = self.distributions[name]
inputs = [var for var in graph_inputs([dist]) if var.name is not None]
# Use the specified PyTensor mode
compilation_mode = self.mode
self._compiled_functions[name] = cast(
Callable[..., npt.NDArray[np.float64]],
function(
inputs=inputs,
outputs=dist,
mode=compilation_mode,
on_unused_input="ignore",
), # type: ignore[no-untyped-call]
)
return self._compiled_functions[name]
def pdf(self, name: str, **parametervalues: float) -> npt.NDArray[np.float64]:
"""
Evaluates the probability density function of the specified distribution.
Args:
name (str): Name of the distribution to evaluate.
**parametervalues (float): Values for each distribution parameter.
Returns:
npt.NDArray[np.float64]: The evaluated PDF value.
"""
if self.mode != "FAST_COMPILE":
# Use compiled function for better performance
func = self._get_compiled_function(name)
inputs = [
var
for var in graph_inputs([self.distributions[name]])
if var.name is not None
]
positional_values = []
for var in inputs:
assert var.name is not None
positional_values.append(parametervalues[var.name])
return func(*positional_values)
# Use original uncompiled approach
dist = self.distributions[name]
inputs = [var for var in graph_inputs([dist]) if var.name is not None]
keyword_values: dict[str, float] = {}
for var in inputs:
assert var.name is not None
keyword_values[var.name] = parametervalues[var.name]
func = cast(
Callable[..., npt.NDArray[np.float64]],
function(inputs=inputs, outputs=dist), # type: ignore[no-untyped-call]
)
return func(**keyword_values)
def logpdf(self, name: str, **parametervalues: float) -> npt.NDArray[np.float64]:
"""
Evaluates the natural logarithm of the PDF.
Args:
name (str): Name of the distribution to evaluate.
**parametervalues (float): Values for each distribution parameter.
Returns:
npt.NDArray[np.float64]: The log of the PDF.
"""
return np.log(self.pdf(name, **parametervalues))
def visualize_graph(
self,
name: str,
fmt: str = "svg",
outfile: str | None = None,
path: str | None = None,
) -> str:
"""
Visualize the computation graph for a distribution.
Args:
name (str): Distribution name.
fmt (str): Output format ('svg', 'png', 'pdf'). Defaults to 'svg'.
outfile (str | None): Output filename. If None, uses '{name}_graph.{fmt}'.
path (str | None): Directory path for output. If None, uses current working directory.
Returns:
str: Path to the generated visualization file.
Raises:
ImportError: If pydot is not installed.
"""
try:
from pytensor.printing import ( # noqa: PLC0415 # pylint: disable=import-outside-toplevel
pydotprint,
)
except ImportError as e:
msg = "Graph visualization requires pydot. Install with: pip install pydot"
raise ImportError(msg) from e
if name not in self.distributions:
msg = f"Distribution '{name}' not found in model"
raise ValueError(msg)
dist = self.distributions[name]
if outfile is not None:
filename = outfile
else:
base_filename = f"{name}_graph.{fmt}"
if path is not None:
filename = str(Path(path) / base_filename)
else:
filename = base_filename
pydotprint( # type: ignore[no-untyped-call]
dist, outfile=filename, format=fmt, with_ids=True, high_contrast=True
)
return filename
def __repr__(self) -> str:
"""Provide a concise overview of the model structure."""
param_names = list(self.parameters.keys())
dist_names = list(self.distributions.keys())
func_names = list(self.functions.keys())
param_display = ", ".join(param_names[:5]) + (
"..." if len(param_names) > 5 else ""
)
dist_display = ", ".join(dist_names[:3]) + (
"..." if len(dist_names) > 3 else ""
)
func_display = ", ".join(func_names[:3]) + (
"..." if len(func_names) > 3 else ""
)
mode_status = self.mode
return f"""Model(
mode: {mode_status}
parameters: {len(param_names)} ({param_display})
distributions: {len(dist_names)} ({dist_display})
functions: {len(func_names)} ({func_display})
)"""
def graph_summary(self, name: str) -> str:
"""
Get a summary of the computation graph structure.
Args:
name (str): Distribution name.
Returns:
str: Summary of the graph structure.
"""
if name not in self.distributions:
msg = f"Distribution '{name}' not found in model"
raise ValueError(msg)
dist = self.distributions[name]
inputs = list(graph_inputs([dist]))
# Count different types of operations
applies = list(applys_between(inputs, [dist]))
op_types: dict[str, int] = {}
for apply in applies:
op_name = type(apply.op).__name__
op_types[op_name] = op_types.get(op_name, 0) + 1
compile_info = f"\n Mode: {self.mode}\n Compiled: {'Yes' if self.mode != 'FAST_COMPILE' and name in self._compiled_functions else 'No'}"
return f"""Distribution '{name}':
Input variables: {len(inputs)}
Graph operations: {len(applies)}
Operation types: {dict(sorted(op_types.items()))}{compile_info}
"""
def create_bounded_tensor(
name: str, domain: Axis, kind: Callable[..., TensorVar] = pt.scalar
) -> TensorVar:
"""
Creates a tensor variable with optional domain constraints.
Args:
name: Name of the parameter.
domain (tuple): Tuple specifying (min, max) range. Use None for unbounded sides.
For example: (0.0, None) for lower bound only, (None, 1.0) for upper bound only.
If both bounds are None, returns an unbounded tensor.
kind: pt.scalar for scalars, pt.vector for vectors (default: pt.scalar).
Returns:
pytensor.tensor.variable.TensorVariable: The tensor variable, clipped to domain if bounds exist.
Examples:
>>> sigma = create_bounded_tensor("sigma", (0.0, None)) # sigma >= 0 (scalar)
>>> fraction = create_bounded_tensor("fraction", (0.0, 1.0)) # 0 <= fraction <= 1 (scalar)
>>> temperatures = create_bounded_tensor("temperatures", (None, 100.0), pt.vector) # vector <= 100
>>> unbounded = create_bounded_tensor("unbounded", (None, None)) # no bounds applied
"""
min_bound, max_bound = domain
# Create the base tensor
tensor = kind(name)
# If both bounds are None, return unbounded tensor
if min_bound is None and max_bound is None:
return tensor
# Use infinity constants for unbounded sides
min_val = pt.constant(-np.inf) if min_bound is None else pt.constant(min_bound)
max_val = pt.constant(np.inf) if max_bound is None else pt.constant(max_bound)
clipped = pt.clip(tensor, min_val, max_val)
clipped.name = tensor.name # Preserve the original name
return cast(TensorVar, clipped)