Source code for pyhs3.generic_parse

from __future__ import annotations

import logging
from functools import lru_cache
from typing import TYPE_CHECKING, Any, cast

import pytensor.tensor as pt
import sympy as sp
from pydantic import ConfigDict, Field, PrivateAttr, model_validator
from pytensor.tensor.type import TensorType
from sympy.parsing import sympy_parser
from sympy.parsing.sympy_parser import (
    auto_number,
    auto_symbol,
    factorial_notation,
    lambda_notation,
    repeated_decimals,
)

from pyhs3.exceptions import ExpressionEvaluationError, ExpressionParseError

if TYPE_CHECKING:
    from pyhs3.context import Context
    from pyhs3.typing.aliases import TensorVar

_transformations = (
    auto_symbol,
    lambda_notation,
    repeated_decimals,
    auto_number,
    factorial_notation,
)

log = logging.getLogger(__name__)


[docs] def analyze_sympy_expr(sympy_expr: sp.Expr) -> dict[str, Any]: """ Analyzes a SymPy expression and logs its independent variables, dependent variables, and structure for debugging. Args: sympy_expr: The SymPy expression to analyze. Returns: Dictionary containing analysis results with keys: - 'expression': The original expression - 'independent_vars': Set of independent variables (symbols) - 'dependent_vars': Set of dependent variables (functions) """ # Independent variables (symbols in the expression) independent_vars = sympy_expr.free_symbols # Dependent variables (functions of symbols) dependent_vars = sympy_expr.atoms(sp.Function) # Log information for debugging log.debug("Expression: %s", sympy_expr) log.debug("Independent Variables: %s", independent_vars) log.debug("Dependent Variables: %s", dependent_vars) if log.isEnabledFor(logging.DEBUG): log.debug("Expression Structure:\n%s", sp.pretty(sympy_expr)) return { "expression": sympy_expr, "independent_vars": independent_vars, "dependent_vars": dependent_vars, }
[docs] def parse_expression(expr_str: str) -> sp.Expr: """ Parse a mathematical expression string into a SymPy expression. Args: expr_str: The mathematical expression as a string. Returns: SymPy expression object. Raises: ExpressionParseError: If the expression cannot be parsed. """ try: return _parse_expr_cached(expr_str) except Exception as exc: msg = f"Failed to parse expression '{expr_str}': {exc}" raise ExpressionParseError(msg) from exc
[docs] def sympy_to_pytensor( sympy_expr: sp.Expr, variables: list[pt.variable.TensorVariable[TensorType, Any]], ) -> pt.variable.TensorVariable[TensorType, Any]: """ Converts a SymPy expression into a PyTensor computational graph using lambdify. Args: sympy_expr: The SymPy expression object. variables: List of PyTensor variables. Returns: PyTensor expression. Raises: ExpressionEvaluationError: If the expression cannot be converted or contains unsupported operations. """ try: # Define the mapping for SymPy functions to PyTensor functions (using pt.math) custom_modules = { "sin": pt.math.sin, "cos": pt.math.cos, "tan": pt.math.tan, "exp": pt.math.exp, "log": pt.math.log, "sqrt": pt.math.sqrt, "abs": pt.math.abs, "erf": pt.math.erf, "min": pt.math.minimum, "max": pt.math.maximum, } # Convert variable names to SymPy symbols sympy_vars = {var.name: sp.Symbol(var.name) for var in variables} # Log the expression for debugging analyze_sympy_expr(sympy_expr) # Convert SymPy expression to a PyTensor-compatible function pytensor_func = sp.lambdify( list(sympy_vars.values()), sympy_expr, modules=custom_modules ) # Apply the function to PyTensor variables result = pytensor_func(*variables) # Handle case where result is a constant (not a PyTensor variable) if not isinstance(result, pt.variable.TensorVariable): result = pt.constant(result) return cast(pt.variable.TensorVariable[TensorType, Any], result) except Exception as exc: msg = f"Failed to convert expression to PyTensor: {sympy_expr}. {exc}" raise ExpressionEvaluationError(msg) from exc
@lru_cache(maxsize=2048) def _parse_expr_cached(expr_str: str) -> sp.Expr: # Delegate to SymPy's parser; results are cached to avoid repeated work return sympy_parser.parse_expr(expr_str, transformations=_transformations) class GenericExpressionMixin: """ Mixin for pydantic models that evaluate a user-supplied math expression string. Inherit this mixin **before** the pydantic ``BaseModel`` (or subclass such as :class:`~pyhs3.base.Evaluable`) in the MRO, e.g.:: class MyDist(GenericExpressionMixin, Distribution): ... The mixin automatically provides: - ``model_config`` — allows ``sp.Expr`` private attributes and alias-based serialisation. - ``expression_str`` — the raw expression string, aliased to ``"expression"`` in serialised form. - ``setup_expression`` — a ``@model_validator(mode="after")`` that parses and analyses the expression at construction time. Subclasses do **not** need to redeclare any of these. Attributes: model_config: Pydantic config that permits ``sp.Expr`` private attrs. expression_str: The raw mathematical expression string (aliased to ``"expression"`` in serialised form). """ model_config = ConfigDict(arbitrary_types_allowed=True, serialize_by_alias=True) expression_str: str = Field(alias="expression", repr=False) _sympy_expr: sp.Expr = PrivateAttr(default=None) @model_validator(mode="after") def setup_expression(self) -> GenericExpressionMixin: """Parse and analyze the expression during initialization.""" self._sympy_expr = parse_expression(self.expression_str) analysis = analyze_sympy_expr(self._sympy_expr) independent_vars = [str(symbol) for symbol in analysis["independent_vars"]] # _parameters is provided by Evaluable (for distributions/functions) self._parameters = {var: var for var in independent_vars} return self def _eval_expression(self, context: Context) -> TensorVar: """Evaluate the parsed expression using variables from *context*.""" variables = [context[name] for name in self._parameters.values()] return cast("TensorVar", sympy_to_pytensor(self._sympy_expr, variables))