from __future__ import annotations
import logging
from functools import lru_cache
from typing import TYPE_CHECKING, Any, cast
import pytensor.tensor as pt
import sympy as sp
from pydantic import ConfigDict, Field, PrivateAttr, model_validator
from pytensor.tensor.type import TensorType
from sympy.parsing import sympy_parser
from sympy.parsing.sympy_parser import (
auto_number,
auto_symbol,
factorial_notation,
lambda_notation,
repeated_decimals,
)
from pyhs3.exceptions import ExpressionEvaluationError, ExpressionParseError
if TYPE_CHECKING:
from pyhs3.context import Context
from pyhs3.typing.aliases import TensorVar
_transformations = (
auto_symbol,
lambda_notation,
repeated_decimals,
auto_number,
factorial_notation,
)
log = logging.getLogger(__name__)
[docs]
def analyze_sympy_expr(sympy_expr: sp.Expr) -> dict[str, Any]:
"""
Analyzes a SymPy expression and logs its independent variables,
dependent variables, and structure for debugging.
Args:
sympy_expr: The SymPy expression to analyze.
Returns:
Dictionary containing analysis results with keys:
- 'expression': The original expression
- 'independent_vars': Set of independent variables (symbols)
- 'dependent_vars': Set of dependent variables (functions)
"""
# Independent variables (symbols in the expression)
independent_vars = sympy_expr.free_symbols
# Dependent variables (functions of symbols)
dependent_vars = sympy_expr.atoms(sp.Function)
# Log information for debugging
log.debug("Expression: %s", sympy_expr)
log.debug("Independent Variables: %s", independent_vars)
log.debug("Dependent Variables: %s", dependent_vars)
if log.isEnabledFor(logging.DEBUG):
log.debug("Expression Structure:\n%s", sp.pretty(sympy_expr))
return {
"expression": sympy_expr,
"independent_vars": independent_vars,
"dependent_vars": dependent_vars,
}
[docs]
def parse_expression(expr_str: str) -> sp.Expr:
"""
Parse a mathematical expression string into a SymPy expression.
Args:
expr_str: The mathematical expression as a string.
Returns:
SymPy expression object.
Raises:
ExpressionParseError: If the expression cannot be parsed.
"""
try:
return _parse_expr_cached(expr_str)
except Exception as exc:
msg = f"Failed to parse expression '{expr_str}': {exc}"
raise ExpressionParseError(msg) from exc
[docs]
def sympy_to_pytensor(
sympy_expr: sp.Expr,
variables: list[pt.variable.TensorVariable[TensorType, Any]],
) -> pt.variable.TensorVariable[TensorType, Any]:
"""
Converts a SymPy expression into a PyTensor computational graph using lambdify.
Args:
sympy_expr: The SymPy expression object.
variables: List of PyTensor variables.
Returns:
PyTensor expression.
Raises:
ExpressionEvaluationError: If the expression cannot be converted or contains unsupported operations.
"""
try:
# Define the mapping for SymPy functions to PyTensor functions (using pt.math)
custom_modules = {
"sin": pt.math.sin,
"cos": pt.math.cos,
"tan": pt.math.tan,
"exp": pt.math.exp,
"log": pt.math.log,
"sqrt": pt.math.sqrt,
"abs": pt.math.abs,
"erf": pt.math.erf,
"min": pt.math.minimum,
"max": pt.math.maximum,
}
# Convert variable names to SymPy symbols
sympy_vars = {var.name: sp.Symbol(var.name) for var in variables}
# Log the expression for debugging
analyze_sympy_expr(sympy_expr)
# Convert SymPy expression to a PyTensor-compatible function
pytensor_func = sp.lambdify(
list(sympy_vars.values()), sympy_expr, modules=custom_modules
)
# Apply the function to PyTensor variables
result = pytensor_func(*variables)
# Handle case where result is a constant (not a PyTensor variable)
if not isinstance(result, pt.variable.TensorVariable):
result = pt.constant(result)
return cast(pt.variable.TensorVariable[TensorType, Any], result)
except Exception as exc:
msg = f"Failed to convert expression to PyTensor: {sympy_expr}. {exc}"
raise ExpressionEvaluationError(msg) from exc
@lru_cache(maxsize=2048)
def _parse_expr_cached(expr_str: str) -> sp.Expr:
# Delegate to SymPy's parser; results are cached to avoid repeated work
return sympy_parser.parse_expr(expr_str, transformations=_transformations)
class GenericExpressionMixin:
"""
Mixin for pydantic models that evaluate a user-supplied math expression string.
Inherit this mixin **before** the pydantic ``BaseModel`` (or subclass such as
:class:`~pyhs3.base.Evaluable`) in the MRO, e.g.::
class MyDist(GenericExpressionMixin, Distribution): ...
The mixin automatically provides:
- ``model_config`` — allows ``sp.Expr`` private attributes and alias-based
serialisation.
- ``expression_str`` — the raw expression string, aliased to ``"expression"``
in serialised form.
- ``setup_expression`` — a ``@model_validator(mode="after")`` that parses and
analyses the expression at construction time.
Subclasses do **not** need to redeclare any of these.
Attributes:
model_config: Pydantic config that permits ``sp.Expr`` private attrs.
expression_str: The raw mathematical expression string (aliased to
``"expression"`` in serialised form).
"""
model_config = ConfigDict(arbitrary_types_allowed=True, serialize_by_alias=True)
expression_str: str = Field(alias="expression", repr=False)
_sympy_expr: sp.Expr = PrivateAttr(default=None)
@model_validator(mode="after")
def setup_expression(self) -> GenericExpressionMixin:
"""Parse and analyze the expression during initialization."""
self._sympy_expr = parse_expression(self.expression_str)
analysis = analyze_sympy_expr(self._sympy_expr)
independent_vars = [str(symbol) for symbol in analysis["independent_vars"]]
# _parameters is provided by Evaluable (for distributions/functions)
self._parameters = {var: var for var in independent_vars}
return self
def _eval_expression(self, context: Context) -> TensorVar:
"""Evaluate the parsed expression using variables from *context*."""
variables = [context[name] for name in self._parameters.values()]
return cast("TensorVar", sympy_to_pytensor(self._sympy_expr, variables))