Source code for pyhs3.generic_parse

from __future__ import annotations

import logging
from typing import Any, cast

import pytensor.tensor as pt
import sympy as sp
from pytensor.tensor.type import TensorType
from sympy.parsing import sympy_parser
from sympy.parsing.sympy_parser import (
    auto_number,
    auto_symbol,
    factorial_notation,
    lambda_notation,
    repeated_decimals,
)

from pyhs3.exceptions import ExpressionEvaluationError, ExpressionParseError

log = logging.getLogger(__name__)


[docs] def analyze_sympy_expr(sympy_expr: sp.Expr) -> dict[str, Any]: """ Analyzes a SymPy expression and logs its independent variables, dependent variables, and structure for debugging. Args: sympy_expr: The SymPy expression to analyze. Returns: Dictionary containing analysis results with keys: - 'expression': The original expression - 'independent_vars': Set of independent variables (symbols) - 'dependent_vars': Set of dependent variables (functions) """ # Independent variables (symbols in the expression) independent_vars = sympy_expr.free_symbols # Dependent variables (functions of symbols) dependent_vars = sympy_expr.atoms(sp.Function) # Log information for debugging log.debug("Expression: %s", sympy_expr) log.debug("Independent Variables: %s", independent_vars) log.debug("Dependent Variables: %s", dependent_vars) log.debug("Expression Structure:\n%s", sp.pretty(sympy_expr)) return { "expression": sympy_expr, "independent_vars": independent_vars, "dependent_vars": dependent_vars, }
[docs] def parse_expression(expr_str: str) -> sp.Expr: """ Parse a mathematical expression string into a SymPy expression. Args: expr_str: The mathematical expression as a string. Returns: SymPy expression object. Raises: ExpressionParseError: If the expression cannot be parsed. """ transformations = ( auto_symbol, lambda_notation, repeated_decimals, auto_number, factorial_notation, ) try: return sympy_parser.parse_expr(expr_str, transformations=transformations) except Exception as exc: msg = f"Failed to parse expression '{expr_str}': {exc}" raise ExpressionParseError(msg) from exc
[docs] def sympy_to_pytensor( sympy_expr: sp.Expr, variables: list[pt.variable.TensorVariable[TensorType, Any]], ) -> pt.variable.TensorVariable[TensorType, Any]: """ Converts a SymPy expression into a PyTensor computational graph using lambdify. Args: sympy_expr: The SymPy expression object. variables: List of PyTensor variables. Returns: PyTensor expression. Raises: ExpressionEvaluationError: If the expression cannot be converted or contains unsupported operations. """ try: # Define the mapping for SymPy functions to PyTensor functions (using pt.math) custom_modules = { "sin": pt.math.sin, "cos": pt.math.cos, "tan": pt.math.tan, "exp": pt.math.exp, "log": pt.math.log, "sqrt": pt.math.sqrt, "abs": pt.math.abs, "erf": pt.math.erf, "min": pt.math.min, "max": pt.math.max, } # Convert variable names to SymPy symbols sympy_vars = {var.name: sp.Symbol(var.name) for var in variables} # Log the expression for debugging analyze_sympy_expr(sympy_expr) # Convert SymPy expression to a PyTensor-compatible function pytensor_func = sp.lambdify( list(sympy_vars.values()), sympy_expr, modules=custom_modules ) # Apply the function to PyTensor variables result = pytensor_func(*variables) # Handle case where result is a constant (not a PyTensor variable) if not isinstance(result, pt.variable.TensorVariable): result = pt.constant(result) return cast(pt.variable.TensorVariable[TensorType, Any], result) except Exception as exc: msg = f"Failed to convert expression to PyTensor: {sympy_expr}. {exc}" raise ExpressionEvaluationError(msg) from exc