from __future__ import annotations
import json
import logging
import os
import sys
from collections import Counter
from collections.abc import Iterable
from functools import singledispatchmethod
from pathlib import Path
from typing import Any, cast
from pydantic import BaseModel, ConfigDict, Field, ValidationError
from pyhs3.analyses import Analyses, Analysis
from pyhs3.data import Data, DataType
from pyhs3.distributions import Distributions, DistributionType, HistFactoryDistChannel
from pyhs3.distributions.histfactory.modifiers import (
ParameterModifier,
ParametersModifier,
)
from pyhs3.domains import Domain, Domains, DomainType, ProductDomain
from pyhs3.exceptions import WorkspaceValidationError
from pyhs3.functions import Functions
from pyhs3.likelihoods import Likelihood, Likelihoods
from pyhs3.metadata import Metadata
from pyhs3.model import Model
from pyhs3.parameter_points import ParameterPoints, ParameterSet
log = logging.getLogger(__name__)
[docs]
class Workspace(BaseModel):
"""
Workspace for managing HS3 model specifications.
A workspace contains parameter points, distributions, domains, and functions
that define a probabilistic model. It provides methods to construct Model
objects with specific parameter values and domain constraints.
Attributes:
metadata: Required metadata containing HS3 version and optional attribution
distributions: List of distribution configurations
functions: List of function configurations
domains: List of domain configurations
parameter_points: List of parameter point configurations
data: Data specifications for observations
likelihoods: Likelihood specifications mapping distributions to data
analyses: Analysis configurations for automated analyses
misc: Arbitrary user-created information
parameter_collection (ParameterPoints): Named parameter sets.
distribution_set (Distributions): Available distributions.
domain_collection (Domains): Domain constraints for parameters.
function_set (Functions): Available functions for parameter computation.
HS3 Reference:
See :hs3:label:`HS3 file format specification <hs3.file-format>` for the complete workspace structure.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
# Required field
metadata: Metadata
# Optional fields using discriminated unions
distributions: Distributions | None = Field(
default_factory=lambda: Distributions([])
)
functions: Functions | None = Field(default_factory=lambda: Functions([]))
domains: Domains | None = Field(default_factory=lambda: Domains([]))
parameter_points: ParameterPoints | None = Field(
default_factory=lambda: ParameterPoints([])
)
data: Data | None = Field(default_factory=lambda: Data([]))
likelihoods: Likelihoods | None = Field(default_factory=lambda: Likelihoods([]))
analyses: Analyses | None = Field(default_factory=lambda: Analyses([]))
misc: dict[str, Any] | None = Field(default_factory=dict)
def model_post_init(self, __context: Any, /) -> None:
"""Resolve foreign key references after workspace construction."""
self._resolve_foreign_keys()
def _resolve_foreign_keys(self) -> None:
"""Resolve string references to actual objects with referential integrity checking."""
errors: list[str] = []
# Resolve Likelihood fields first (analyses reference likelihoods)
if self.likelihoods is not None:
for likelihood in self.likelihoods:
self._resolve_likelihood_fields(likelihood, errors)
# Validate observable axis uniqueness after FK resolution
if self.likelihoods is not None:
for likelihood in self.likelihoods:
try:
likelihood.validate_unique_axis_names(self)
except ValueError as exc:
errors.append(str(exc))
# Validate HFDC constraint consistency across channels
if self.likelihoods is not None:
for likelihood in self.likelihoods:
try:
self._validate_hfdc_constraints(likelihood)
except ValueError as exc:
errors.append(str(exc))
# Resolve Analysis fields
if self.analyses is not None:
for analysis in self.analyses:
self._resolve_analysis_fields(analysis, errors)
if errors:
msg = "Workspace has unresolved references:\n" + "\n".join(
f" - {e}" for e in errors
)
raise WorkspaceValidationError(msg)
def _resolve_fk_list(
self,
refs: Iterable[Any],
collection: Distributions | Data | Domains,
parent_label: str,
entity_label: str,
errors: list[str],
) -> list[Any]:
"""Resolve string references in a list against a named collection."""
resolved: list[Any] = []
for ref in refs:
if isinstance(ref, str):
obj = collection.get(ref)
if obj is None:
errors.append(
f"{parent_label} references unknown {entity_label} '{ref}'"
)
else:
resolved.append(obj)
else:
resolved.append(ref)
return resolved
@staticmethod
def _check_hfdc_modifier(
modifier: object,
where: str,
channel_name: str,
param_constraint: dict[str, tuple[str, str]],
shapesys_owners: dict[str, str],
staterror_owners: dict[str, str],
) -> None:
"""Check one modifier's constraint for type consistency and ownership."""
if not hasattr(modifier, "constraint") or modifier.constraint is None:
return
constraint = modifier.constraint
if isinstance(modifier, ParameterModifier):
param = modifier.parameter
prev = param_constraint.get(param)
if prev is not None and prev[0] != constraint:
msg = (
f"Parameter '{param}' has conflicting constraint types: "
f"'{prev[1]}' declared '{prev[0]}', "
f"'{where}' declares '{constraint}'"
)
raise ValueError(msg)
param_constraint[param] = (constraint, where)
else:
multi_mod = cast(ParametersModifier, modifier)
owners = (
shapesys_owners if multi_mod.type == "shapesys" else staterror_owners
)
for param in multi_mod.parameters:
if param in owners and owners[param] != channel_name:
kind = multi_mod.type
msg = (
f"{kind} parameter '{param}' appears in both "
f"'{owners[param]}' and '{channel_name}'; "
f"{kind} is per-channel and may not be shared."
)
raise ValueError(msg)
owners[param] = channel_name
def _validate_hfdc_constraints(self, likelihood: Likelihood) -> None:
"""Validate constraint consistency across HFDC channels in a likelihood.
Rules enforced:
- A nuisance parameter may not have conflicting constraint types (e.g.,
Gauss in one channel, LogNormal in another).
- shapesys parameters must not be shared across channels (per-channel by design).
- staterror parameters must not be shared across channels (same reason).
Called after FK resolution so likelihood.distributions contains objects.
"""
# param -> (constraint_literal, "channel/sample/modifier" description)
param_constraint: dict[str, tuple[str, str]] = {}
shapesys_owners: dict[str, str] = {}
staterror_owners: dict[str, str] = {}
for dist_obj in likelihood.distributions:
if isinstance(dist_obj, str) or not isinstance(
dist_obj, HistFactoryDistChannel
):
continue
for sample in dist_obj.samples:
for modifier in sample.modifiers:
where = f"{dist_obj.name}/{sample.name}/{modifier.name}"
self._check_hfdc_modifier(
modifier,
where,
dist_obj.name,
param_constraint,
shapesys_owners,
staterror_owners,
)
def _resolve_likelihood_fields(
self, likelihood: Likelihood, errors: list[str]
) -> None:
"""Resolve foreign key fields on a Likelihood."""
# Resolve distributions
if self.distributions is not None:
resolved = self._resolve_fk_list(
likelihood.distributions,
self.distributions,
f"Likelihood '{likelihood.name}'",
"distribution",
errors,
)
likelihood.distributions = Distributions(
cast(list[DistributionType], resolved)
)
if likelihood.aux_distributions is not None:
resolved_aux = self._resolve_fk_list(
likelihood.aux_distributions,
self.distributions,
f"Likelihood '{likelihood.name}'",
"aux_distribution",
errors,
)
likelihood.aux_distributions = Distributions(
cast(list[DistributionType], resolved_aux)
)
else:
errors.append(
f"Likelihood '{likelihood.name}' references unknown distributions"
)
# Resolve data
if self.data is not None:
resolved = self._resolve_fk_list(
likelihood.data,
self.data,
f"Likelihood '{likelihood.name}'",
"data",
errors,
)
likelihood.data = Data(cast(list[DataType], resolved))
else:
errors.append(f"Likelihood '{likelihood.name}' references unknown data")
def _resolve_analysis_fields(self, analysis: Analysis, errors: list[str]) -> None:
"""Resolve foreign key fields on an Analysis."""
# Resolve likelihood
if self.likelihoods is not None:
if isinstance(analysis.likelihood, str):
lk = self.likelihoods.get(analysis.likelihood)
if lk is None:
errors.append(
f"Analysis '{analysis.name}' references unknown likelihood '{analysis.likelihood}'"
)
else:
analysis.likelihood = lk
else:
errors.append(
f"Analysis '{analysis.name}' references unknown likelihood '{analysis.likelihood}'"
)
# Resolve domains
if self.domains is not None:
resolved = self._resolve_fk_list(
analysis.domains,
self.domains,
f"Analysis '{analysis.name}'",
"domain",
errors,
)
analysis.domains = Domains(cast(list[DomainType], resolved))
else:
errors.append(f"Analysis '{analysis.name}' references unknown domains")
@classmethod
def load(
cls,
path: str | os.PathLike[str],
*,
verbose: bool = False,
suppress_traceback: bool = True,
) -> Workspace:
"""
Load workspace from a JSON file.
Args:
path: Path to the JSON file containing the HS3 specification
verbose: If True, show all errors. If False, show first 20 and summarize rest.
suppress_traceback: If True, suppress traceback on validation errors (default True).
Returns:
Workspace: The loaded workspace instance
"""
path_obj = Path(path)
with path_obj.open("r", encoding="utf-8") as f:
spec_dict = json.load(f)
try:
return cls(**spec_dict)
except ValidationError as e:
error_summary = cls._format_validation_error(e, path, verbose)
if suppress_traceback:
sys.tracebacklimit = 0
raise WorkspaceValidationError(error_summary) from None
@classmethod
def _format_validation_error(
cls,
validation_error: ValidationError,
path: str | os.PathLike[str],
verbose: bool,
) -> str:
"""
Format a ValidationError into a readable error summary.
Args:
validation_error: The ValidationError to format
path: Path to the file that caused the error
verbose: If True, show all errors. If False, show first 20 and summarize rest.
Returns:
Formatted error message string
"""
errors = validation_error.errors()
error_count = len(errors)
error_types: Counter[str] = Counter()
loc_errors: Counter[tuple[str, ...]] = Counter()
for error in errors:
error_types[error["type"]] += 1
loc_errors[
tuple("#" if isinstance(key, int) else key for key in error["loc"])
] += 1
# Build error summary using list for efficient string concatenation
parts = [
f"Workspace validation failed with {error_count} errors from {path}\n",
"\nError breakdown by type:\n",
]
for error_type, count in error_types.most_common():
parts.append(f" {error_type}: {count}\n")
parts.append("\nError breakdown by component:\n")
for loc, count in loc_errors.most_common():
loc_str = ".".join(loc)
parts.append(f" {loc_str}: {count}\n")
# Show detailed errors with improved formatting
errors_to_show = errors if verbose else errors[:20]
parts.append(f"\nErrors for debugging ({'all' if verbose else 'first 20'}):\n")
for i, error in enumerate(errors_to_show):
# Format location more readably using list comprehension
loc_parts = [
f"[{part}]" if isinstance(part, int) else str(part)
for part in error.get("loc", [])
]
# Build readable location string
if not loc_parts:
readable_loc = ""
else:
readable_loc = loc_parts[0]
for part in loc_parts[1:]:
if part.startswith("["):
readable_loc += part # Index directly follows
else:
readable_loc += f" -> {part}"
# Add name from input if available
# Add name from input if available
input_data: Any = error.get("input", {})
if isinstance(input_data, dict) and "name" in input_data:
name = input_data["name"]
if readable_loc and not readable_loc.endswith("]"):
readable_loc += f"('{name}')"
msg = error.get("msg", "Unknown error")
parts.append(f" {i + 1}. {readable_loc}: {msg}\n")
if not verbose and error_count > 20:
parts.append(
f" ... and {error_count - 20} more errors (use verbose=True to see all)\n"
)
return "".join(parts)
def _compute_observables(self) -> dict[str, tuple[float, float]]:
"""
Extract observable names and bounds from likelihoods + data + domain.
Walks likelihoods to find distribution-data pairings. For each dataset axis,
gets bounds from the data axis itself (axis.min/max). Propagates observable
info through composite distributions (MixtureDist, ProductDist).
Returns:
Dictionary mapping observable names to (min, max) tuples
"""
observables: dict[str, tuple[float, float]] = {}
if not self.likelihoods or not self.data:
return observables
# For each likelihood, extract observable axes from paired data
for likelihood in self.likelihoods:
for data_item in likelihood.data:
# FK resolution guarantees data items are resolved objects
datum = (
data_item
if not isinstance(data_item, str)
else self.data[data_item]
)
if datum.axes is None:
log.warning(
"The likelihood '%s' references data '%s' without axes. This cannot be used to normalize any distribution.",
likelihood.name,
datum.name,
)
continue
# For each axis, extract bounds
for axis in datum.axes:
observables[axis.name] = (axis.min, axis.max)
return observables
@staticmethod
def _extract_observables(likelihood: Likelihood) -> dict[str, tuple[float, float]]:
"""Return {axis_name: (min, max)} for all data axes in a likelihood."""
return {
axis.name: (axis.min, axis.max)
for datum in likelihood.data
if not isinstance(datum, str)
for axis in datum.axes or []
}
def _select_parameterset(
self,
parameter_set: int | str | ParameterSet | None,
*,
fallback_first: bool = True,
) -> ParameterSet:
"""Resolve *parameter_set* to a :class:`~pyhs3.parameter_points.ParameterSet`.
Args:
parameter_set: Explicit override -- a ``ParameterSet`` instance, an int
or str index into ``self.parameter_points``, or ``None`` to fall back.
fallback_first: When *parameter_set* is ``None`` and ``fallback_first``
is ``True`` (the default), fall back to ``parameter_points[0]``.
Pass ``False`` to return an empty default instead (used by the
``Analysis`` path which manages its own ``init`` fallback).
"""
if isinstance(parameter_set, ParameterSet):
return parameter_set
if parameter_set is not None:
if not self.parameter_points:
msg = f"parameter_set={parameter_set!r} was requested but no parameter_points are available in this workspace"
raise ValueError(msg)
return self.parameter_points[parameter_set]
if fallback_first and self.parameter_points:
return self.parameter_points[0]
return ParameterSet(name="default", parameters=[])
def _select_domain(
self,
domain: int | str | Domain | None,
default_index: int | str | None = None,
) -> Domain:
"""Resolve *domain* to a :class:`~pyhs3.domains.Domain`.
Args:
domain: Explicit override (a ``Domain`` instance, int, or str key) or
``None`` to use *default_index*.
default_index: Index/key to use when *domain* is ``None``. If both are
``None``, ``default_domain`` is preferred when present before
falling back to the first domain. If no domain collection exists,
returns a default ``ProductDomain``.
"""
if isinstance(domain, Domain):
return domain
if domain is not None:
if not self.domains:
msg = f"domain={domain!r} was requested but no domains are available in this workspace"
raise ValueError(msg)
return self.domains[domain]
if default_index is not None and self.domains:
return self.domains[default_index]
if self.domains:
default_domain = self.domains.get("default_domain")
if default_domain is not None:
return default_domain
return self.domains[0]
return ProductDomain(name="default")
@singledispatchmethod
def model(
self,
target: int,
*,
domain: int | str | Domain | None = None,
parameter_set: int | str | ParameterSet | None = None,
progress: bool = True,
mode: str = "FAST_RUN",
) -> Model:
"""
Constructs a :class:`~pyhs3.model.Model` rooted at ``target``.
Dispatch is based on the type of ``target``:
- :class:`~pyhs3.analyses.Analysis` — all context (domain, parameter
set, observables) is derived from the analysis; gains access to
:attr:`~pyhs3.model.Model.log_prob`, :attr:`~pyhs3.model.Model.data`,
and :attr:`~pyhs3.model.Model.free_params`.
- :class:`~pyhs3.likelihoods.Likelihood` — observable bounds are derived
from the likelihood's data; ``domain`` falls back to ``default_domain``
then index 0, and ``parameter_set`` falls back to workspace defaults
unless overridden.
- ``str`` — searches analyses then likelihoods by name; delegates to the
appropriate registered path. Falls back to legacy domain indexing if
the name is not found in either.
- ``int`` — legacy path: ``target`` indexes into workspace domains.
Args:
target: Dispatch key. Pass an
:class:`~pyhs3.analyses.Analysis` or
:class:`~pyhs3.likelihoods.Likelihood` for the modern paths,
a name string to search analyses/likelihoods, or an ``int``
domain index for the legacy path.
domain: Override domain (legacy and Likelihood paths only).
parameter_set: Override parameter set (legacy and Likelihood paths only).
progress: Whether to show a progress bar during graph construction.
mode: PyTensor compilation mode (default ``"FAST_RUN"``).
Returns:
:class:`~pyhs3.model.Model`: The constructed model.
"""
# Legacy int path: target indexes into workspace domains.
selected_domain = self._select_domain(domain, default_index=target)
parameterset = self._select_parameterset(parameter_set)
return Model(
parameterset=parameterset,
distributions=self.distributions or Distributions(),
domain=selected_domain or Domain(name="default", type="unknown"),
functions=self.functions or Functions(),
progress=progress,
mode=mode,
observables=self._compute_observables(),
likelihood=None,
)
@model.register
def _(
self,
target: Analysis,
*,
parameter_set: int | str | ParameterSet | None = None,
progress: bool = True,
mode: str = "FAST_RUN",
) -> Model:
# _resolve_foreign_keys guarantees both are resolved objects by construction.
likelihood_obj = cast(Likelihood, target.likelihood)
domains = cast(Domains, target.domains)
if len(domains) == 1:
analysis_domain: Domain = domains[0]
else:
# Merge all domain axes into one ProductDomain
all_axes = [ax for d in domains for ax in getattr(d, "axes", [])]
analysis_domain = ProductDomain(name=f"{target.name}_merged", axes=all_axes) # type: ignore[arg-type]
if target.init:
if self.parameter_points is None:
msg = f"Analysis '{target.name}' requires parameter set '{target.init}' but workspace has no parameter_points"
raise ValueError(msg)
param_set = self.parameter_points.get(target.init)
if param_set is None:
msg = f"Analysis '{target.name}' references unknown parameter set '{target.init}'"
raise ValueError(msg)
else:
param_set = None
# Explicit override takes priority; otherwise use analysis.init param_set or empty default.
# Do NOT fall back to parameter_points[0] when neither init nor override was given —
# that would silently impose workspace defaults that the caller did not request.
if parameter_set is not None:
parameterset = self._select_parameterset(
parameter_set, fallback_first=False
)
else:
parameterset = param_set or ParameterSet(name="default", parameters=[])
return Model(
parameterset=parameterset,
distributions=self.distributions or Distributions(),
domain=analysis_domain,
functions=self.functions or Functions(),
progress=progress,
mode=mode,
observables=self._extract_observables(likelihood_obj),
likelihood=likelihood_obj,
)
@model.register
def _(
self,
target: Likelihood,
*,
domain: int | str | Domain | None = None,
parameter_set: int | str | ParameterSet | None = None,
progress: bool = True,
mode: str = "FAST_RUN",
) -> Model:
selected_domain = self._select_domain(domain)
parameterset = self._select_parameterset(parameter_set)
return Model(
parameterset=parameterset,
distributions=self.distributions or Distributions(),
domain=selected_domain or Domain(name="default", type="unknown"),
functions=self.functions or Functions(),
progress=progress,
mode=mode,
observables=self._extract_observables(target),
likelihood=target,
)
@model.register
def _(
self,
target: str,
*,
domain: int | str | Domain | None = None,
parameter_set: int | str | ParameterSet | None = None,
progress: bool = True,
mode: str = "FAST_RUN",
) -> Model:
# Search analyses first, then likelihoods; fall back to legacy domain indexing.
if self.analyses:
analysis = self.analyses.get(target)
if analysis is not None:
if domain is not None:
msg = "domain override not supported when target resolves to an analysis"
raise ValueError(msg)
return self.model(
analysis,
parameter_set=parameter_set,
progress=progress,
mode=mode,
)
if self.likelihoods:
likelihood = self.likelihoods.get(target)
if likelihood is not None:
return self.model(
likelihood,
domain=domain,
parameter_set=parameter_set,
progress=progress,
mode=mode,
)
# Legacy fallback: treat target as a domain name.
selected_domain = self._select_domain(domain, default_index=target)
parameterset = self._select_parameterset(parameter_set)
return Model(
parameterset=parameterset,
distributions=self.distributions or Distributions(),
domain=selected_domain or Domain(name="default", type="unknown"),
functions=self.functions or Functions(),
progress=progress,
mode=mode,
observables=self._compute_observables(),
likelihood=None,
)