Skip to content

API-Reference

SimpliPy Engine

SimplificationStatistics dataclass

SimplificationStatistics(rule_application_counts: defaultdict[tuple, int] = (lambda: defaultdict(int))(), explicit_rule_applications: int = 0, pattern_rule_applications: int = 0, post_operand_rule_applications: int = 0, constant_folding_count: int = 0, rule_match_attempts: int = 0, rule_match_hits: int = 0, cancellation_events: list[dict[str, Any]] = list(), iterations_used: int = 0, converged: bool = False, result_rejected: bool = False, per_iteration_lengths: list[dict[str, int]] = list(), stage_timings: dict[str, float] = (lambda: {'cancel_terms': 0.0, 'apply_rules': 0.0, 'sort_operands': 0.0, 'mask_literals': 0.0})())

Collects detailed statistics about a simplification run.

This dataclass is populated by :meth:SimpliPyEngine.simplify when collect_statistics=True. It replaces the former rule_application_statistics dict with a richer set of metrics that cover every stage of the simplification pipeline.

ATTRIBUTE DESCRIPTION
rule_application_counts

How many times each (rule_index, pattern, replacement) rule fired. rule_index is the position of the rule in the original simplification_rules list (-1 for rules not found in the list).

TYPE: defaultdict[tuple, int]

explicit_rule_applications

Total number of explicit (no-wildcard) rule applications.

TYPE: int

pattern_rule_applications

Total number of wildcard-pattern rule applications.

TYPE: int

post_operand_rule_applications

Rules that fired only after children were simplified first.

TYPE: int

constant_folding_count

How often the all-operands-are-<constant> short-circuit fired.

TYPE: int

rule_match_attempts

Total match_pattern calls made.

TYPE: int

rule_match_hits

How many of those attempts succeeded.

TYPE: int

cancellation_events

One entry per term cancellation with keys 'class' ('add'/ 'mult'), 'subtree', 'multiplicity_sum', and 'neutral_insertions'.

TYPE: list[dict[str, Any]]

iterations_used

Number of simplification iterations that were executed.

TYPE: int

converged

Whether the loop stopped before reaching max_iter.

TYPE: bool

result_rejected

Whether the simplified result was longer than the input and therefore discarded.

TYPE: bool

per_iteration_lengths

For each iteration, a dict with keys 'after_cancel' and 'after_rules' holding the expression length at that point.

TYPE: list[dict[str, int]]

stage_timings

Cumulative wall-clock seconds keyed by stage name: 'cancel_terms', 'apply_rules', 'sort_operands', 'mask_literals'.

TYPE: dict[str, float]

SimpliPyEngine

SimpliPyEngine(operators: dict[str, dict[str, Any]], rules: list[tuple] | None = None)

Manages and manipulates symbolic expressions.

This class provides a comprehensive toolkit for parsing, transforming, and simplifying mathematical expressions. It operates on expressions in prefix notation (a list of tokens) and uses a customizable set of operators and simplification rules.

PARAMETER DESCRIPTION
operators

A dictionary defining the operators. Each key is the operator's canonical name (e.g., 'add', 'sin'), and the value is another dictionary specifying its properties like 'arity', 'realization' (the corresponding Python function), 'inverse', etc.

TYPE: dict[str, dict[str, Any]]

rules

A list of simplification rules. Each rule is a tuple containing two lists of strings: the pattern to match and the replacement expression, both in prefix notation. If None, the engine is initialized with no rules.

TYPE: list[tuple] or None DEFAULT: None

ATTRIBUTE DESCRIPTION
operator_tokens

A list of all defined operator names.

TYPE: list[str]

operator_arity

A mapping from operator names to their arity (number of arguments).

TYPE: dict[str, int]

simplification_rules

The list of simplification rules loaded into the engine.

TYPE: list[tuple]

simplification_rules_patterns

A compiled version of rules that involve pattern variables (e.g., _0), organized for efficient matching.

TYPE: dict

simplification_rules_no_patterns

A compiled version of explicit rules without pattern variables.

TYPE: dict

Source code in src/simplipy/engine.py
def __init__(self, operators: dict[str, dict[str, Any]], rules: list[tuple] | None = None) -> None:
    # Cache operator metadata for quick access during parsing and evaluation.
    self.operator_tokens = list(operators.keys())
    self.operator_aliases = {alias: operator for operator, properties in operators.items() for alias in properties['alias']}
    self.operator_inverses = {k: v["inverse"] for k, v in operators.items() if v.get("inverse") is not None}

    self.inverse_base = {'*': ['inv', '/', '1'], '+': ['neg', '-', '0']}
    self.inverse_unary = {v[0]: [k, v[1], v[2]] for k, v in self.inverse_base.items()}
    self.inverse_binary = {v[1]: [k, v[0], v[2]] for k, v in self.inverse_base.items()}

    self.unary_mult_div_operators = {k: v["inverse"] for k, v in operators.items() if k.startswith('mult') or k.startswith('div')}
    self.commutative_operators = [k for k, v in operators.items() if v.get("commutative", False)]

    self.operator_realizations = {k: v["realization"] for k, v in operators.items()}
    self.realization_to_operator = {v: k for k, v in self.operator_realizations.items()}

    self.operator_precedence_compat = {k: v.get("precedence", i) for i, (k, v) in enumerate(operators.items())}
    self.operator_precedence_compat['**'] = 3
    self.operator_precedence_compat['sqrt'] = 3

    self.operator_arity = {k: v["arity"] for k, v in operators.items()}
    self.operator_arity_compat = deepcopy(self.operator_arity)
    self.operator_arity_compat['**'] = 2
    self.operators = list(self.operator_arity.keys())

    self.max_power = max([int(op[3:]) for op in self.operator_tokens if re.match(r'pow\d+(?!\_)', op)] + [0])
    self.max_fractional_power = max([int(op[5:]) for op in self.operator_tokens if re.match(r'pow1_\d+', op)] + [0])

    self.modules = get_used_modules(''.join(f"{op}(" for op in self.operator_realizations.values()))
    self.import_modules()

    self.connection_classes = {'add': (['+', '-'], "0"), 'mult': (['*', '/'], "1")}
    self.operator_to_class = {'+': 'add', '-': 'add', '*': 'mult', '/': 'mult'}
    self.connection_classes_inverse = {'add': "neg", 'mult': "inv"}
    self.connection_classes_hyper = {'add': "mult", 'mult': "pow"}
    self.binary_connectable_operators = {'+', '-', '*', '/'}

    # Normalize the incoming rule list and eliminate duplicate patterns.
    dummy_variables = [f'x{i}' for i in range(100)]
    if rules is None:
        self.simplification_rules = []
    else:
        self.simplification_rules = deduplicate_rules(rules, dummy_variables=dummy_variables)

    # Build the compiled lookup tables that power rule application.
    self.compile_rules()
    self.simplification_statistics: SimplificationStatistics | None = None

compile_rules

compile_rules() -> None

Compiles the text-based rules into an efficient internal format.

This method processes the self.simplification_rules list, separating them into rules with patterns (like '_0', '_1') and explicit rules. It then converts the patterns into a tree-based structure optimized for fast matching against expression subtrees.

Source code in src/simplipy/engine.py
def compile_rules(self) -> None:
    """Compiles the text-based rules into an efficient internal format.

    This method processes the `self.simplification_rules` list,
    separating them into rules with patterns (like '_0', '_1') and
    explicit rules. It then converts the patterns into a tree-based
    structure optimized for fast matching against expression subtrees.
    """
    simplification_rules_patterns = []
    simplification_rules_no_patterns = []
    self._rule_index_lookup: dict[tuple[tuple, tuple], int] = {}
    for idx, r in enumerate(self.simplification_rules):
        key = (tuple(r[0]), tuple(r[1]))
        self._rule_index_lookup[key] = idx
        if any(_WILDCARD_RE.match(t) for t in r[0]):
            simplification_rules_patterns.append(r)
        else:
            simplification_rules_no_patterns.append(r)
    self.max_pattern_length = 0
    self.simplification_rules_patterns: dict[tuple, list[tuple[list, list]]] = self.construct_rule_patterns(simplification_rules_patterns)
    self.simplification_rules_no_patterns: dict[tuple, tuple] = {tuple(r[0]): tuple(r[1]) for r in simplification_rules_no_patterns}

prune_redundant_rules

prune_redundant_rules(verbose: bool = False) -> int

Remove explicit rules that are subsumed by wildcard-pattern rules.

An explicit rule (e, r_e) is redundant if the engine still simplifies e to r_e when that single rule is removed. This happens when a wildcard-pattern rule already covers the same transformation, or when constant folding / term cancellation achieve the same result.

Rules are tested and removed serially: once a rule is found redundant it stays removed for all subsequent tests. This avoids over-pruning in the case where two explicit rules each appear redundant in the presence of the other but neither is covered by a pattern rule alone.

PARAMETER DESCRIPTION
verbose

If True, shows a progress bar and prints a summary. Defaults to False.

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
int

The number of rules that were pruned.

Source code in src/simplipy/engine.py
def prune_redundant_rules(self, verbose: bool = False) -> int:
    """Remove explicit rules that are subsumed by wildcard-pattern rules.

    An explicit rule ``(e, r_e)`` is *redundant* if the engine still
    simplifies ``e`` to ``r_e`` when that single rule is removed.  This
    happens when a wildcard-pattern rule already covers the same
    transformation, or when constant folding / term cancellation achieve
    the same result.

    Rules are tested and removed serially: once a rule is found redundant
    it stays removed for all subsequent tests.  This avoids over-pruning
    in the case where two explicit rules each appear redundant in the
    presence of the other but neither is covered by a pattern rule alone.

    Parameters
    ----------
    verbose : bool, optional
        If True, shows a progress bar and prints a summary.
        Defaults to False.

    Returns
    -------
    int
        The number of rules that were pruned.
    """
    # Collect indices of explicit (non-pattern) rules
    explicit_indices = [
        i for i, (lhs, _rhs) in enumerate(self.simplification_rules)
        if not any(_WILDCARD_RE.match(t) for t in lhs)
    ]

    n_pruned = 0
    pruned_indices: set[int] = set()

    for idx in tqdm(explicit_indices, desc='Pruning redundant rules', disable=not verbose):
        lhs, rhs = self.simplification_rules[idx]
        lhs_key = tuple(lhs)

        # Remove this explicit rule from the compiled dict
        saved = self.simplification_rules_no_patterns.pop(lhs_key, None)

        result = self.simplify(list(lhs), mask_elementary_literals=False)
        if tuple(result) == tuple(rhs):
            # Rule is redundant — keep it removed
            pruned_indices.add(idx)
            n_pruned += 1
        else:
            # Rule is needed — restore it
            if saved is not None:
                self.simplification_rules_no_patterns[lhs_key] = saved

    if pruned_indices:
        self.simplification_rules = [
            rule for i, rule in enumerate(self.simplification_rules)
            if i not in pruned_indices
        ]
        self.compile_rules()

    if verbose:
        print(f'Pruned {n_pruned} redundant explicit rules '
              f'({len(self.simplification_rules)} rules remaining)')

    return n_pruned

import_modules

import_modules() -> None

Imports Python modules required by operator realizations.

The engine inspects the 'realization' strings of all operators (e.g., 'np.sin') to identify necessary modules (e.g., 'numpy') and imports them into the global namespace to make them available for expression evaluation.

Source code in src/simplipy/engine.py
def import_modules(self) -> None:
    """Imports Python modules required by operator realizations.

    The engine inspects the 'realization' strings of all operators
    (e.g., 'np.sin') to identify necessary modules (e.g., 'numpy') and
    imports them into the global namespace to make them available for
    expression evaluation.
    """
    for module in self.modules:
        if module not in globals():
            globals()[module] = importlib.import_module(module)

from_config classmethod

from_config(config_path: str) -> SimpliPyEngine

Creates a SimpliPyEngine instance from a JSON configuration file.

The configuration file should specify the operators and can optionally provide a path to a rules file.

PARAMETER DESCRIPTION
config_path

The absolute or relative path to the JSON configuration file.

TYPE: str

RETURNS DESCRIPTION
SimpliPyEngine

A new instance of the engine configured as per the file.

Source code in src/simplipy/engine.py
@classmethod
def from_config(cls, config_path: str) -> "SimpliPyEngine":
    """Creates a SimpliPyEngine instance from a JSON configuration file.

    The configuration file should specify the `operators` and can
    optionally provide a path to a `rules` file.

    Parameters
    ----------
    config_path : str
        The absolute or relative path to the JSON configuration file.

    Returns
    -------
    SimpliPyEngine
        A new instance of the engine configured as per the file.
    """
    config_path = os.path.abspath(config_path)
    config = load_config(config_path)
    rules = []
    rules_file = config.get('rules')
    if rules_file:
        if not os.path.isabs(rules_file):
            config_dir = os.path.dirname(config_path)
            rules_path = os.path.join(config_dir, rules_file)
        else:
            rules_path = rules_file
        if os.path.exists(rules_path):
            with open(rules_path, 'r') as f:
                rules = json.load(f)
        else:
            warnings.warn(f"Rules file '{rules_path}' specified in config not found.", UserWarning)
    return cls(operators=config['operators'], rules=rules)

load classmethod

load(path: str, install: bool = False, local_dir: Path | str | None = None, repo_id: str | None = None, manifest_filename: str | None = None) -> SimpliPyEngine

Loads a pre-defined engine configuration from the asset manager.

This provides a convenient way to load standard engine configurations distributed with the simplipy package.

PARAMETER DESCRIPTION
path

The name of the configuration to load (e.g., 'default').

TYPE: str

install

If True, forces the download of the asset if not found locally. Defaults to False.

TYPE: bool DEFAULT: False

local_dir

A local directory to search for the assets. Defaults to None, which uses the default asset directory.

TYPE: Path or str or None DEFAULT: None

repo_id

The Hugging Face repository ID where the manifest is stored. If None, the default repository ID is used.

TYPE: str or None DEFAULT: None

manifest_filename

The filename of the manifest file. If None, the default filename is used.

TYPE: str or None DEFAULT: None

RETURNS DESCRIPTION
SimpliPyEngine

A new instance of the engine.

Source code in src/simplipy/engine.py
@classmethod
def load(cls, path: str, install: bool = False, local_dir: Path | str | None = None, repo_id: str | None = None, manifest_filename: str | None = None) -> "SimpliPyEngine":
    """Loads a pre-defined engine configuration from the asset manager.

    This provides a convenient way to load standard engine configurations
    distributed with the `simplipy` package.

    Parameters
    ----------
    path : str
        The name of the configuration to load (e.g., 'default').
    install : bool, optional
        If True, forces the download of the asset if not found locally.
        Defaults to False.
    local_dir : Path or str or None, optional
        A local directory to search for the assets. Defaults to None,
        which uses the default asset directory.
    repo_id : str or None, optional
        The Hugging Face repository ID where the manifest is stored. If None, the default repository ID is used.
    manifest_filename : str or None, optional
        The filename of the manifest file. If None, the default filename is used.

    Returns
    -------
    SimpliPyEngine
        A new instance of the engine.
    """
    return cls.from_config(get_path(path, install=install, local_dir=local_dir, repo_id=repo_id, manifest_filename=manifest_filename))

is_valid

is_valid(prefix_expression: list[str], verbose: bool = False) -> bool

Checks if a prefix expression is syntactically valid.

An expression is valid if every operator has the correct number of operands according to its defined arity.

PARAMETER DESCRIPTION
prefix_expression

The expression in prefix notation.

TYPE: list[str]

verbose

If True, prints the reason for invalidity. Defaults to False.

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
bool

True if the expression is valid, False otherwise.

Source code in src/simplipy/engine.py
def is_valid(self, prefix_expression: list[str], verbose: bool = False) -> bool:
    """Checks if a prefix expression is syntactically valid.

    An expression is valid if every operator has the correct number of
    operands according to its defined arity.

    Parameters
    ----------
    prefix_expression : list[str]
        The expression in prefix notation.
    verbose : bool, optional
        If True, prints the reason for invalidity. Defaults to False.

    Returns
    -------
    bool
        True if the expression is valid, False otherwise.
    """
    stack: list[str] = []

    if len(prefix_expression) > 1 and prefix_expression[0] not in self.operator_arity:
        if verbose:
            print(f'Invalid expression {prefix_expression}: Variable must be leaf node')
        return False

    for token in reversed(prefix_expression):
        # Check if token is not a constant and numeric
        if token != '<constant>' and is_numeric_string(token):
            try:
                float(token)
            except ValueError:
                if verbose:
                    print(f'Invalid token {token} in expression {prefix_expression}')
                return False

        if token in self.operator_arity:
            if len(stack) < self.operator_arity[token]:
                if verbose:
                    print(f'Not enough operands for operator {token} in expression {prefix_expression}')
                return False

            # Consume the operands based on the arity of the operator
            for _ in range(self.operator_arity[token]):
                stack.pop()

        # Add the token to the stack
        stack.append(token)

    if len(stack) != 1:
        if verbose:
            print(f'Stack is not empty after parsing the expression {prefix_expression}')
        return False

    return True

prefix_to_infix

prefix_to_infix(tokens: list[str], power: Literal['func', '**'] = 'func', realization: bool = False) -> str

Converts a prefix expression to an infix string with minimal parentheses.

PARAMETER DESCRIPTION
tokens

The prefix expression to render.

TYPE: list[str]

power

Controls how power operators are emitted. 'func' keeps canonical engine names such as pow3(x), while '**' renders Python-style exponentiation.

TYPE: (func, '**') DEFAULT: 'func'

realization

If True, operator tokens are replaced with their runtime realizations (for example, 'sin' becomes 'np.sin'), so the output can be compiled directly.

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
str

The formatted infix expression.

RAISES DESCRIPTION
ValueError

If the provided tokens do not form a well-formed prefix expression.

Source code in src/simplipy/engine.py
def prefix_to_infix(self, tokens: list[str], power: Literal['func', '**'] = 'func', realization: bool = False) -> str:
    """Converts a prefix expression to an infix string with minimal parentheses.

    Parameters
    ----------
    tokens : list[str]
        The prefix expression to render.
    power : {'func', '**'}, optional
        Controls how power operators are emitted. ``'func'`` keeps canonical
        engine names such as ``pow3(x)``, while ``'**'`` renders Python-style
        exponentiation.
    realization : bool, optional
        If True, operator tokens are replaced with their runtime
        realizations (for example, ``'sin'`` becomes ``'np.sin'``), so the
        output can be compiled directly.

    Returns
    -------
    str
        The formatted infix expression.

    Raises
    ------
    ValueError
        If the provided tokens do not form a well-formed prefix expression.
    """

    if not tokens:
        return ''

    # Use the configured operator precedence as a baseline for deciding
    # when parentheses are necessary. Higher numbers mean higher precedence.
    op_precedence = self.operator_precedence_compat
    op_associativity = {
        '+': 'left',
        '-': 'left',
        '*': 'left',
        '/': 'left',
        '**': 'right',
        'pow': 'right',
    }

    FUNC_PRECEDENCE = float('inf')
    TERMINAL_PRECEDENCE = float('inf')

    # Stack elements are tuples of (rendered_str, precedence_value, root_operator)
    stack: list[tuple[str, float, str | None]] = []

    def right_allows_flatten(parent_op: str, child_root: str | None) -> bool:
        """Return True if a right operand with the same precedence can omit parentheses."""
        if child_root is None:
            return True

        flatten_map: dict[str, set[str]] = {
            '+': {'+', '-'},
            '*': {'*', '/'},
        }
        return child_root in flatten_map.get(parent_op, set())

    for token in reversed(tokens):
        operator = self.realization_to_operator.get(token, token)
        canonical_operator = self.operator_aliases.get(operator, operator)

        if (
            canonical_operator in self.operator_tokens
            or operator in self.operator_aliases
            or canonical_operator in self.operator_arity_compat
        ):
            arity = self.operator_arity_compat.get(canonical_operator, 1)

            if len(stack) < arity:
                raise ValueError(f"Invalid prefix expression: Not enough operands for operator '{operator}'")

            operands_data = [stack.pop() for _ in range(arity)]

            write_operator = (
                self.operator_realizations.get(canonical_operator, canonical_operator)
                if realization
                else canonical_operator
            )

            # Render realization strings that look like fully qualified callables
            if realization and ('.' in write_operator or self.operator_arity_compat.get(canonical_operator, 0) > 2):
                rendered = f"{write_operator}({', '.join(op_str for op_str, _, _ in operands_data)})"
                stack.append((rendered, FUNC_PRECEDENCE, canonical_operator))
                continue

            current_precedence = op_precedence.get(canonical_operator, op_precedence.get('pow', FUNC_PRECEDENCE))
            current_assoc = op_associativity.get(canonical_operator, 'left')

            if arity == 2:
                left_str, left_prec, left_root = operands_data[0]
                right_str, right_prec, right_root = operands_data[1]

                if canonical_operator == 'pow' and power == 'func':
                    rendered = f'{write_operator}({left_str}, {right_str})'
                    stack.append((rendered, FUNC_PRECEDENCE, canonical_operator))
                    continue

                if canonical_operator == 'pow' and power == '**':
                    write_operator = '**'
                    current_precedence = op_precedence.get('**', current_precedence)
                    current_assoc = 'right'

                if left_prec < current_precedence or (
                    left_prec == current_precedence and current_assoc == 'right'
                ):
                    left_str = f'({left_str})'

                if right_prec < current_precedence or (
                    right_prec == current_precedence and current_assoc == 'left'
                    and not right_allows_flatten(canonical_operator, right_root)
                ):
                    right_str = f'({right_str})'

                rendered = f'{left_str} {write_operator} {right_str}'
                stack.append((rendered, current_precedence, canonical_operator))
                continue

            if arity == 1:
                operand_str, operand_prec, operand_root = operands_data[0]
                is_pow_op = re.match(r'pow\d+(?!_)', canonical_operator)
                is_frac_pow_op = re.match(r'pow1_\d+', canonical_operator)

                if canonical_operator == 'neg':
                    if operand_prec < current_precedence:
                        operand_str = f'({operand_str})'
                    rendered = f'-{operand_str}'
                    stack.append((rendered, current_precedence, canonical_operator))
                    continue

                if canonical_operator == 'inv':
                    if operand_prec <= current_precedence:
                        operand_str = f'({operand_str})'
                    rendered = f'1/{operand_str}'
                    inv_precedence = op_precedence.get('/', current_precedence)
                    stack.append((rendered, inv_precedence, canonical_operator))
                    continue

                if power == '**' and (is_pow_op or is_frac_pow_op):
                    power_precedence = op_precedence.get('**', current_precedence)
                    if operand_prec <= power_precedence:
                        operand_str = f'({operand_str})'

                    if is_pow_op:
                        exponent = int(canonical_operator[3:])
                        rendered = f'{operand_str}**{exponent}'
                    else:
                        denominator = int(canonical_operator[5:])
                        rendered = f'{operand_str}**(1/{denominator})'

                    stack.append((rendered, power_precedence, canonical_operator))
                    continue

                rendered = f'{write_operator}({operand_str})'
                stack.append((rendered, FUNC_PRECEDENCE, canonical_operator))
                continue

            # Fallback for nullary or higher arity operators
            rendered = f"{write_operator}({', '.join(op_str for op_str, _, _ in operands_data)})"
            stack.append((rendered, FUNC_PRECEDENCE, canonical_operator))
        else:
            stack.append((token, TERMINAL_PRECEDENCE, None))

    if len(stack) != 1:
        raise ValueError(
            "Malformed prefix expression: too many operands remain after processing. "
            f"Stack: {[part for part, _, _ in stack]}"
        )

    return stack[0][0]

infix_to_prefix

infix_to_prefix(infix_expression: str) -> list[str]

Converts an infix expression string to prefix notation.

This method uses a standard algorithm (related to Shunting-yard) to parse the infix string, respecting operator precedence and parentheses.

PARAMETER DESCRIPTION
infix_expression

The mathematical expression in infix notation.

TYPE: str

RETURNS DESCRIPTION
list[str]

A list of tokens representing the expression in prefix notation.

Source code in src/simplipy/engine.py
def infix_to_prefix(self, infix_expression: str) -> list[str]:
    """Converts an infix expression string to prefix notation.

    This method uses a standard algorithm (related to Shunting-yard) to
    parse the infix string, respecting operator precedence and parentheses.

    Parameters
    ----------
    infix_expression : str
        The mathematical expression in infix notation.

    Returns
    -------
    list[str]
        A list of tokens representing the expression in prefix notation.
    """
    # Regex to tokenize expression properly (handles floating-point numbers and scientific notation)
    number_pattern = r'(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?'
    # Include caret '^' as a distinct power token so users can write x ^ 3
    token_pattern = re.compile(rf'<constant>|{number_pattern}|[A-Za-z_][\w.]*|\*\*|[-+*/^()]')

    # Tokenize the infix expression
    tokens = token_pattern.findall(infix_expression.replace(' ', ''))

    stack: list[str] = []
    prefix_expr: list[str] = []

    # Reverse the tokens for right-to-left parsing
    tokens = tokens[::-1]

    i = 0
    while i < len(tokens):
        token = tokens[i]

        # Normalize alternative power symbol '^' to the canonical '**'
        if token == '^':
            token = '**'

        # Handle numbers (integers, floats, or scientific notation)
        if re.fullmatch(number_pattern, token):
            prefix_expr.append(token)
        elif re.match(r'[A-Za-z_][\w.]*', token) or token == '<constant>':  # Match functions and variables
            prefix_expr.append(token)
        elif token == ')':
            stack.append(token)
        elif token == '(':
            while stack and stack[-1] != ')':
                prefix_expr.append(stack.pop())
            if stack and stack[-1] == ')':
                stack.pop()  # Pop the ')'
        else:
            # Handle binary and unary operators
            if token == '-' and (i == len(tokens) - 1 or tokens[i + 1] == '(' or (tokens[i + 1]) in self.operator_precedence_compat):
                # Handle unary negation (not part of a number)
                token = 'neg'

            if stack and stack[-1] != ')' and token != ')':
                while stack and self.operator_precedence_compat.get(stack[-1], 0) >= self.operator_precedence_compat.get(token, 0):
                    prefix_expr.append(stack.pop())
                stack.append(token)
            else:

                if (token == 'neg' and not stack) or (stack and stack[-1] != ')'):
                    stack.insert(-1, token)
                else:
                    stack.append(token)

        i += 1

    while stack:
        prefix_expr.append(stack.pop())

    return prefix_expr[::-1]

convert_expression

convert_expression(prefix_expr: list[str]) -> list[str]

Normalizes an expression into the engine's standard internal format.

This method performs several key conversions: 1. Converts standard binary operators like ** into the engine's unary power operators (e.g., pow2, pow1_3). 2. Combines chained power operators (e.g., pow2(pow3(x)) becomes pow6(x)). 3. Handles unary negation, applying it directly to numbers where possible.

PARAMETER DESCRIPTION
prefix_expr

The prefix expression to convert.

TYPE: list[str]

RETURNS DESCRIPTION
list[str]

The normalized prefix expression.

Source code in src/simplipy/engine.py
def convert_expression(self, prefix_expr: list[str]) -> list[str]:
    """Normalizes an expression into the engine's standard internal format.

    This method performs several key conversions:
    1.  Converts standard binary operators like `**` into the engine's
        unary power operators (e.g., `pow2`, `pow1_3`).
    2.  Combines chained power operators (e.g., `pow2(pow3(x))` becomes
        `pow6(x)`).
    3.  Handles unary negation, applying it directly to numbers where
        possible.

    Parameters
    ----------
    prefix_expr : list[str]
        The prefix expression to convert.

    Returns
    -------
    list[str]
        The normalized prefix expression.
    """
    stack: list = []
    i = len(prefix_expr) - 1

    while i >= 0:
        token = prefix_expr[i]

        if token in self.operator_arity_compat or token in self.operator_aliases or re.match(r'pow\d+(?!\_)', token) or re.match(r'pow1_\d+', token):
            operator = self.operator_aliases.get(token, token)
            arity = self.operator_arity_compat[operator]

            if operator == 'neg':
                # If the operand of neg is a number, combine them
                if isinstance(stack[-1][0], str):
                    if is_numeric_string(stack[-1][0]):
                        stack[-1][0] = f'-{stack[-1][0]}'
                    elif is_numeric_string(stack[-1][0]):
                        stack[-1][0] = stack[-1][0][1:]
                    else:
                        # General case: assemble operator and its operands
                        operands = [stack.pop() for _ in range(arity)]
                        stack.append([operator, operands])
                else:
                    # General case: assemble operator and its operands+
                    operands = [stack.pop() for _ in range(arity)]
                    stack.append([operator, operands])

            elif operator == '**':
                # Check for floating-point exponent
                base = stack.pop()
                exponent = stack.pop()

                if len(exponent) == 1:
                    if re.match(r'-?\d+$', exponent[0]):  # Integer exponent
                        exponent_value: int | float = int(exponent[0])
                        pow_operator = f'pow{abs(exponent_value)}'
                        if exponent_value < 0:
                            stack.append(['inv', [[pow_operator, [base]]]])
                        else:
                            stack.append([pow_operator, [base]])
                    elif is_numeric_string(exponent[0]):  # Floating-point exponent
                        exponent_value = float(exponent[0])

                        # Try to convert the exponent into a fraction
                        abs_exponent_fraction = fractions.Fraction(abs(float(exponent[0]))).limit_denominator()
                        if abs_exponent_fraction.numerator <= 5 and abs_exponent_fraction.denominator <= 5:
                            # Format the fraction as a combination of power operators, i.e. "x**(2/3)" -> "pow1_3(pow2(x))"
                            new_expression = [base]
                            if abs_exponent_fraction.numerator != 1:
                                new_expression = [f'pow{abs_exponent_fraction.numerator}', new_expression]
                            if abs_exponent_fraction.denominator != 1:
                                new_expression = [f'pow1_{abs_exponent_fraction.denominator}', new_expression]
                            if exponent_value < 0:
                                new_expression = ['inv', new_expression]
                            stack.append(new_expression)
                        else:
                            stack.append(['pow', [base, exponent]])
                    else:
                        stack.append(['pow', [base, exponent]])

                elif len(exponent) == 2 and exponent[0][0] == '/' and is_numeric_string(exponent[1][0][0]) and is_numeric_string(exponent[1][1][0]):
                    # Handle fractional exponent, e.g. "x**(2/3)"
                    if re.match(r'-?\d+$', exponent[1][0][0]) and re.match(r'-?\d+$', exponent[1][1][0]):
                        # Integer fraction exponent
                        numerator = int(exponent[1][0][0])
                        denominator = int(exponent[1][1][0])
                        numerator_power = f'pow{abs(numerator)}'
                        denominator_power = f'pow1_{abs(denominator)}'
                        if numerator * denominator < 0:
                            stack.append(['inv', [[denominator_power, [[numerator_power, [base]]]]]])
                        else:
                            stack.append([denominator_power, [[numerator_power, [base]]]])
                    else:
                        exponent_value = int(exponent[1][0][0]) / int(exponent[1][1][0])
                        abs_exponent_fraction = fractions.Fraction(abs(exponent_value)).limit_denominator()
                        if abs_exponent_fraction.numerator <= 5 and abs_exponent_fraction.denominator <= 5:
                            # Format the fraction as a combination of power operators, i.e. "x**(2/3)" -> "pow1_3(pow2(x))"
                            new_expression = [base]
                            if abs_exponent_fraction.numerator != 1:
                                new_expression = [f'pow{abs_exponent_fraction.numerator}', new_expression]
                            if abs_exponent_fraction.denominator != 1:
                                new_expression = [f'pow1_{abs_exponent_fraction.denominator}', new_expression]
                            if exponent_value < 0:
                                new_expression = ['inv', new_expression]
                            stack.append(new_expression)
                        else:
                            stack.append(['pow', [base, exponent]])
                else:
                    stack.append(['pow', [base, exponent]])

            else:
                # General case: assemble operator and its operands
                operands = [stack.pop() for _ in range(arity)]
                stack.append([operator, operands])
        else:
            # Non-operator token (operand)
            stack.append([token])

        i -= 1

    need_to_convert_powers_expression = flatten_nested_list(stack)[::-1]

    stack = []
    i = len(need_to_convert_powers_expression) - 1

    while i >= 0:
        token = need_to_convert_powers_expression[i]

        if re.match(r'pow\d+(?!\_)', token) or re.match(r'pow1_\d+', token):
            operator = self.operator_aliases.get(token, token)
            arity = self.operator_arity_compat.get(operator, 1)
            operands = list(reversed(stack[-arity:]))

            # Identify chains of pow<i> xor pow1_<i> operators
            # Mixed chains are ignored
            operator_chain = [operator]
            current_operand = operands[0]

            operator_bases = ['pow1_', 'pow']
            operator_patterns = [r'pow1_\d+', r'pow\d+']
            operator_patterns_grouped = [r'pow1_(\d+)', r'pow(\d+)']
            max_powers = [self.max_fractional_power, self.max_power]
            for base, pattern, pattern_grouped, p in zip(operator_bases, operator_patterns, operator_patterns_grouped, max_powers):
                if re.match(pattern, operator):
                    operator_base = base
                    operator_pattern = pattern
                    operator_pattern_grouped = pattern_grouped
                    max_power = p
                    break

            while len(current_operand) == 2 and re.match(operator_pattern, current_operand[0]):
                operator_chain.append(current_operand[0])
                current_operand = current_operand[1]

            if len(operator_chain) > 0:
                p = prod(int(re.match(operator_pattern_grouped, op).group(1)) for op in operator_chain)  # type: ignore

                try:
                    p_factors = factorize_to_at_most(p, max_power)
                    new_operators = [f'{operator_base}{factor}' for factor in p_factors]

                    if len(new_operators) == 0:
                        new_chain = current_operand
                    else:
                        new_chain = [new_operators[-1], [current_operand]]
                        for op in new_operators[-2::-1]:
                            new_chain = [op, [new_chain]]
                except ValueError:
                    # Fall back to the original chain of operators when the
                    # exponent cannot be expressed using the configured
                    # unary power operators.
                    new_chain = [operator_chain[-1], [current_operand]]
                    for op in operator_chain[-2::-1]:
                        new_chain = [op, [new_chain]]

                _ = [stack.pop() for _ in range(arity)]
                stack.append(new_chain)
                i -= 1
                continue

        elif token in self.operator_arity_compat or token in self.operator_aliases:
            operator = self.operator_aliases.get(token, token)
            arity = self.operator_arity_compat[operator]
            operands = list(reversed(stack[-arity:]))

            _ = [stack.pop() for _ in range(arity)]
            stack.append([operator, operands])
            i -= 1
            continue

        else:
            stack.append([token])
            i -= 1

    return flatten_nested_list(stack)[::-1]

parse

parse(infix_expression: str, convert_expression: bool = True, mask_numbers: bool = False) -> list[str]

Parses an infix string into a standardized prefix expression.

This is a high-level parsing utility that combines infix_to_prefix with optional canonicalization and number masking. The resulting token list is additionally cleaned up via remove_pow1 to drop redundant pow1_1 occurrences.

PARAMETER DESCRIPTION
infix_expression

The mathematical expression in infix notation.

TYPE: str

convert_expression

If True, the expression is normalized using convert_expression. Defaults to True.

TYPE: bool DEFAULT: True

mask_numbers

If True, all numerical literals in the expression are replaced with a generic '' token. Defaults to False.

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
list[str]

The processed prefix expression after conversion, masking (if enabled), and remove_pow1 cleanup.

Source code in src/simplipy/engine.py
def parse(
        self,
        infix_expression: str,
        convert_expression: bool = True,
        mask_numbers: bool = False) -> list[str]:
    """Parses an infix string into a standardized prefix expression.

    This is a high-level parsing utility that combines `infix_to_prefix`
    with optional canonicalization and number masking. The resulting token
    list is additionally cleaned up via `remove_pow1` to drop redundant
    ``pow1_1`` occurrences.

    Parameters
    ----------
    infix_expression : str
        The mathematical expression in infix notation.
    convert_expression : bool, optional
        If True, the expression is normalized using `convert_expression`.
        Defaults to True.
    mask_numbers : bool, optional
        If True, all numerical literals in the expression are replaced
        with a generic '<constant>' token. Defaults to False.

    Returns
    -------
    list[str]
        The processed prefix expression after conversion, masking (if
        enabled), and `remove_pow1` cleanup.
    """

    parsed_expression = self.infix_to_prefix(infix_expression)

    if convert_expression:
        parsed_expression = self.convert_expression(parsed_expression)
    if mask_numbers:
        parsed_expression = numbers_to_constant(parsed_expression, inplace=True)

    return remove_pow1(parsed_expression)  # HACK: Find a better place to put this

prefix_to_tree

prefix_to_tree(expression: list[str]) -> list

Converts a flat prefix expression into a nested tree structure.

The tree is represented as a nested list, where each subtree is a list of the form [operator, [operand1, operand2, ...]] and leaves are lists of the form [variable].

PARAMETER DESCRIPTION
expression

The expression in prefix notation.

TYPE: list[str]

RETURNS DESCRIPTION
list

The nested list representing the expression tree.

Source code in src/simplipy/engine.py
def prefix_to_tree(self, expression: list[str]) -> list:
    """Converts a flat prefix expression into a nested tree structure.

    The tree is represented as a nested list, where each subtree is a
    list of the form `[operator, [operand1, operand2, ...]]` and leaves
    are lists of the form `[variable]`.

    Parameters
    ----------
    expression : list[str]
        The expression in prefix notation.

    Returns
    -------
    list
        The nested list representing the expression tree.
    """
    def build_tree(index: int) -> tuple[list | None, int]:
        if index >= len(expression):
            return None, index

        token = expression[index]

        # If token is not an operator or is an operator with arity 0
        if isinstance(token, dict) or token not in self.operator_arity or self.operator_arity[token] == 0:
            return [token], index + 1

        # If token is an operator
        operands = []
        current_index = index + 1

        # Process operands based on the operator's arity
        for _ in range(self.operator_arity[token]):
            if current_index >= len(expression):
                break

            subtree, current_index = build_tree(current_index)
            if subtree:
                operands.append(subtree)

        return [token, operands], current_index

    result, _ = build_tree(0)

    if result is None:
        raise ValueError(f'Failed to build tree from expression {expression}')

    return result

construct_rule_patterns

construct_rule_patterns(rules_list: list[tuple[tuple[str, ...], tuple[str, ...]]], verbose: bool = False) -> dict[tuple, list[tuple[list, list]]]

Transforms a list of rules into a structured dictionary of pattern trees.

This pre-processes rules for efficient matching. It groups rules by the length and root operator of their patterns and converts the flat prefix patterns into tree structures using prefix_to_tree.

PARAMETER DESCRIPTION
rules_list

A list of simplification rules to process.

TYPE: list[tuple]

verbose

If True, displays a progress bar. Defaults to False.

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
dict

A dictionary mapping (pattern_length, root_operator) tuples to a list of (pattern_tree, replacement_tree) tuples.

Source code in src/simplipy/engine.py
def construct_rule_patterns(self, rules_list: list[tuple[tuple[str, ...], tuple[str, ...]]], verbose: bool = False) -> dict[tuple, list[tuple[list, list]]]:
    """Transforms a list of rules into a structured dictionary of pattern trees.

    This pre-processes rules for efficient matching. It groups rules by the
    length and root operator of their patterns and converts the flat
    prefix patterns into tree structures using `prefix_to_tree`.

    Parameters
    ----------
    rules_list : list[tuple]
        A list of simplification rules to process.
    verbose : bool, optional
        If True, displays a progress bar. Defaults to False.

    Returns
    -------
    dict
        A dictionary mapping `(pattern_length, root_operator)` tuples to a
        list of `(pattern_tree, replacement_tree)` tuples.
    """
    # Group the rules by arity
    rules_list_of_operator: defaultdict[str, list] = defaultdict(list)
    for rule in rules_list:
        rules_list_of_operator[rule[0][0]].append(rule)
    rules_list_of_operator = dict(rules_list_of_operator)  # type: ignore

    # Sort the rules by length of the left-hand side to make matching more efficient
    for operator, rules_list_of_operator_list in rules_list_of_operator.items():
        rules_list_of_operator[operator] = sorted(rules_list_of_operator_list, key=lambda x: len(x[0]))

    # Construct the trees for pattern matching
    rules_trees = {operator: [
        (
            self.prefix_to_tree(list(rule[0])),
            self.prefix_to_tree(list(rule[1]))
        )
        for rule in rules_list_of_operator_a] for operator, rules_list_of_operator_a in tqdm(rules_list_of_operator.items(), desc='Constructing patterns', disable=not verbose)}

    rules_trees_organized: defaultdict[tuple, list] = defaultdict(list)
    for operator, rules in rules_trees.items():
        for (pattern, replacement) in rules:
            pattern_length = len(flatten_nested_list(pattern))
            rules_trees_organized[(pattern_length, operator,)].append((pattern, replacement))

            if pattern_length > self.max_pattern_length:
                self.max_pattern_length = pattern_length

    return rules_trees_organized

parse_subtree

parse_subtree(tokens: list[str] | tuple[str, ...], start_idx: int) -> tuple[list, int]

Parses a complete subtree from a token list starting at a given index.

Recursively consumes tokens corresponding to an operator and its operands to build a single expression tree.

PARAMETER DESCRIPTION
tokens

A sequence of tokens in prefix notation.

TYPE: list[str] or tuple[str, ...]

start_idx

The index in tokens where the subtree is assumed to start.

TYPE: int

RETURNS DESCRIPTION
subtree

The parsed subtree as a nested list.

TYPE: list

next_idx

The index of the token immediately following the parsed subtree.

TYPE: int

Source code in src/simplipy/engine.py
def parse_subtree(self, tokens: list[str] | tuple[str, ...], start_idx: int) -> tuple[list, int]:
    """Parses a complete subtree from a token list starting at a given index.

    Recursively consumes tokens corresponding to an operator and its
    operands to build a single expression tree.

    Parameters
    ----------
    tokens : list[str] or tuple[str, ...]
        A sequence of tokens in prefix notation.
    start_idx : int
        The index in `tokens` where the subtree is assumed to start.

    Returns
    -------
    subtree : list
        The parsed subtree as a nested list.
    next_idx : int
        The index of the token immediately following the parsed subtree.
    """
    if start_idx >= len(tokens):
        raise ValueError(f"Start index {start_idx} is out of bounds for tokens {tokens}")

    token = tokens[start_idx]

    if token in self.operator_arity_compat or token in self.operator_aliases:
        operator = self.operator_aliases.get(token, token)
        arity = self.operator_arity_compat[operator]
        operands = []
        idx = start_idx + 1

        for _ in range(arity):
            operand, idx = self.parse_subtree(tokens, idx)
            operands.append(operand)

        return [operator, operands], idx
    else:
        # It's a terminal (constant or variable)
        return [token], start_idx + 1

apply_rules_top_down

apply_rules_top_down(subtree: list, max_pattern_length: int | None = None, collect_statistics: bool = False, verbose: bool = False) -> list

Recursively applies simplification rules to an expression tree.

It attempts to match rules at the current node (top-down). If no rule matches, it recursively calls itself on the node's children. After the children are simplified, it re-checks for matching rules at the current node, in case a child's simplification enables a new rule.

PARAMETER DESCRIPTION
subtree

The expression tree (nested list) to simplify.

TYPE: list

max_pattern_length

The maximum length of a rule pattern to consider. Defaults to None.

TYPE: int or None DEFAULT: None

collect_statistics

If True, records which rules are successfully applied. Defaults to False.

TYPE: bool DEFAULT: False

verbose

If True, prints detailed information about rule applications. Defaults to False.

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
list

The simplified expression tree.

Source code in src/simplipy/engine.py
def apply_rules_top_down(self, subtree: list, max_pattern_length: int | None = None, collect_statistics: bool = False, verbose: bool = False) -> list:
    """Recursively applies simplification rules to an expression tree.

    It attempts to match rules at the current node (top-down). If no rule
    matches, it recursively calls itself on the node's children. After
    the children are simplified, it re-checks for matching rules at the
    current node, in case a child's simplification enables a new rule.

    Parameters
    ----------
    subtree : list
        The expression tree (nested list) to simplify.
    max_pattern_length : int or None, optional
        The maximum length of a rule pattern to consider. Defaults to None.
    collect_statistics : bool, optional
        If True, records which rules are successfully applied. Defaults to False.
    verbose : bool, optional
        If True, prints detailed information about rule applications. Defaults to False.

    Returns
    -------
    list
        The simplified expression tree.
    """
    if len(subtree) == 1:
        # Terminal node, no rules to apply
        return subtree

    stats = self.simplification_statistics if collect_statistics else None

    operator = subtree[0]
    operands = subtree[1]

    # First, check if all operands are constants
    if all(len(operand) == 1 and operand[0] == '<constant>' for operand in operands):
        if stats is not None:
            stats.constant_folding_count += 1
        return ['<constant>']

    # Convert subtree to flat form for rule matching
    flat_subtree = tuple(flatten_nested_list(subtree)[::-1])
    subtree_length = len(flat_subtree)

    if verbose:
        print(f'Checking if explicit rule applies to subtree: {flat_subtree} with length {subtree_length}')

    # Check explicit rules first
    replacement = self.simplification_rules_no_patterns.get(flat_subtree, None)
    if verbose:
        print(f'Explicit rule found: {flat_subtree} -> {replacement}' if replacement else 'No explicit rule found')
    if replacement is not None:
        if stats is not None:
            rule_idx = self._rule_index_lookup.get((flat_subtree, replacement), -1)
            stats.rule_application_counts[(rule_idx, flat_subtree, replacement)] += 1
            stats.explicit_rule_applications += 1
        if verbose:
            print(f'Applied explicit rule\t{flat_subtree} ->\n\t\t{replacement}\nto subtree\t{subtree}\n')
        # Parse and recursively simplify the replacement
        parsed_replacement, _ = self.parse_subtree(list(replacement), 0)
        return self.apply_rules_top_down(parsed_replacement, max_pattern_length, collect_statistics, verbose)

    # Check pattern rules, starting with the largest patterns
    if max_pattern_length is None:
        subtree_max_pattern_length = min(subtree_length, self.max_pattern_length)
    else:
        subtree_max_pattern_length = min(max_pattern_length, subtree_length, self.max_pattern_length)

    for pattern_length in reversed(range(1, subtree_max_pattern_length + 1)):
        if verbose:
            print(f'Checking pattern rules for operator {operator} with subtree length {pattern_length}')
        for rule in self.simplification_rules_patterns.get((pattern_length, operator,), []):
            if stats is not None:
                stats.rule_match_attempts += 1
            does_match, mapping = match_pattern(subtree, rule[0], mapping=None)
            if does_match:
                if stats is not None:
                    stats.rule_match_hits += 1
                # Apply the mapping to get the replacement
                replacement_tree = apply_mapping(deepcopy(rule[1]), mapping)
                if stats is not None:
                    rule_key = (
                        tuple(flatten_nested_list(rule[0])[::-1]),
                        tuple(flatten_nested_list(rule[1])[::-1]))
                    rule_idx = self._rule_index_lookup.get(rule_key, -1)
                    stats.rule_application_counts[(rule_idx, *rule_key)] += 1
                    stats.pattern_rule_applications += 1
                if verbose:
                    print(f'Applied pattern rule\t{rule[0]} ->\n\t\t{rule[1]}\nto subtree\t{subtree}\nwith mapping\t{mapping}\n')
                # Recursively simplify the replacement
                return self.apply_rules_top_down(replacement_tree, max_pattern_length, collect_statistics, verbose)

    # No rule applied at this level, recursively simplify operands
    simplified_operands = [self.apply_rules_top_down(operand, max_pattern_length, collect_statistics, verbose) for operand in operands]
    simplified_subtree = [operator, simplified_operands]

    # After simplifying operands, check again if a rule now applies
    # (This handles cases where simplification of operands enables a rule)
    flat_simplified = tuple(flatten_nested_list(simplified_subtree)[::-1])

    # Check explicit rules again
    replacement = self.simplification_rules_no_patterns.get(flat_simplified, None)
    if replacement is not None:
        if stats is not None:
            rule_idx = self._rule_index_lookup.get((flat_simplified, replacement), -1)
            stats.rule_application_counts[(rule_idx, flat_simplified, replacement)] += 1
            stats.explicit_rule_applications += 1
            stats.post_operand_rule_applications += 1
        if verbose:
            print(f'Applied explicit rule (after operand simplification)\t{flat_simplified} ->\n\t\t{replacement}\nto subtree\t{simplified_subtree}\n')
        parsed_replacement, _ = self.parse_subtree(list(replacement), 0)
        return self.apply_rules_top_down(parsed_replacement, max_pattern_length, collect_statistics, verbose)

    # Check pattern rules again
    for pattern_length in reversed(range(1, subtree_max_pattern_length + 1)):
        for rule in self.simplification_rules_patterns.get((pattern_length, operator,), []):
            if stats is not None:
                stats.rule_match_attempts += 1
            does_match, mapping = match_pattern(simplified_subtree, rule[0], mapping=None)
            if does_match:
                if stats is not None:
                    stats.rule_match_hits += 1
                replacement_tree = apply_mapping(deepcopy(rule[1]), mapping)
                if stats is not None:
                    rule_key = (
                        tuple(flatten_nested_list(rule[0])[::-1]),
                        tuple(flatten_nested_list(rule[1])[::-1]))
                    rule_idx = self._rule_index_lookup.get(rule_key, -1)
                    stats.rule_application_counts[(rule_idx, *rule_key)] += 1
                    stats.pattern_rule_applications += 1
                    stats.post_operand_rule_applications += 1
                if verbose:
                    print(f'Applied pattern rule (after operand simplification)\t{rule[0]} ->\n\t\t{rule[1]}\nto subtree\t{simplified_subtree}\nwith mapping\t{mapping}\n')
                return self.apply_rules_top_down(replacement_tree, max_pattern_length, collect_statistics, verbose)

    return simplified_subtree

apply_simplifcation_rules

apply_simplifcation_rules(expression: list[str] | tuple[str, ...], max_pattern_length: int | None = None, collect_statistics: bool = False, verbose: bool = False) -> list[str]

Applies all loaded simplification rules to a prefix expression.

This method serves as a wrapper around apply_rules_top_down. It first converts the flat prefix expression into a tree, applies the rules recursively, and then flattens the resulting tree back into prefix notation.

PARAMETER DESCRIPTION
expression

The expression in prefix notation.

TYPE: list[str] or tuple[str, ...]

max_pattern_length

The maximum length of rule patterns to attempt to match.

TYPE: int or None DEFAULT: None

collect_statistics

If True, updates statistics on rule application counts.

TYPE: bool DEFAULT: False

verbose

If True, enables detailed logging of the simplification process.

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
list[str]

The simplified expression in prefix notation.

Source code in src/simplipy/engine.py
def apply_simplifcation_rules(self, expression: list[str] | tuple[str, ...], max_pattern_length: int | None = None, collect_statistics: bool = False, verbose: bool = False) -> list[str]:
    """Applies all loaded simplification rules to a prefix expression.

    This method serves as a wrapper around `apply_rules_top_down`. It
    first converts the flat prefix expression into a tree, applies the
    rules recursively, and then flattens the resulting tree back into
    prefix notation.

    Parameters
    ----------
    expression : list[str] or tuple[str, ...]
        The expression in prefix notation.
    max_pattern_length : int or None, optional
        The maximum length of rule patterns to attempt to match.
    collect_statistics : bool, optional
        If True, updates statistics on rule application counts.
    verbose : bool, optional
        If True, enables detailed logging of the simplification process.

    Returns
    -------
    list[str]
        The simplified expression in prefix notation.
    """
    if all(t == '<constant>' or t in self.operator_arity for t in expression):
        return ['<constant>']

    # Parse the entire expression into a tree
    tree, _ = self.parse_subtree(expression, 0)
    if tree is None:
        return list(expression)

    # Apply rules top-down
    simplified_tree = self.apply_rules_top_down(tree, max_pattern_length, collect_statistics, verbose)

    # Flatten back to prefix notation
    return flatten_nested_list(simplified_tree)[::-1]

collect_multiplicities

collect_multiplicities(expression: list[str] | tuple[str, ...], verbose: bool = False) -> tuple[list, list, list]

Traverses an expression tree to find subtrees that can be cancelled.

This method performs a bottom-up traversal of the expression, counting the occurrences of each unique subtree within additive (+, -) and multiplicative (*, /) contexts. For example, in (a*b) + (a*b), it identifies that the subtree (a*b) appears twice in an additive context.

PARAMETER DESCRIPTION
expression

The expression in prefix notation.

TYPE: list[str] or tuple[str, ...]

verbose

If True, prints detailed debugging information. Defaults to False.

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
expression_tree

A stack-based representation of the expression tree. Each entry is a nested list of the form [operator, operands] mirroring the structure consumed by cancel_terms.

TYPE: list

annotations_tree

A parallel stack holding multiplicity annotations for each subtree, organized by connection class.

TYPE: list

labels_tree

A parallel stack containing stable identifiers for every subtree, used to detect duplicates during cancellation.

TYPE: list

Source code in src/simplipy/engine.py
def collect_multiplicities(self, expression: list[str] | tuple[str, ...], verbose: bool = False) -> tuple[list, list, list]:
    """Traverses an expression tree to find subtrees that can be cancelled.

    This method performs a bottom-up traversal of the expression, counting
    the occurrences of each unique subtree within additive (`+`, `-`) and
    multiplicative (`*`, `/`) contexts. For example, in `(a*b) + (a*b)`,
    it identifies that the subtree `(a*b)` appears twice in an additive
    context.

    Parameters
    ----------
    expression : list[str] or tuple[str, ...]
        The expression in prefix notation.
    verbose : bool, optional
        If True, prints detailed debugging information. Defaults to False.

    Returns
    -------
    expression_tree : list
        A stack-based representation of the expression tree. Each entry is a
        nested list of the form ``[operator, operands]`` mirroring the
        structure consumed by `cancel_terms`.
    annotations_tree : list
        A parallel stack holding multiplicity annotations for each subtree,
        organized by connection class.
    labels_tree : list
        A parallel stack containing stable identifiers for every subtree,
        used to detect duplicates during cancellation.
    """
    stack: list = []
    stack_annotations: list = []
    stack_labels: list = []

    i = len(expression) - 1

    # Traverse the expression from right to left
    while i >= 0:
        token = expression[i]

        if token in self.binary_connectable_operators:
            operator = token
            arity = 2
            operands = list(reversed(stack[-arity:]))
            operands_annotations_dicts = list(reversed(stack_annotations[-arity:]))
            operands_labels = list(reversed(stack_labels[-arity:]))

            operator_annotation_dict: dict[str, dict[tuple[str, ...], list[int]]] = {cc: {} for cc in self.connection_classes}

            cc = self.operator_to_class[operator]

            # Carry over annotations from operand nodes
            if verbose:
                print(f'---- {token} ----')

            for branch, operand_annotations_dict in enumerate(operands_annotations_dicts):  # One dict for left and right branch
                if verbose:
                    print(branch)
                    pprint.pprint(operand_annotations_dict)
                for subtree_hash in operand_annotations_dict[0][cc]:  # All subtrees appearing in either branch (0 gets root node of the branch)
                    # Add to operator dict if not already present
                    if subtree_hash not in operator_annotation_dict[cc]:
                        if verbose:
                            print(f'Initializing {subtree_hash} for {cc}')
                        operator_annotation_dict[cc][subtree_hash] = [0, 0]

                    if operator in {'-', '/'} and branch == 1:
                        for p in range(2):
                            if verbose:
                                print(f'Adding {operand_annotations_dict[0][cc][subtree_hash][p]} to {operator_annotation_dict[cc][subtree_hash][1 - p]} at {1 - p} of {subtree_hash} (reversed)')
                            operator_annotation_dict[cc][subtree_hash][1 - p] += operand_annotations_dict[0][cc][subtree_hash][p]
                    else:
                        for p in range(2):
                            if verbose:
                                print(f'Adding {operand_annotations_dict[0][cc][subtree_hash][p]} to {operator_annotation_dict[cc][subtree_hash][p]} at {p} of {subtree_hash}')
                            operator_annotation_dict[cc][subtree_hash][p] += operand_annotations_dict[0][cc][subtree_hash][p]

            if verbose:
                print(f'/---- {token} ----')
                print()

            # Label each subtree with its own hash to know which to cancel later
            _ = [stack.pop() for _ in range(arity)]
            _ = [stack_annotations.pop() for _ in range(arity)]
            _ = [stack_labels.pop() for _ in range(arity)]
            stack.append([operator, operands])
            stack_annotations.append([operator_annotation_dict, operands_annotations_dicts])
            new_label = tuple(flatten_nested_list([operator, operands])[::-1])
            stack_labels.append([new_label, operands_labels])
            i -= 1
            continue

        if token in self.operator_arity:
            operator = token
            arity = self.operator_arity[token]
            operands = list(reversed(stack[-arity:]))
            operands_annotations_dicts = list(reversed(stack_annotations[-arity:]))
            operands_labels = list(reversed(stack_labels[-arity:]))

            # Label each subtree with its own hash to know which to cancel later
            _ = [stack.pop() for _ in range(arity)]
            _ = [stack_annotations.pop() for _ in range(arity)]
            _ = [stack_labels.pop() for _ in range(arity)]
            stack.append([operator, operands])
            stack_annotations.append([{cc: {} for cc in self.connection_classes}, operands_annotations_dicts])
            new_label = tuple(flatten_nested_list([operator, operands])[::-1])
            stack_labels.append([new_label, operands_labels])
            i -= 1
            continue

        stack.append([token])
        stack_annotations.append([{cc: {tuple([token]): [1, 0]} for cc in self.connection_classes}])
        stack_labels.append([tuple([token])])
        i -= 1

    if verbose:
        pprint.pprint(stack_annotations)
        print()

    return stack, stack_annotations, stack_labels

cancel_terms

cancel_terms(expression_tree: list, expression_annotations_tree: list, stack_labels: list, collect_statistics: bool = False, verbose: bool = False) -> list[str]

Reconstructs an expression, cancelling terms based on multiplicity counts.

Using the annotated tree from collect_multiplicities, this method identifies the best candidate for cancellation (e.g., a term that appears with both positive and negative signs). It then rebuilds the expression while replacing the cancelled terms with the appropriate neutral element ('0' for addition, '1' for multiplication) or a simplified form (e.g., x + x becomes 2 * x).

PARAMETER DESCRIPTION
expression_tree

The stack produced by collect_multiplicities, containing the nested expression structure.

TYPE: list

expression_annotations_tree

The parallel stack of multiplicity annotations returned by collect_multiplicities.

TYPE: list

stack_labels

The parallel stack of subtree labels returned by collect_multiplicities.

TYPE: list

collect_statistics

If True, records cancellation events in self.simplification_statistics. Defaults to False.

TYPE: bool DEFAULT: False

verbose

If True, prints detailed debugging information. Defaults to False.

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
list[str]

A simplified prefix expression with the detected duplicates merged or removed.

Source code in src/simplipy/engine.py
def cancel_terms(self, expression_tree: list, expression_annotations_tree: list, stack_labels: list, collect_statistics: bool = False, verbose: bool = False) -> list[str]:
    """Reconstructs an expression, cancelling terms based on multiplicity counts.

    Using the annotated tree from `collect_multiplicities`, this method
    identifies the best candidate for cancellation (e.g., a term that appears
    with both positive and negative signs). It then rebuilds the expression
    while replacing the cancelled terms with the appropriate neutral element
    ('0' for addition, '1' for multiplication) or a simplified form (e.g.,
    `x + x` becomes `2 * x`).

    Parameters
    ----------
    expression_tree : list
        The stack produced by `collect_multiplicities`, containing the
        nested expression structure.
    expression_annotations_tree : list
        The parallel stack of multiplicity annotations returned by
        `collect_multiplicities`.
    stack_labels : list
        The parallel stack of subtree labels returned by
        `collect_multiplicities`.
    collect_statistics : bool, optional
        If True, records cancellation events in
        ``self.simplification_statistics``. Defaults to False.
    verbose : bool, optional
        If True, prints detailed debugging information. Defaults to False.

    Returns
    -------
    list[str]
        A simplified prefix expression with the detected duplicates merged
        or removed.
    """
    stats = self.simplification_statistics if collect_statistics else None

    stack = expression_tree
    stack_annotations = expression_annotations_tree
    stack_parity = [{cc: 1 for cc in self.connection_classes} for _ in range(len(stack_labels))]
    stack_still_connected = [False]

    expression: list[str] = []

    cancellation_candidate = None
    n_replaced = 0
    still_connected = False

    while len(stack) > 0:
        subtree = stack.pop()
        subtree_annotation = stack_annotations.pop()
        subtree_labels = stack_labels.pop()
        subtree_parities = stack_parity.pop()
        still_connected = stack_still_connected.pop()

        if cancellation_candidate is not None:
            argmax_class, cancelled_subtree, cancelled_multiplicity_sum = cancellation_candidate
            still_connected = still_connected and (subtree[0] in self.connection_classes[argmax_class][0] or subtree[0] not in self.operator_arity)

            if still_connected:
                if cancelled_subtree == subtree_labels[0]:
                    neutral_element = self.connection_classes[argmax_class][1]

                    if cancelled_subtree == ('<constant>',):
                        first_replacement = ('<constant>',)
                        other_replacements: str | tuple[str, ...] = neutral_element
                    else:
                        current_parity = subtree_parities[argmax_class]
                        inverse_operator = self.connection_classes_inverse[argmax_class]

                        if verbose:
                            print()
                            print(f'Processing subtree {subtree_labels[0]} with current parity {current_parity} and total multiplicity sum {cancelled_multiplicity_sum}')

                        # FIXME
                        if current_parity * cancelled_multiplicity_sum >= 0:  # Negative parity and negative multiplicity cancel out
                            inverse_operator_prefix: tuple[str, ...] = ()
                            double_inverse_operator_prefix: tuple[str, ...] = (inverse_operator,)
                        else:
                            inverse_operator_prefix = (inverse_operator,)
                            double_inverse_operator_prefix = ()

                        if verbose:
                            print(f'Inverse operator prefix: {inverse_operator_prefix}, double inverse operator prefix: {double_inverse_operator_prefix}')

                        if cancelled_multiplicity_sum == 0:
                            # Term is cancelled entirely. Replace all occurences with the neutral element
                            first_replacement = (neutral_element,)
                            other_replacements = neutral_element
                            if verbose:
                                print(f'Cancelled term {cancelled_subtree} entirely: first replacement {first_replacement}, other replacements {other_replacements}')

                        if abs(cancelled_multiplicity_sum) == 1:
                            # Term occurs once. Replace every occurence after the first one with the neutral element
                            first_replacement = inverse_operator_prefix + cancelled_subtree
                            other_replacements = (neutral_element,)
                            if verbose:
                                print(f'Cancelled term {cancelled_subtree} once: first replacement {first_replacement}, other replacements {other_replacements}')

                        if abs(cancelled_multiplicity_sum) > 1:
                            # Term occurs multiple times. Replace the first occurence with a multiplication or power of the term. Replace every occurence after the first one with the neutral element
                            hyper_operator = self.connection_classes_hyper[argmax_class]
                            operator = self.connection_classes[argmax_class][0][0]  # Positive multiplicity
                            try:
                                if cancelled_multiplicity_sum > 5 and is_prime(abs(cancelled_multiplicity_sum)):
                                    powers = factorize_to_at_most(abs(cancelled_multiplicity_sum) - 1, self.max_power)
                                    first_replacement = inverse_operator_prefix + (operator,) + tuple(f'{hyper_operator}{p}' for p in powers) + cancelled_subtree + cancelled_subtree
                                else:
                                    powers = factorize_to_at_most(abs(cancelled_multiplicity_sum), self.max_power)
                                    first_replacement = inverse_operator_prefix + tuple(f'{hyper_operator}{p}' for p in powers) + cancelled_subtree
                            except ValueError:
                                # Fall back to a representation that stays within the
                                # available operator set. For additive contexts we use
                                # a binary multiplication with an explicit integer
                                # coefficient; for multiplicative contexts we use the
                                # binary ``pow`` operator with a numeric exponent.
                                magnitude = abs(cancelled_multiplicity_sum)
                                coefficient_token = str(magnitude)

                                if argmax_class == 'add':
                                    fallback_prefix = inverse_operator_prefix + ('*', coefficient_token)
                                    first_replacement = fallback_prefix + cancelled_subtree
                                else:
                                    # Multiplicative class
                                    fallback_prefix = inverse_operator_prefix + ('pow',)
                                    first_replacement = fallback_prefix + cancelled_subtree + (coefficient_token,)

                            other_replacements = (neutral_element,)

                            if verbose:
                                print(f'Cancelled term {cancelled_subtree} multiple times: first replacement {first_replacement}, other replacements {other_replacements}')

                            if verbose:
                                print(f'Cancelled term {cancelled_subtree} multiple times inverted: first replacement {first_replacement}, other replacements {other_replacements}')

                    if n_replaced == 0:
                        expression.extend(first_replacement)
                        if verbose:
                            print(f'{n_replaced}: Added first replacement {first_replacement} to expression')
                    else:
                        expression.extend(other_replacements)
                        if verbose:
                            print(f'{n_replaced}: Added other replacements {other_replacements} to expression')
                    n_replaced += 1
                    continue

        # Leaf node
        if len(subtree) == 1:
            expression.append(subtree[0])
            continue

        # Non-leaf node
        operator, operands = subtree
        _, operands_annotations_sets = subtree_annotation
        _, operands_labels = subtree_labels
        operator_parity = subtree_parities  # No operand parity information yet

        # TODO: Propagate parities of unary inverse operators

        if verbose:
            print(f'Operator {operator} with operands {operands} is still connected: {still_connected}')
            print(f'Operator parities: {operator_parity}')

        if operator in self.binary_connectable_operators:
            propagated_operand_parities: list[dict[str, int]] = [{}, {}]
            if still_connected:
                for cc, (operator_set, _) in self.connection_classes.items():
                    propagated_operand_parities[0][cc] = operator_parity[cc]
                    propagated_operand_parities[1][cc] = operator_parity[cc] * (-1 if operator == self.operator_inverses[operator_set[0]] else 1)
                if verbose:
                    print(f'Propagated operand parities: {propagated_operand_parities}')
            else:
                for cc, (operator_set, _) in self.connection_classes.items():
                    propagated_operand_parities[0][cc] = 1
                    propagated_operand_parities[1][cc] = (-1 if operator == self.operator_inverses[operator_set[0]] else 1)
                if verbose:
                    print(f'Reset parities to {propagated_operand_parities}')

            # If no cancellation candidate has been identified yet, try to find one in the current subtree
            if cancellation_candidate is None:
                for cc in self.connection_classes:
                    for subtree_hash, multiplicity in subtree_annotation[0][cc].items():
                        # Consider candidates where
                        # 1. there is something to cancel (i.e. the sum of the absolute multiplicities is greater than 1)
                        # 2. constants are allowed to be cancelled:
                        #   a. single constants <constant> can be cancelled
                        #   b. composite terms with constants cannot be cancelled with the current method (one <constant> needs to survive)
                        if sum(abs(m) for m in multiplicity) > 1 and ('<constant>' not in subtree_hash or len(subtree_hash) == 1):  # Cannot cancel terms with arbitrary constants
                            cancellation_candidate = (cc, subtree_hash, multiplicity[0] - multiplicity[1])
                            still_connected = True
                            if stats is not None:
                                neutral_insertions = sum(abs(m) for m in multiplicity) - max(1, abs(multiplicity[0] - multiplicity[1]))
                                stats.cancellation_events.append({
                                    'class': cc,
                                    'subtree': subtree_hash,
                                    'multiplicity_sum': multiplicity[0] - multiplicity[1],
                                    'neutral_insertions': neutral_insertions,
                                })

            # Add the operator to the expression
            expression.append(operator)

            # Add the children to the stack
            for operand, operand_an, operand_label, propagated_operand_parity in zip(
                    reversed(operands),
                    reversed(operands_annotations_sets),
                    reversed(operands_labels),
                    reversed(propagated_operand_parities)):
                stack.append(operand)
                stack_annotations.append(operand_an)
                stack_labels.append(operand_label)
                stack_parity.append(propagated_operand_parity)
                stack_still_connected.append(still_connected)

        else:
            # Add the operator to the expression
            expression.append(operator)

            # Add the children to the stack
            for operand, operand_an, operand_label in zip(reversed(operands), reversed(operands_annotations_sets), reversed(operands_labels)):
                stack.append(operand)
                stack_annotations.append(operand_an)
                stack_labels.append(operand_label)
                stack_parity.append({cc: 1 for cc in self.connection_classes})
                stack_still_connected.append(still_connected)

    return expression

sort_operands

sort_operands(expression: list[str] | tuple[str, ...]) -> list[str]

Sorts the operands of commutative operators to create a canonical form.

This method traverses the expression and, for any commutative operator (like + or *), it sorts its operands based on a consistent key. This ensures that expressions like b + a and a + b are treated as identical.

PARAMETER DESCRIPTION
expression

The expression in prefix notation.

TYPE: list[str] or tuple[str, ...]

RETURNS DESCRIPTION
list[str]

The expression with sorted operands, in prefix notation.

Source code in src/simplipy/engine.py
def sort_operands(self, expression: list[str] | tuple[str, ...]) -> list[str]:
    """Sorts the operands of commutative operators to create a canonical form.

    This method traverses the expression and, for any commutative operator
    (like `+` or `*`), it sorts its operands based on a consistent key.
    This ensures that expressions like `b + a` and `a + b` are treated as
    identical.

    Parameters
    ----------
    expression : list[str] or tuple[str, ...]
        The expression in prefix notation.

    Returns
    -------
    list[str]
        The expression with sorted operands, in prefix notation.
    """
    stack: list = []
    i = len(expression) - 1

    while i >= 0:
        token = expression[i]

        if token in self.operator_arity_compat or token in self.operator_aliases:
            operator = self.operator_aliases.get(token, token)
            arity = self.operator_arity_compat[operator]
            operands = list(reversed(stack[-arity:]))

            if operator in self.commutative_operators:
                # Check for the pattern [*, *, A, B, C] -> [*, A, *, B, C] or [+, +, A, B, C] -> [+, A, +, B, C]
                if len(operands[0]) == 2 and operator == operands[0][0]:
                    _ = [stack.pop() for _ in range(arity)]
                    stack.append([operator, [operands[0][1][0], [operator, [operands[0][1][1], operands[1]]]]])
                    i -= 1
                    continue

                subtree = [operator, operands]

                # Traverse through the tree in breadth-first order
                queue = [subtree]
                commutative_paths: list[tuple] = [tuple()]
                commutative_positions = []
                while queue:
                    node = queue.pop(0)
                    current_path = commutative_paths.pop(0)
                    for child_index, child in enumerate(node[1]):  # I conclude that using `i` as a variable name here is not very clever
                        if len(child) > 1:
                            if child[0] == node[0]:
                                # Continue: Same commutative perator
                                queue.append(child)
                                commutative_paths.append(current_path + (child_index,))
                            else:
                                # Stop: Different operator
                                commutative_positions.append(current_path + (child_index,))
                        else:
                            # Stop: Leaf
                            commutative_positions.append(current_path + (child_index,))

                # Sort the positions
                sorted_indices = sorted(range(len(commutative_positions)), key=lambda x: commutative_positions[x])

                commutative_paths = [commutative_positions[i] for i in sorted_indices]
                commutative_positions = [commutative_positions[i] for i in sorted_indices]

                operands_to_sort = []
                for position in commutative_positions:
                    node = subtree
                    for position_index in position:
                        node = node[1][position_index]
                    operands_to_sort.append(node)

                sorted_operands = sorted(operands_to_sort, key=self.operand_key)

                # Replace the operands in the tree
                new_subtree: list = deepcopy(subtree)

                for position, operand in zip(commutative_positions, sorted_operands):
                    node = new_subtree
                    for position_index in position:
                        node = node[1][position_index]
                    node[:] = operand

                operands = new_subtree[1]

                _ = [stack.pop() for _ in range(arity)]
                stack.append([operator, operands])
                i -= 1
                continue

            _ = [stack.pop() for _ in range(arity)]
            stack.append([operator, operands])

        else:
            stack.append([token])

        i -= 1

    return flatten_nested_list(stack)[::-1]

simplify

simplify(expression: str | list[str] | tuple[str, ...] | ndarray, max_iter: int = 5, max_pattern_length: int | None = None, mask_elementary_literals: bool = True, apply_simplification_rules: bool = True, inplace: bool = False, collect_statistics: bool = False, verbose: bool = False) -> str | list[str] | tuple[str, ...] | np.ndarray

Performs a full simplification of a mathematical expression.

This is the main public method for simplification. It iteratively applies term cancellation, rule-based simplification, and operand sorting until the expression stops changing or max_iter is reached.

PARAMETER DESCRIPTION
expression

The expression to simplify, given as an infix string, a prefix token list/tuple, or a one-dimensional numpy array of tokens.

TYPE: str or list[str] or tuple[str, ...] or ndarray

max_iter

The maximum number of simplification iterations. Defaults to 5.

TYPE: int DEFAULT: 5

max_pattern_length

The maximum length of a rule pattern to consider.

TYPE: int or None DEFAULT: None

mask_elementary_literals

If True, replaces literals like '0' and '1' that result from cancellation with a generic <constant> token. Defaults to True.

TYPE: bool DEFAULT: True

apply_simplification_rules

If False, skips the rule-based simplification step. Defaults to True.

TYPE: bool DEFAULT: True

inplace

If the input is a list, this modifies it directly. Defaults to False.

TYPE: bool DEFAULT: False

collect_statistics

If True, populates self.simplification_statistics with a fresh :class:SimplificationStatistics instance containing detailed metrics about the simplification run. Defaults to False.

TYPE: bool DEFAULT: False

verbose

If True, prints the expression after each simplification step.

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
str or list[str] or tuple[str, ...] or ndarray

The simplified expression, in the same format as the input. If the simplification results in a longer expression, the original expression is returned.

Source code in src/simplipy/engine.py
def simplify(
        self,
        expression: str | list[str] | tuple[str, ...] | np.ndarray,
        max_iter: int = 5,
        max_pattern_length: int | None = None,
        mask_elementary_literals: bool = True,
        apply_simplification_rules: bool = True,
        inplace: bool = False,
        collect_statistics: bool = False,
        verbose: bool = False) -> str | list[str] | tuple[str, ...] | np.ndarray:
    """Performs a full simplification of a mathematical expression.

    This is the main public method for simplification. It iteratively
    applies term cancellation, rule-based simplification, and operand
    sorting until the expression stops changing or `max_iter` is reached.

    Parameters
    ----------
    expression : str or list[str] or tuple[str, ...] or np.ndarray
        The expression to simplify, given as an infix string, a prefix
        token list/tuple, or a one-dimensional numpy array of tokens.
    max_iter : int, optional
        The maximum number of simplification iterations. Defaults to 5.
    max_pattern_length : int or None, optional
        The maximum length of a rule pattern to consider.
    mask_elementary_literals : bool, optional
        If True, replaces literals like '0' and '1' that result from
        cancellation with a generic `<constant>` token. Defaults to True.
    apply_simplification_rules : bool, optional
        If False, skips the rule-based simplification step. Defaults to True.
    inplace : bool, optional
        If the input is a list, this modifies it directly. Defaults to False.
    collect_statistics : bool, optional
        If True, populates ``self.simplification_statistics`` with a fresh
        :class:`SimplificationStatistics` instance containing detailed
        metrics about the simplification run.  Defaults to False.
    verbose : bool, optional
        If True, prints the expression after each simplification step.

    Returns
    -------
    str or list[str] or tuple[str, ...] or np.ndarray
        The simplified expression, in the same format as the input. If the
        simplification results in a longer expression, the original
        expression is returned.
    """
    if collect_statistics:
        self.simplification_statistics = SimplificationStatistics()
    else:
        self.simplification_statistics = None

    list_expression_ref: list[str] | None = None
    original_expression: str | list[str] | tuple[str, ...] | np.ndarray
    current_expression: list[str]

    if isinstance(expression, str):
        return_type = 'str'
        original_expression = "" + expression  # Create a copy
        current_expression = self.parse(expression, convert_expression=True, mask_numbers=False)
    elif isinstance(expression, tuple):
        return_type = 'tuple'
        original_expression = expression  # No need to copy immutable tuple
        current_expression = list(expression)
    elif isinstance(expression, np.ndarray):
        if expression.ndim != 1:
            raise ValueError('`simplify` expects a one-dimensional numpy array of tokens')
        if expression.dtype.kind not in {'U', 'S', 'O'}:
            raise ValueError('`simplify` expects a numpy array of string-like tokens')
        if inplace:
            raise ValueError('`inplace=True` is not supported when the expression is a numpy array')
        return_type = 'np_array'
        original_expression = expression.copy()
        current_expression = cast(list[str], expression.tolist())
    else:
        return_type = 'list'
        list_expression_ref = expression
        original_expression = expression.copy()
        current_expression = expression.copy()

    new_expression: list[str] = current_expression.copy()

    length_before = len(current_expression)

    if verbose:
        print(f'Initial expression: {new_expression}')

    # # Apply simplification rules and sort operands to get started
    # if apply_simplification_rules:
    #     new_expression = self.apply_simplifcation_rules(new_expression, max_pattern_length, collect_statistics=collect_statistics, verbose=verbose)

    # if verbose:
    #     print(f'_apply_simplifcation_rules: {new_expression}')

    iterations_used = 0
    converged = False

    for i in range(max_iter):
        iterations_used = i + 1

        # Cancel any terms
        t0 = time.perf_counter() if collect_statistics else 0.0
        expression_tree, annotated_expression_tree, stack_labels = self.collect_multiplicities(new_expression, verbose=verbose)
        new_expression = self.cancel_terms(expression_tree, annotated_expression_tree, stack_labels, collect_statistics=collect_statistics, verbose=verbose)
        if collect_statistics:
            self.simplification_statistics.stage_timings['cancel_terms'] += time.perf_counter() - t0  # type: ignore[union-attr]

        if verbose:
            print(f'{i}: cancel_terms: {new_expression}')

        iteration_lengths: dict[str, int] = {}
        if collect_statistics:
            iteration_lengths['after_cancel'] = len(new_expression)

        # Apply simplification rules
        if apply_simplification_rules:
            t0 = time.perf_counter() if collect_statistics else 0.0
            new_expression = self.apply_simplifcation_rules(new_expression, max_pattern_length, collect_statistics=collect_statistics, verbose=verbose)
            if collect_statistics:
                self.simplification_statistics.stage_timings['apply_rules'] += time.perf_counter() - t0  # type: ignore[union-attr]

        if verbose:
            print(f'{i}: _apply_simplifcation_rules: {new_expression}')

        if collect_statistics:
            iteration_lengths['after_rules'] = len(new_expression)
            self.simplification_statistics.per_iteration_lengths.append(iteration_lengths)  # type: ignore[union-attr]

        if new_expression == current_expression:
            converged = True
            break
        current_expression = new_expression

    # Sort operands
    t0 = time.perf_counter() if collect_statistics else 0.0
    new_expression = self.sort_operands(new_expression)
    if collect_statistics:
        self.simplification_statistics.stage_timings['sort_operands'] += time.perf_counter() - t0  # type: ignore[union-attr]

    if verbose:
        print(f'{i}: sort_operands: {new_expression}')

    if mask_elementary_literals:
        t0 = time.perf_counter() if collect_statistics else 0.0
        new_expression = mask_elementary_literals_fn(new_expression, inplace=inplace)
        if collect_statistics:
            self.simplification_statistics.stage_timings['mask_literals'] += time.perf_counter() - t0  # type: ignore[union-attr]

        if verbose:
            print(f'{i}: mask_elementary_literals: {new_expression}')

    result_rejected = len(new_expression) > length_before

    if collect_statistics:
        self.simplification_statistics.iterations_used = iterations_used  # type: ignore[union-attr]
        self.simplification_statistics.converged = converged  # type: ignore[union-attr]
        self.simplification_statistics.result_rejected = result_rejected  # type: ignore[union-attr]

    if result_rejected:
        # The expression has grown, which is not a simplification
        match return_type:
            case 'str':
                return original_expression
            case 'tuple':
                return tuple(original_expression)
            case 'np_array':
                original_np_expression = cast(np.ndarray, original_expression)
                return original_np_expression.copy()
            case 'list':
                if inplace and list_expression_ref is not None:
                    list_expression_ref[:] = original_expression
                    return list_expression_ref
                return original_expression
        return original_expression

    match return_type:
        case 'str':
            return self.prefix_to_infix(new_expression, realization=False, power='**')
        case 'tuple':
            return tuple(new_expression)
        case 'np_array':
            original_np_expression = cast(np.ndarray, original_expression)
            return np.array(new_expression, dtype=original_np_expression.dtype, copy=True)
        case 'list':
            if inplace and list_expression_ref is not None:
                list_expression_ref[:] = new_expression
                return list_expression_ref
            return new_expression

    return new_expression

exist_constants_that_fit

exist_constants_that_fit(expression: list[str] | tuple[str, ...], variables: list[str], X: ndarray, y_target: ndarray) -> bool

Checks if numerical constants exist to make an expression fit data.

Given an expression with <constant> placeholders, this method uses scipy.optimize.curve_fit to determine if there is a set of numerical values for these placeholders that makes the expression accurately model the relationship between input data X and target data y_target.

PARAMETER DESCRIPTION
expression

The prefix expression, potentially containing <constant> tokens.

TYPE: list[str] or tuple[str, ...]

variables

A list of variable names corresponding to the columns of X.

TYPE: list[str]

X

The input data, with shape (n_samples, n_variables).

TYPE: ndarray

y_target

The target data to be fitted.

TYPE: ndarray

RETURNS DESCRIPTION
bool

True if a set of constants is found that results in a close fit, False otherwise.

Source code in src/simplipy/engine.py
def exist_constants_that_fit(self, expression: list[str] | tuple[str, ...], variables: list[str], X: np.ndarray, y_target: np.ndarray) -> bool:
    """Checks if numerical constants exist to make an expression fit data.

    Given an expression with `<constant>` placeholders, this method uses
    `scipy.optimize.curve_fit` to determine if there is a set of numerical
    values for these placeholders that makes the expression accurately
    model the relationship between input data `X` and target data `y_target`.

    Parameters
    ----------
    expression : list[str] or tuple[str, ...]
        The prefix expression, potentially containing `<constant>` tokens.
    variables : list[str]
        A list of variable names corresponding to the columns of `X`.
    X : np.ndarray
        The input data, with shape (n_samples, n_variables).
    y_target : np.ndarray
        The target data to be fitted.

    Returns
    -------
    bool
        True if a set of constants is found that results in a close fit,
        False otherwise.
    """
    if isinstance(expression, tuple):
        expression = list(expression)

    executable_prefix_expression = self.operators_to_realizations(expression)
    prefix_expression_with_constants, constants = explicit_constant_placeholders(executable_prefix_expression, convert_numbers_to_constant=False)
    code_string = self.prefix_to_infix(prefix_expression_with_constants, realization=True)
    code = codify(code_string, variables + constants)
    f = self.code_to_lambda(code)

    def pred_function(X: np.ndarray, *constants: np.ndarray | None) -> float | np.ndarray:
        if len(constants) == 0:
            y = safe_f(f, X)
        y = safe_f(f, X, constants)

        # If the numbers are complex, return nan
        if np.iscomplexobj(y):
            return np.full(X.shape[0], np.nan)

        return y

    p0 = np.random.normal(loc=0, scale=5, size=len(constants))

    is_valid = np.isfinite(X).all(axis=1) & np.isfinite(y_target)

    if not np.any(is_valid) or len(constants) > is_valid.sum():  # https://github.com/scipy/scipy/issues/13969
        return False

    try:
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", category=OptimizeWarning)
            popt, _ = curve_fit(pred_function, X[is_valid], y_target[is_valid].flatten(), p0=p0)
    except RuntimeError:
        return False

    y = f(*X.T, *popt)
    if not isinstance(y, np.ndarray):
        y = np.full(X.shape[0], y)  # type: ignore

    return np.allclose(y_target, y, equal_nan=True)

find_rule_worker

find_rule_worker(worker_id: int, work_queue: Queue, result_queue: Queue, X_shape: tuple, X_dtype: dtype, X_shm_name: str, expressions_of_length_and_variables: dict, dummy_variables: list[str], operator_arity: dict, constants_fit_challenges: int, constants_fit_retries: int) -> None

A worker process for discovering simplification rules in parallel.

This function runs in a separate process. It fetches work items of the form (expression, simplified_length, allowed_candidate_lengths) from work_queue, evaluates the expression on shared random data, and compares the result against a library of simpler candidate expressions. If a numerical equivalence is found, it is considered a potential new simplification rule and is placed on the result_queue; otherwise None is queued to signal that no rule was discovered. A sentinel None work item triggers a graceful shutdown.

Notes

This method is designed for internal use by the find_rules method and is not intended to be called directly.

Source code in src/simplipy/engine.py
def find_rule_worker(
        self,
        worker_id: int,
        work_queue: Queue,
        result_queue: Queue,
        X_shape: tuple,
        X_dtype: np.dtype,
        X_shm_name: str,
        expressions_of_length_and_variables: dict,
        dummy_variables: list[str],
        operator_arity: dict,
        constants_fit_challenges: int,
        constants_fit_retries: int) -> None:
    """A worker process for discovering simplification rules in parallel.

    This function runs in a separate process. It fetches work items of the
    form ``(expression, simplified_length, allowed_candidate_lengths)`` from
    `work_queue`, evaluates the expression on shared random data, and
    compares the result against a library of simpler candidate expressions.
    If a numerical equivalence is found, it is considered a potential new
    simplification rule and is placed on the `result_queue`; otherwise ``None``
    is queued to signal that no rule was discovered. A sentinel ``None`` work
    item triggers a graceful shutdown.

    Notes
    -----
    This method is designed for internal use by the `find_rules` method
    and is not intended to be called directly.
    """

    signal.signal(signal.SIGINT, signal.SIG_IGN)

    try:
        # Reconstruct arrays from shared memory
        X_shm = SharedMemory(name=X_shm_name)
        X: np.ndarray = np.ndarray(X_shape, dtype=X_dtype, buffer=X_shm.buf)

        # Main work loop
        while True:
            work_item = work_queue.get()

            # Check for sentinel
            if work_item is None:
                break

            expression, simplified_length, allowed_candidate_lengths = work_item

            if len(allowed_candidate_lengths) == 0 or max(allowed_candidate_lengths) <= 0 or simplified_length <= min(allowed_candidate_lengths):  # Request unrealistic simplification or already have better simplification than requested
                # No candidates allowed, skip this expression
                result_queue.put(None)
                continue

            # Check if purely numerical
            if all([t == '<constant>' or t in operator_arity for t in expression]) and len(expression) > 1:
                result_queue.put((expression, ('<constant>',)))
                continue

            expression_variables = list(set(expression) & set(dummy_variables))

            # Evaluate expression
            executable_prefix_expression = self.operators_to_realizations(expression)
            prefix_expression_with_constants, constants = explicit_constant_placeholders(executable_prefix_expression, convert_numbers_to_constant=False)
            code_string = self.prefix_to_infix(prefix_expression_with_constants, realization=True)
            code = codify(code_string, dummy_variables + constants)

            f = self.code_to_lambda(code)

            # Suppress warnings in worker
            with warnings.catch_warnings():
                warnings.filterwarnings("ignore", category=RuntimeWarning)

                found_simplifications = []

                # Check against all smaller expressions
                for candidate_length in allowed_candidate_lengths:
                    for candidate_variables, candidate_expressions in expressions_of_length_and_variables.get(candidate_length, {}).items():
                        if any(var not in expression_variables for var in candidate_variables):
                            # The candidate expression contains variables not in the original expression. It cannot be a simplification.
                            continue

                        for candidate_expression in candidate_expressions:
                            executable_candidate = self.operators_to_realizations(candidate_expression)
                            prefix_candidate_w_constants, candidate_constants = explicit_constant_placeholders(executable_candidate, convert_numbers_to_constant=False)
                            candidate_code = self.prefix_to_infix(prefix_candidate_w_constants, realization=True)
                            candidate_compiled = codify(candidate_code, dummy_variables + candidate_constants)
                            f_candidate = self.code_to_lambda(candidate_compiled)

                            # Check if expressions are equivalent
                            if len(candidate_constants) == 0:
                                y_candidate = safe_f(f_candidate, X)
                                if not isinstance(y_candidate, np.ndarray):
                                    y_candidate = np.full(X.shape[0], y_candidate)

                                # Resample constants to avoid false positives
                                # The expression is considered a match unless one of the challenges fails
                                expressions_match = True
                                for challenge_id in range(constants_fit_challenges):
                                    random_constants = np.random.normal(loc=0, scale=5, size=len(constants))
                                    # Try all combinations of positive and negative constants
                                    for positive_negative_constant_combination in product((-1, 0, 1), repeat=len(constants)):
                                        y = safe_f(f, X, np.abs(random_constants) * positive_negative_constant_combination)  # abs may be redundant here
                                        if not np.allclose(y, y_candidate, equal_nan=True):
                                            expressions_match = False
                                            break

                                    if not expressions_match:
                                        # A combination produced a different result, abort this candidate
                                        break

                            else:
                                # Resample constants to avoid false positives
                                # The expression is considered a match unless one of the challenges fails
                                expressions_match = True
                                for challenge_id in range(constants_fit_challenges):
                                    # Need to check if constants can be fitted
                                    random_constants = np.random.normal(loc=0, scale=5, size=len(constants))
                                    # Try all combinations of positive and negative constants
                                    for positive_negative_constant_combination in product((-1, 0, 1), repeat=len(constants)):
                                        y = safe_f(f, X, np.abs(random_constants) * positive_negative_constant_combination)  # abs may be redundant here
                                        for _ in range(constants_fit_retries):
                                            if self.exist_constants_that_fit(candidate_expression, dummy_variables, X, y):
                                                # Found a candidate that fits, next challenge please
                                                break
                                        else:
                                            # No candidate found that fits, not all challenges could be solved, abort this candidate
                                            expressions_match = False
                                            break

                                    if not expressions_match:
                                        # A combination produced a different result, abort this candidate
                                        break

                            if expressions_match:
                                found_simplifications.append(candidate_expression)
                                # Still check for further candidates of the same length

                    if found_simplifications:
                        # Found at least one simplification for the current length
                        # Every further candidate will be longer, so we can stop checking
                        break

            if not found_simplifications:
                # No simplification found
                result_queue.put(None)
            else:
                # Prefer candidates with fewer <constant> tokens; among ties, keep discovery order.
                # Lazily check the non-increasing wildcard multiplicity condition (no subtree duplication).
                found_simplifications.sort(key=lambda s: s.count('<constant>'))
                selected = None
                for candidate in found_simplifications:
                    if not violates_wildcard_multiplicity(expression, candidate):
                        selected = candidate
                        break
                if selected is not None:
                    result_queue.put((expression, selected))
                else:
                    # All candidates violate wildcard multiplicity
                    result_queue.put(None)

    except Exception as e:
        # Log exceptions to result queue
        result_queue.put(('ERROR', e, (expression, simplified_length, allowed_candidate_lengths)))
    finally:
        X_shm.close()

find_rules

find_rules(max_source_pattern_length: int = 7, max_target_pattern_length: int | None = None, dummy_variables: int | list[str] | None = None, extra_internal_terms: list[str] | None = None, X: ndarray | int | None = None, constants_fit_challenges: int = 5, constants_fit_retries: int = 5, output_file: str | None = None, save_every: int = 100, reset_rules: bool = True, prune: bool = False, verbose: bool = False, n_workers: int | None = None) -> None

Systematically discovers new simplification rules.

This powerful method automates the discovery of simplification rules. It operates in two phases: 1. Generation: It combinatorially generates all possible valid expressions up to max_source_pattern_length. 2. Verification: It uses a pool of worker processes to test each generated expression for equivalence with any shorter expression. Equivalences are found by evaluating both expressions on random numerical data.

Discovered rules are deduplicated, compiled into the running engine, and can optionally be saved to disk.

PARAMETER DESCRIPTION
max_source_pattern_length

The maximum length of expressions to generate and test.

TYPE: int DEFAULT: 7

max_target_pattern_length

The maximum length of a valid simplified expression. If None, any shorter expression is considered a valid simplification.

TYPE: int or None DEFAULT: None

dummy_variables

The variables to use when generating expressions.

TYPE: int or list[str] or None DEFAULT: None

extra_internal_terms

Additional leaf nodes (e.g., '') to include.

TYPE: list[str] or None DEFAULT: None

X

The numerical data for testing equivalence. If an int, specifies the number of samples to generate. If None, defaults to 1024 samples.

TYPE: ndarray or int or None DEFAULT: None

constants_fit_challenges

Number of random constant sets to test for equivalence.

TYPE: int DEFAULT: 5

constants_fit_retries

Number of retries for the curve fitting process.

TYPE: int DEFAULT: 5

output_file

If provided, saves the discovered rules to this JSON file.

TYPE: str or None DEFAULT: None

save_every

How often to save the rules to the output file.

TYPE: int DEFAULT: 100

reset_rules

If True, clears existing rules before starting.

TYPE: bool DEFAULT: True

prune

If True, runs :meth:prune_redundant_rules after discovery to remove explicit rules that are subsumed by wildcard-pattern rules. This can be expensive for large rule sets. Defaults to False.

TYPE: bool DEFAULT: False

verbose

If True, shows progress bars and status updates.

TYPE: bool DEFAULT: False

n_workers

Number of parallel processes to use. Defaults to the number of CPU cores.

TYPE: int or None DEFAULT: None

Source code in src/simplipy/engine.py
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
def find_rules(
        self,
        max_source_pattern_length: int = 7,
        max_target_pattern_length: int | None = None,
        dummy_variables: int | list[str] | None = None,
        extra_internal_terms: list[str] | None = None,
        X: np.ndarray | int | None = None,
        constants_fit_challenges: int = 5,
        constants_fit_retries: int = 5,
        output_file: str | None = None,
        save_every: int = 100,
        reset_rules: bool = True,
        prune: bool = False,
        verbose: bool = False,
        n_workers: int | None = None) -> None:
    """Systematically discovers new simplification rules.

    This powerful method automates the discovery of simplification rules.
    It operates in two phases:
    1.  **Generation**: It combinatorially generates all possible valid
        expressions up to `max_source_pattern_length`.
    2.  **Verification**: It uses a pool of worker processes to test each
        generated expression for equivalence with any shorter expression.
        Equivalences are found by evaluating both expressions on random
        numerical data.

    Discovered rules are deduplicated, compiled into the running engine, and
    can optionally be saved to disk.

    Parameters
    ----------
    max_source_pattern_length : int, optional
        The maximum length of expressions to generate and test.
    max_target_pattern_length : int or None, optional
        The maximum length of a valid simplified expression. If None, any
        shorter expression is considered a valid simplification.
    dummy_variables : int or list[str] or None, optional
        The variables to use when generating expressions.
    extra_internal_terms : list[str] or None, optional
        Additional leaf nodes (e.g., '<constant>') to include.
    X : np.ndarray or int or None, optional
        The numerical data for testing equivalence. If an int, specifies
        the number of samples to generate. If None, defaults to 1024 samples.
    constants_fit_challenges : int, optional
        Number of random constant sets to test for equivalence.
    constants_fit_retries : int, optional
        Number of retries for the curve fitting process.
    output_file : str or None, optional
        If provided, saves the discovered rules to this JSON file.
    save_every : int, optional
        How often to save the rules to the output file.
    reset_rules : bool, optional
        If True, clears existing rules before starting.
    prune : bool, optional
        If True, runs :meth:`prune_redundant_rules` after discovery to
        remove explicit rules that are subsumed by wildcard-pattern rules.
        This can be expensive for large rule sets. Defaults to False.
    verbose : bool, optional
        If True, shows progress bars and status updates.
    n_workers : int or None, optional
        Number of parallel processes to use. Defaults to the number of CPU cores.
    """
    # Signal handler for main process
    interrupted = False

    def signal_handler(signum: Any, frame: Any) -> None:
        nonlocal interrupted
        interrupted = True
        print("\nInterrupt received, cleaning up...")

    # Set up signal handler in main process
    old_handler = signal.signal(signal.SIGINT, signal_handler)

    # All the initialization from the sequential version
    extra_internal_terms = extra_internal_terms or []

    if dummy_variables is None:
        max_leaf_nodes_if_operators_binary = int(max_source_pattern_length - (max_source_pattern_length - 1) / 2)
        dummy_variables = [f"x{i}" for i in range(max_leaf_nodes_if_operators_binary)]
        if verbose:
            print(f"Using {len(dummy_variables)} dummy variables: {dummy_variables}")
    elif isinstance(dummy_variables, int):
        dummy_variables = [f"x{i}" for i in range(dummy_variables)]

    if reset_rules:
        self.simplification_rules = []
        self.compile_rules()

    if X is None:
        X_data = np.random.normal(loc=0, scale=5, size=(1024, len(dummy_variables)))
    elif isinstance(X, int):
        X_data = np.random.normal(loc=0, scale=5, size=(X, len(dummy_variables)))

    leaf_nodes = dummy_variables + extra_internal_terms
    non_leaf_nodes = dict(sorted(self.operator_arity.items(), key=lambda x: x[1]))

    # --- Phase 1: Generate expressions ---
    if verbose:
        print(f"Phase 1: Generating all expressions up to length {max_source_pattern_length}")

    expressions_of_length: dict[int, set[tuple[str, ...]]] = defaultdict(set)
    new_expressions_of_length: defaultdict[int, set[tuple[str, ...]]] = defaultdict(set)

    # Initialize with leaf nodes
    for leaf in leaf_nodes:
        expressions_of_length[1].add((leaf,))

    # Generate expressions level by level
    new_sizes: set[int] = set()
    while max(expressions_of_length.keys()) < max_source_pattern_length:  # This means that every smaller size is already generated
        for expression in construct_expressions(expressions_of_length, non_leaf_nodes, must_have_sizes=new_sizes):
            new_expressions_of_length[len(expression)].add(expression)

        new_sizes = set()
        lengths_before = {k: len(v) for k, v in expressions_of_length.items()}
        for new_length, new_hashes in new_expressions_of_length.items():
            expressions_of_length[new_length].update(new_hashes)
        lengths_after = {k: len(v) for k, v in new_expressions_of_length.items()}

        for length in lengths_after.keys():
            if length not in lengths_before or lengths_after[length] > lengths_before[length]:
                new_sizes.add(length)

        if verbose:
            print(f'Constructed expressions of sizes {sorted(new_sizes)}:')
            for length, count in sorted(lengths_after.items()):
                print(f'  {length:2d}: {count:,} new expressions')

        # Move the new hashes to the main dictionary
        for length, new_hashes in new_expressions_of_length.items():
            expressions_of_length[length].update(new_hashes)

        new_expressions_of_length.clear()

    total_expressions = sum(len(v) for v in expressions_of_length.values())

    if verbose:
        print(f"Finished generating expressions up to size {max_source_pattern_length}. Total expressions: {total_expressions:,}")
        for length, expressions in sorted(expressions_of_length.items()):
            print(f"Size {length}: {len(expressions):,} expressions")

    expressions_of_length_and_variables: dict[int, dict[tuple[str, ...], set[tuple[str, ...]]]] = {}
    for length, expressions in expressions_of_length.items():
        expressions_of_length_and_variables[length] = defaultdict(set)
        for expression in expressions:
            expression_variables = list(set(expression) & set(dummy_variables))  # This gets the dummy variables used in the expression
            expressions_of_length_and_variables[length][tuple(sorted(expression_variables))].add(expression)

    # --- Phase 2: Parallel rule finding ---
    if n_workers is None:
        n_workers = mp.cpu_count()

    # Create shared memory for arrays
    X_shm = SharedMemory(create=True, size=X_data.nbytes)
    X_shared: np.ndarray = np.ndarray(X_data.shape, dtype=X_data.dtype, buffer=X_shm.buf)
    X_shared[:] = X_data[:]

    # Create queues
    work_queue: mp.Queue = mp.Queue()
    result_queue: mp.Queue = mp.Queue()

    # Start workers
    workers = []
    for i in range(n_workers):
        p = Process(
            target=self.find_rule_worker,
            args=(
                i, work_queue, result_queue,
                X_data.shape, X_data.dtype, X_shm.name,
                dict(expressions_of_length_and_variables),  # Make a copy for each worker
                dummy_variables,
                self.operator_arity,
                constants_fit_challenges,
                constants_fit_retries,
            )
        )
        p.daemon = True  # Make workers daemon processes
        p.start()
        workers.append(p)

    # Main processing loop
    n_scanned = 0
    active_tasks = 0

    # Create iterator over all work items
    work_items = [
        expression_to_simplify
        for _, expressions in sorted(expressions_of_length.items())  # We don't care about the variables here
        for expression_to_simplify in expressions
    ]
    work_iter = iter(work_items)

    pbar = tqdm(total=len(work_items), desc="Finding rules", disable=not verbose)

    current_length = 0

    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=RuntimeWarning)

        try:
            # Initial work distribution
            for _ in range(min(n_workers * 2, len(work_items))):  # Queue 2x workers for efficiency
                try:
                    expression_to_simplify = next(work_iter)
                    simplified_length = len(self.simplify(expression_to_simplify, max_iter=5))
                    # Skip expressions that already simplify via existing rules (Kruskal-style pruning)
                    if simplified_length < len(expression_to_simplify):
                        n_scanned += 1
                        pbar.update(1)
                        continue
                    if max_target_pattern_length is None:
                        allowed_candidate_lengths = tuple(range(simplified_length))
                    else:
                        allowed_candidate_lengths = tuple(range(min(simplified_length, max_target_pattern_length + 1)))
                    work_queue.put((expression_to_simplify, simplified_length, allowed_candidate_lengths))
                    active_tasks += 1
                except StopIteration:
                    break

            current_length = len(expression_to_simplify)

            # Process results and distribute new work
            try:
                while active_tasks > 0 and not interrupted:
                    # Get result with timeout to allow checking stop conditions
                    try:
                        result = result_queue.get(timeout=0.1)
                    except queue.Empty:
                        # Check if interrupted during wait
                        if interrupted:
                            break
                        continue

                    active_tasks -= 1

                    # Process result
                    if result is not None:
                        if result[0] == 'ERROR':
                            print(f"Error in worker {result[1]}: {result[2]}")
                            print(result[2])
                            raise result[1]

                        self.simplification_rules.append(result)

                    # Send new work if available (but not if interrupted)
                    if not interrupted:
                        try:
                            expression_to_simplify = next(work_iter)

                            if len(expression_to_simplify) > current_length:
                                # This means that the collected rules can be applied to coming expressions
                                # To avoid redundant rules, we incorporate the rules into the simplification to raise the requirements for rules
                                if verbose:
                                    print(f'Increasing expression length from {current_length} to {len(expression_to_simplify)}')
                                self.simplification_rules = deduplicate_rules(self.simplification_rules, dummy_variables, verbose=verbose)
                                self.compile_rules()
                                if output_file is not None:
                                    if verbose:
                                        print("Saving rules after increasing expression length...")
                                    with open(output_file, 'w') as file:
                                        json.dump(self.simplification_rules, file, indent=4)
                                current_length = len(expression_to_simplify)

                            simplified_length = len(self.simplify(expression_to_simplify, max_iter=5))
                            # Skip expressions that already simplify via existing rules (Kruskal-style pruning)
                            if simplified_length < len(expression_to_simplify):
                                n_scanned += 1
                                pbar.update(1)
                                continue
                            if max_target_pattern_length is None:
                                allowed_candidate_lengths = tuple(range(simplified_length))
                            else:
                                allowed_candidate_lengths = tuple(range(min(simplified_length, max_target_pattern_length + 1)))
                            work_queue.put((expression_to_simplify, simplified_length, allowed_candidate_lengths))
                            active_tasks += 1
                        except StopIteration:
                            pass

                    n_scanned += 1
                    pbar.update(1)
                    # Calculate the display string for the last rule with truncation
                    last_rule = self.simplification_rules[-1] if self.simplification_rules else 'None'
                    last_rule_str = str(last_rule)[:64].ljust(64)  # Truncate and pad

                    # Format with fixed widths
                    pbar.set_postfix_str(
                        f"Rules: {len(self.simplification_rules):>6,}, "  # 6 chars, right-aligned
                        f"Active tasks: {active_tasks:>3}, "              # 3 chars, right-aligned
                        f"Last rule: {last_rule_str}"                     # Fixed 30 chars
                    )

                    # Periodic saving
                    if output_file is not None and n_scanned % save_every == 0:
                        if verbose:
                            print(f"Saving rules after processing {n_scanned} expressions...")
                        self.simplification_rules = deduplicate_rules(self.simplification_rules, dummy_variables, verbose=verbose)
                        self.compile_rules()
                        with open(output_file, 'w') as file:
                            json.dump(self.simplification_rules, file, indent=4)
            except Exception as e:
                print(f"Error during processing: {e}")
                interrupted = True

        finally:
            # Restore original signal handler
            signal.signal(signal.SIGINT, old_handler)

            pbar.close()

            # Clean shutdown or force termination
            if interrupted:
                print("Force terminating workers...")
                for p in workers:
                    if p.is_alive():
                        p.terminate()
                        p.join(timeout=0.5)
                        if p.is_alive():
                            p.kill()
            else:
                # Normal shutdown
                print("Shutting down workers...")
                for _ in workers:
                    try:
                        work_queue.put(None, timeout=0.1)
                    except Exception as e:
                        print(e)
                        pass

                for p in workers:
                    p.join(timeout=2)
                    if p.is_alive():
                        p.terminate()

            # Cleanup resources
            try:
                X_shm.close()
                X_shm.unlink()
            except Exception as e:
                print(e)
                pass

            # Close queues
            work_queue.close()
            result_queue.close()

            if output_file is not None:
                if verbose:
                    print("Saving results...")
                time.sleep(1)  # Give time for the user to interrupt the process
                self.simplification_rules = deduplicate_rules(self.simplification_rules, dummy_variables, verbose=verbose)
                self.compile_rules()
                if prune:
                    self.prune_redundant_rules(verbose=verbose)
                with open(output_file, 'w') as file:
                    json.dump(self.simplification_rules, file, indent=4)

operand_key

operand_key(operands: list) -> tuple

Generates a key for sorting operands of a commutative operator.

The key is a tuple designed to produce a consistent, canonical ordering. It prioritizes variables, then numbers, and finally complex subtrees. Subtrees are sorted by length and then recursively by their contents.

PARAMETER DESCRIPTION
operands

The operand to generate a key for, represented as a tree node.

TYPE: list

RETURNS DESCRIPTION
tuple

A sortable key.

Source code in src/simplipy/engine.py
def operand_key(self, operands: list) -> tuple:
    """Generates a key for sorting operands of a commutative operator.

    The key is a tuple designed to produce a consistent, canonical ordering.
    It prioritizes variables, then numbers, and finally complex subtrees.
    Subtrees are sorted by length and then recursively by their contents.

    Parameters
    ----------
    operands : list
        The operand to generate a key for, represented as a tree node.

    Returns
    -------
    tuple
        A sortable key.
    """
    if len(operands) > 1 and isinstance(operands[0], str):
        # if operands[0] in self.operator_arity_compat or operands[0] in self.operator_aliases:
        # Node
        operand_keys = tuple(self.operand_key(op) for op in operands[1])
        return (2, len(flatten_nested_list(operands)), operand_keys, operands[0])

    # Leaf
    if len(operands) == 1 and isinstance(operands[0], str):
        try:
            return (1, float(operands[0]))
        except ValueError:
            return (0, operands[0])

    if isinstance(operands, str):
        return (0, operands)

    raise ValueError(f'None of the criteria matched for operands {operands}:\n1. ({len(operands) > 1}, {isinstance(operands[0], str)}, {operands[0] in self.operator_arity_compat or operands[0] in self.operator_aliases})\n2. ({len(operands) == 1}, {isinstance(operands[0], str)})\n3. ({isinstance(operands, str)})')

operators_to_realizations

operators_to_realizations(prefix_expression: list[str] | tuple[str, ...]) -> list[str] | tuple[str, ...]

Converts operator names in an expression to their Python realizations.

This method replaces tokens like 'add' or 'sin' with their executable counterparts like '+' or 'np.sin', making the expression ready for evaluation.

PARAMETER DESCRIPTION
prefix_expression

The prefix expression with canonical operator names.

TYPE: list[str] or tuple[str, ...]

RETURNS DESCRIPTION
list[str] or tuple[str, ...]

The prefix expression with Python-executable operator realizations.

Source code in src/simplipy/engine.py
def operators_to_realizations(self, prefix_expression: list[str] | tuple[str, ...]) -> list[str] | tuple[str, ...]:
    """Converts operator names in an expression to their Python realizations.

    This method replaces tokens like 'add' or 'sin' with their executable
    counterparts like '+' or 'np.sin', making the expression ready for
    evaluation.

    Parameters
    ----------
    prefix_expression : list[str] or tuple[str, ...]
        The prefix expression with canonical operator names.

    Returns
    -------
    list[str] or tuple[str, ...]
        The prefix expression with Python-executable operator realizations.
    """
    return [self.operator_realizations.get(token, token) for token in prefix_expression]

realizations_to_operators

realizations_to_operators(prefix_expression: list[str]) -> list[str]

Converts Python realizations in an expression back to operator names.

This is the inverse of operators_to_realizations, replacing tokens like '+' or 'np.sin' with their canonical engine names like 'add' or 'sin'.

PARAMETER DESCRIPTION
prefix_expression

The prefix expression with Python-executable realizations.

TYPE: list[str]

RETURNS DESCRIPTION
list[str]

The prefix expression with canonical operator names.

Source code in src/simplipy/engine.py
def realizations_to_operators(self, prefix_expression: list[str]) -> list[str]:
    """Converts Python realizations in an expression back to operator names.

    This is the inverse of `operators_to_realizations`, replacing tokens
    like '+' or 'np.sin' with their canonical engine names like 'add' or 'sin'.

    Parameters
    ----------
    prefix_expression : list[str]
        The prefix expression with Python-executable realizations.

    Returns
    -------
    list[str]
        The prefix expression with canonical operator names.
    """
    return [self.realization_to_operator.get(token, token) for token in prefix_expression]

code_to_lambda staticmethod

code_to_lambda(code: CodeType) -> Callable[..., float]

Converts a Python code object into an executable lambda function.

PARAMETER DESCRIPTION
code

The compiled code object to convert.

TYPE: CodeType

RETURNS DESCRIPTION
Callable[..., float]

An executable lambda function.

Source code in src/simplipy/engine.py
@staticmethod
def code_to_lambda(code: CodeType) -> Callable[..., float]:
    """Converts a Python code object into an executable lambda function.

    Parameters
    ----------
    code : CodeType
        The compiled code object to convert.

    Returns
    -------
    Callable[..., float]
        An executable lambda function.
    """
    return FunctionType(code, globals())()

Asset Management

get_default_cache_dir

get_default_cache_dir() -> Path

Get the default OS-appropriate cache directory for SimpliPy assets.

This function determines the standard cache location based on the user's operating system, following the XDG Base Directory Specification on Linux. It ensures the directory exists, creating it if necessary.

RETURNS DESCRIPTION
Path

The path to the cache directory.

Source code in src/simplipy/asset_manager.py
def get_default_cache_dir() -> Path:
    """Get the default OS-appropriate cache directory for SimpliPy assets.

    This function determines the standard cache location based on the user's
    operating system, following the XDG Base Directory Specification on Linux.
    It ensures the directory exists, creating it if necessary.

    Returns
    -------
    pathlib.Path
        The path to the cache directory.

    """
    cache_dir = Path(platformdirs.user_cache_dir(appname="simplipy"))
    cache_dir.mkdir(parents=True, exist_ok=True)
    return cache_dir

fetch_manifest

fetch_manifest(repo_id: str | None = None, manifest_filename: str | None = None) -> dict

Download the latest asset manifest from Hugging Face Hub.

The manifest is a JSON file that contains metadata for all official assets, including engines and test data. This function handles potential network errors gracefully.

PARAMETER DESCRIPTION
repo_id

The Hugging Face repository ID where the manifest is stored. If None, the default repository ID is used.

TYPE: str DEFAULT: None

manifest_filename

The filename of the manifest file. If None, the default filename is used.

TYPE: str DEFAULT: None

RETURNS DESCRIPTION
dict

The parsed JSON manifest as a dictionary. Returns an empty dictionary if the download fails.

Source code in src/simplipy/asset_manager.py
def fetch_manifest(repo_id: str | None = None, manifest_filename: str | None = None) -> dict:
    """Download the latest asset manifest from Hugging Face Hub.

    The manifest is a JSON file that contains metadata for all official
    assets, including engines and test data. This function handles
    potential network errors gracefully.

    Parameters
    ----------
    repo_id : str, optional
        The Hugging Face repository ID where the manifest is stored. If None, the default repository ID is used.
    manifest_filename : str, optional
        The filename of the manifest file. If None, the default filename is used.

    Returns
    -------
    dict
        The parsed JSON manifest as a dictionary. Returns an empty dictionary
        if the download fails.

    """
    try:
        manifest_path = hf_hub_download(
            repo_id=repo_id or HF_MANIFEST_REPO,
            filename=manifest_filename or HF_MANIFEST_FILENAME,
            repo_type="dataset",
        )
        with open(manifest_path, 'r') as f:
            return json.load(f)
    except HfHubHTTPError as e:
        print(f"Error: Could not download the asset manifest from Hugging Face: {e}")
        return {}

get_path

get_path(asset: str, install: bool = False, local_dir: Path | str | None = None, repo_id: str | None = None, manifest_filename: str | None = None) -> str

Resolve the local filesystem path to an asset's entrypoint file.

This function serves as a universal resolver for SimpliPy assets. It first checks if the asset string is a valid local path. If not, it treats it as an official asset name and looks it up in the manifest.

PARAMETER DESCRIPTION
asset

The identifier for the asset. This can be a direct path to a local file (e.g., './my_rules.yaml') or the name of an official asset (e.g., 'core-rules-v1').

TYPE: str

install

If True, automatically downloads and installs the asset from Hugging Face Hub if it is not found locally. Defaults to False.

TYPE: bool DEFAULT: False

local_dir

The directory to check for the asset or install it into. If None, the default cache directory is used. Defaults to None.

TYPE: Path | str | None DEFAULT: None

repo_id

The Hugging Face repository ID where the manifest is stored. If None, the default repository ID is used.

TYPE: str DEFAULT: None

manifest_filename

The filename of the manifest file. If None, the default filename is used.

TYPE: str DEFAULT: None

RETURNS DESCRIPTION
str

The absolute path to the asset's entrypoint file.

RAISES DESCRIPTION
RuntimeError

If the asset manifest cannot be fetched from Hugging Face Hub or if the installation fails when install=True.

ValueError

If asset is not a local path and is not a known asset name in the manifest.

FileNotFoundError

If the asset is not found locally and install is False.

Source code in src/simplipy/asset_manager.py
def get_path(asset: str, install: bool = False, local_dir: Path | str | None = None, repo_id: str | None = None, manifest_filename: str | None = None) -> str:
    """Resolve the local filesystem path to an asset's entrypoint file.

    This function serves as a universal resolver for SimpliPy assets. It first
    checks if the `asset` string is a valid local path. If not, it treats it
    as an official asset name and looks it up in the manifest.

    Parameters
    ----------
    asset : str
        The identifier for the asset. This can be a direct path to a local
        file (e.g., './my_rules.yaml') or the name of an official asset
        (e.g., 'core-rules-v1').
    install : bool, optional
        If True, automatically downloads and installs the asset from
        Hugging Face Hub if it is not found locally. Defaults to False.
    local_dir : pathlib.Path | str | None, optional
        The directory to check for the asset or install it into. If None,
        the default cache directory is used. Defaults to None.
    repo_id : str, optional
        The Hugging Face repository ID where the manifest is stored. If None, the default repository ID is used.
    manifest_filename : str, optional
        The filename of the manifest file. If None, the default filename is used.

    Returns
    -------
    str
        The absolute path to the asset's entrypoint file.

    Raises
    ------
    RuntimeError
        If the asset manifest cannot be fetched from Hugging Face Hub or if
        the installation fails when `install=True`.
    ValueError
        If `asset` is not a local path and is not a known asset name in the
        manifest.
    FileNotFoundError
        If the asset is not found locally and `install` is False.

    """
    if not asset or not isinstance(asset, str):
        raise ValueError("Error: 'asset' must be a non-empty string.")

    # Check if 'asset' is a valid local path
    if Path(asset).exists():
        return asset

    # Otherwise, treat 'asset' as an official asset name
    manifest = fetch_manifest(repo_id=repo_id, manifest_filename=manifest_filename)
    if not manifest:
        raise RuntimeError("Could not fetch asset manifest.")

    asset_info = manifest.get(asset, {})
    if not asset_info:
        list_assets(asset_type='all')
        raise ValueError(f"Error: Unknown asset: '{asset}'. See above for available assets.")

    if local_dir is None:
        local_dir = get_default_cache_dir()
    elif isinstance(local_dir, str):
        local_dir = Path(local_dir)

    entrypoint_path = local_dir / asset_info['directory'] / asset_info['entrypoint']

    if entrypoint_path.exists():
        return str(entrypoint_path)

    if install:
        print(f"Asset '{asset}' is not installed. Installing.")
        if install_asset(asset, local_dir=local_dir, repo_id=repo_id, manifest_filename=manifest_filename):
            return str(entrypoint_path)
        else:
            raise RuntimeError(f"Failed to install asset '{asset}'.")

    raise FileNotFoundError(f"Asset '{asset}' is not installed. Use install=True to download it.")

install_asset

install_asset(asset: str, force: bool = False, local_dir: Path | str | None = None, repo_id: str | None = None, manifest_filename: str | None = None) -> bool

Install a SimpliPy asset from Hugging Face Hub.

Downloads all files associated with a given asset from its corresponding Hugging Face repository and places them in the specified local directory.

PARAMETER DESCRIPTION
asset

The name of the official asset to install.

TYPE: str

force

If True, any existing local version of the asset will be removed before the new version is installed. Defaults to False.

TYPE: bool DEFAULT: False

local_dir

The directory to install the asset into. If None, the default cache directory is used. Defaults to None.

TYPE: Path | str | None DEFAULT: None

repo_id

The Hugging Face repository ID where the manifest is stored. If None, the default repository ID is used.

TYPE: str DEFAULT: None

manifest_filename

The filename of the manifest file. If None, the default filename is used.

TYPE: str DEFAULT: None

RETURNS DESCRIPTION
bool

True if the installation was successful or if the asset was already installed. False if the asset name is unknown or a download error occurs.

Source code in src/simplipy/asset_manager.py
def install_asset(asset: str, force: bool = False, local_dir: Path | str | None = None, repo_id: str | None = None, manifest_filename: str | None = None) -> bool:
    """Install a SimpliPy asset from Hugging Face Hub.

    Downloads all files associated with a given asset from its corresponding
    Hugging Face repository and places them in the specified local directory.

    Parameters
    ----------
    asset : str
        The name of the official asset to install.
    force : bool, optional
        If True, any existing local version of the asset will be removed
        before the new version is installed. Defaults to False.
    local_dir : pathlib.Path | str | None, optional
        The directory to install the asset into. If None, the default cache
        directory is used. Defaults to None.
    repo_id : str, optional
        The Hugging Face repository ID where the manifest is stored. If None, the default repository ID is used.
    manifest_filename : str, optional
        The filename of the manifest file. If None, the default filename is used.

    Returns
    -------
    bool
        True if the installation was successful or if the asset was already
        installed. False if the asset name is unknown or a download error
        occurs.

    """
    manifest = fetch_manifest(repo_id=repo_id, manifest_filename=manifest_filename)
    if not manifest:
        return False

    asset_info = manifest.get(asset)
    if not asset_info:
        print(f"Error: Unknown asset: '{asset}'.")
        list_assets(asset_type='all')
        return False

    if local_dir is None:
        local_dir = get_default_cache_dir()
    elif isinstance(local_dir, str):
        local_dir = Path(local_dir)
    local_dir.mkdir(parents=True, exist_ok=True)
    local_path = local_dir / asset_info['directory']

    if local_path.exists() and not force:
        print(f"Asset '{asset}' is already installed at {local_path}.")
        print("Use force=True or --force to reinstall.")
        return True

    if local_path.exists() and force:
        print(f"Force option specified. Removing existing version of '{asset}'...")
        uninstall_asset(asset, quiet=True, local_dir=local_dir)

    print(f"Installing asset '{asset}' to {local_path}.")
    try:
        for file in asset_info['files']:

            hf_hub_download(
                repo_id=asset_info['repo_id'],
                filename=f"{asset_info['directory']}/{file}",
                repo_type="dataset",
                local_dir=local_dir,
            )
        print(f"Successfully installed '{asset}'.")
        return True
    except HfHubHTTPError as e:
        print(f"Error downloading asset '{asset}': {e}")
        # Clean up partial download
        if local_dir.exists():
            shutil.rmtree(local_dir)
        return False

uninstall_asset

uninstall_asset(asset: str, quiet: bool = False, local_dir: Path | str | None = None, repo_id: str | None = None, manifest_filename: str | None = None) -> bool

Remove a locally installed SimpliPy asset.

This function deletes the entire directory associated with the specified asset from the local filesystem.

PARAMETER DESCRIPTION
asset

The name of the asset to uninstall.

TYPE: str

quiet

If True, suppresses console output messages. Defaults to False.

TYPE: bool DEFAULT: False

local_dir

The directory from which to uninstall the asset. If None, the default cache directory is used. Defaults to None.

TYPE: Path | str | None DEFAULT: None

repo_id

The Hugging Face repository ID where the manifest is stored. If None, the default repository ID is used.

TYPE: str DEFAULT: None

manifest_filename

The filename of the manifest file. If None, the default filename is used.

TYPE: str DEFAULT: None

RETURNS DESCRIPTION
bool

True if the asset was successfully removed or was not installed to begin with. False if an OS error occurs during removal.

RAISES DESCRIPTION
ValueError

If asset is not a known asset name in the manifest.

Source code in src/simplipy/asset_manager.py
def uninstall_asset(asset: str, quiet: bool = False, local_dir: Path | str | None = None, repo_id: str | None = None, manifest_filename: str | None = None) -> bool:
    """Remove a locally installed SimpliPy asset.

    This function deletes the entire directory associated with the specified
    asset from the local filesystem.

    Parameters
    ----------
    asset : str
        The name of the asset to uninstall.
    quiet : bool, optional
        If True, suppresses console output messages. Defaults to False.
    local_dir : pathlib.Path | str | None, optional
        The directory from which to uninstall the asset. If None, the
        default cache directory is used. Defaults to None.
    repo_id : str, optional
        The Hugging Face repository ID where the manifest is stored. If None, the default repository ID is used.
    manifest_filename : str, optional
        The filename of the manifest file. If None, the default filename is used.

    Returns
    -------
    bool
        True if the asset was successfully removed or was not installed to
        begin with. False if an OS error occurs during removal.

    Raises
    ------
    ValueError
        If `asset` is not a known asset name in the manifest.

    """
    if local_dir is None:
        local_dir = get_default_cache_dir()
    elif isinstance(local_dir, str):
        local_dir = Path(local_dir)

    manifest = fetch_manifest(repo_id=repo_id, manifest_filename=manifest_filename)
    if manifest:
        asset_info = manifest.get(asset)
        if not asset_info:
            list_assets(asset_type='all', installed_only=True)
            raise ValueError(f"Error: Unknown asset: '{asset}'. See above for installed assets.")

        local_path = local_dir / asset_info['directory']
    else:
        local_path = local_dir / asset

    if not local_path.exists():
        if not quiet:
            print(f"Asset '{asset}' is not installed.")
        return True

    try:
        shutil.rmtree(local_path)
        if not quiet:
            print(f"Successfully removed '{asset}'.")
        return True
    except OSError as e:
        if not quiet:
            print(f"Error removing '{asset}': {e}")
        return False

list_assets

list_assets(asset_type: AssetType, installed_only: bool = False, local_dir: Path | str | None = None, repo_id: str | None = None, manifest_filename: str | None = None) -> None

List available or installed SimpliPy assets.

Fetches the asset manifest and checks the local filesystem to print a formatted list of assets, their descriptions, and their installation status to standard output.

PARAMETER DESCRIPTION
asset_type

The category of assets to list.

TYPE: (engine, test - data, all) DEFAULT: 'engine'

installed_only

If True, the list is filtered to show only assets that are currently installed locally. Defaults to False.

TYPE: bool DEFAULT: False

local_dir

The directory to check for installed assets. If None, the default cache directory is used. Defaults to None.

TYPE: Path | str | None DEFAULT: None

repo_id

The Hugging Face repository ID where the manifest is stored. If None, the default repository ID is used.

TYPE: str DEFAULT: None

manifest_filename

The filename of the manifest file. If None, the default filename is used.

TYPE: str DEFAULT: None

Source code in src/simplipy/asset_manager.py
def list_assets(asset_type: AssetType, installed_only: bool = False, local_dir: Path | str | None = None, repo_id: str | None = None, manifest_filename: str | None = None) -> None:
    """List available or installed SimpliPy assets.

    Fetches the asset manifest and checks the local filesystem to print a
    formatted list of assets, their descriptions, and their installation
    status to standard output.

    Parameters
    ----------
    asset_type : {'engine', 'test-data', 'all'}
        The category of assets to list.
    installed_only : bool, optional
        If True, the list is filtered to show only assets that are currently
        installed locally. Defaults to False.
    local_dir : pathlib.Path | str | None, optional
        The directory to check for installed assets. If None, the default
        cache directory is used. Defaults to None.
    repo_id : str, optional
        The Hugging Face repository ID where the manifest is stored. If None, the default repository ID is used.
    manifest_filename : str, optional
        The filename of the manifest file. If None, the default filename is used.

    """
    manifest = fetch_manifest(repo_id=repo_id, manifest_filename=manifest_filename)
    if not manifest:
        return

    print(f"--- {'Installed' if installed_only else 'Available'} Assets ---")

    if local_dir is None:
        local_dir = get_default_cache_dir()
    elif isinstance(local_dir, str):
        local_dir = Path(local_dir)

    found_any = False
    for name, info in manifest.items():
        if asset_type != 'all' and info.get('type') != asset_type:
            continue
        local_path = local_dir / info['directory']
        is_installed = local_path.exists()

        if installed_only and not is_installed:
            continue

        status = "[installed]" if is_installed else ""
        print(f"- {name:<15} {status:<12} {info['description']}")
        found_any = True

    if not found_any:
        print(f"No {asset_type}s found.")

Operators

neg

neg(x: float) -> float

Return the element-wise negation of x.

Source code in src/simplipy/operators.py
def neg(x: float) -> float:
    """Return the element-wise negation of x."""
    return -x

inv

inv(x: float) -> float

Return the element-wise multiplicative inverse of x.

Source code in src/simplipy/operators.py
def inv(x: float) -> float:
    """Return the element-wise multiplicative inverse of x."""
    # numpy will handle the x = 0 case
    if isinstance(x, Iterable):
        return 1 / x

    # Manually handle scalar case
    if x == 0:
        return float('inf')

    # All safe
    return 1 / x

div

div(x: float, y: float) -> float

Return the element-wise division of x by y.

Source code in src/simplipy/operators.py
def div(x: float, y: float) -> float:
    """Return the element-wise division of x by y."""
    # numpy will handle the x = 0 case
    if isinstance(y, Iterable):
        return x / y

    # Manually handle scalar case
    if y == 0:
        # When x is an iterable, multiply with infinity to let the sign determine the result
        if isinstance(x, Iterable):
            return x * float('inf')

        # When x is a scalar, return inf or -inf depending on the sign of x
        if not isinstance(x, complex):
            if x > 0:
                return float('inf')
            elif x < 0:
                return float('-inf')

        # Both x and y are zero.
        # Return NaN to indicate an undefined result
        return float('nan')

    # All safe
    return x / y

mult2

mult2(x: float) -> float

Multiply x by 2.

Source code in src/simplipy/operators.py
def mult2(x: float) -> float:
    """Multiply x by 2."""
    return 2 * x

mult3

mult3(x: float) -> float

Multiply x by 3.

Source code in src/simplipy/operators.py
def mult3(x: float) -> float:
    """Multiply x by 3."""
    return 3 * x

mult4

mult4(x: float) -> float

Multiply x by 4.

Source code in src/simplipy/operators.py
def mult4(x: float) -> float:
    """Multiply x by 4."""
    return 4 * x

mult5

mult5(x: float) -> float

Multiply x by 5.

Source code in src/simplipy/operators.py
def mult5(x: float) -> float:
    """Multiply x by 5."""
    return 5 * x

div2

div2(x: float) -> float

Divide x by 2.

Source code in src/simplipy/operators.py
def div2(x: float) -> float:
    """Divide x by 2."""
    return x / 2

div3

div3(x: float) -> float

Divide x by 3.

Source code in src/simplipy/operators.py
def div3(x: float) -> float:
    """Divide x by 3."""
    return x / 3

div4

div4(x: float) -> float

Divide x by 4.

Source code in src/simplipy/operators.py
def div4(x: float) -> float:
    """Divide x by 4."""
    return x / 4

div5

div5(x: float) -> float

Divide x by 5.

Source code in src/simplipy/operators.py
def div5(x: float) -> float:
    """Divide x by 5."""
    return x / 5

pow2

pow2(x: float) -> float

Return x raised to the power of 2.

Source code in src/simplipy/operators.py
def pow2(x: float) -> float:
    """Return x raised to the power of 2."""
    return x ** 2

pow3

pow3(x: float) -> float

Return x raised to the power of 3.

Source code in src/simplipy/operators.py
def pow3(x: float) -> float:
    """Return x raised to the power of 3."""
    return x ** 3

pow4

pow4(x: float) -> float

Return x raised to the power of 4.

Source code in src/simplipy/operators.py
def pow4(x: float) -> float:
    """Return x raised to the power of 4."""
    return x ** 4

pow5

pow5(x: float) -> float

Return x raised to the power of 5.

Source code in src/simplipy/operators.py
def pow5(x: float) -> float:
    """Return x raised to the power of 5."""
    return x ** 5

pow1_2

pow1_2(x: float) -> float

Return the square root of x.

Source code in src/simplipy/operators.py
def pow1_2(x: float) -> float:
    """Return the square root of x."""
    return x ** 0.5

pow1_3

pow1_3(x: float) -> float

Return the real-valued cube root of x.

Source code in src/simplipy/operators.py
def pow1_3(x: float) -> float:
    """Return the real-valued cube root of x."""
    global _torch_module, _torch_checked
    if isinstance(x, np.ndarray):
        # Handle numpy arrays
        if np.iscomplexobj(x):
            # Handle complex numbers
            return x ** (1 / 3)
        x = np.asarray(x)
        x = np.where(x < 0, -(-x) ** (1 / 3), x ** (1 / 3))
        return x

    if type(x).__module__ == 'torch' and type(x).__name__ == 'Tensor':
        if not _torch_checked:
            try:
                import torch  # type:ignore
                _torch_module = torch
            except ImportError:
                _torch_module = None
            _torch_checked = True

        if _torch_module is None:
            raise ImportError("PyTorch is required to process torch tensors")

        # Handle torch tensors
        if x.dtype == torch.complex64 or x.dtype == torch.complex128:  # type:ignore
            # Handle complex numbers
            return x ** (1 / 3)
        x = torch.where(x < 0, -(-x) ** (1 / 3), x ** (1 / 3))
        return x

    if not isinstance(x, complex) and x < 0:
        # Discard imaginary component
        return - (-x) ** (1 / 3)
    else:
        return x ** (1 / 3)

pow1_4

pow1_4(x: float) -> float

Return the fourth root of x.

Source code in src/simplipy/operators.py
def pow1_4(x: float) -> float:
    """Return the fourth root of x."""
    return x ** 0.25

pow1_5

pow1_5(x: float) -> float

Return the real-valued fifth root of x.

Source code in src/simplipy/operators.py
def pow1_5(x: float) -> float:
    """Return the real-valued fifth root of x."""
    global _torch_module, _torch_checked
    if isinstance(x, np.ndarray):
        # Handle numpy arrays
        if np.iscomplexobj(x):
            # Handle complex numbers
            return x ** (1 / 5)
        x = np.asarray(x)
        x = np.where(x < 0, -(-x) ** (1 / 5), x ** (1 / 5))
        return x

    if type(x).__module__ == 'torch' and type(x).__name__ == 'Tensor':
        if not _torch_checked:
            try:
                import torch  # type:ignore
                _torch_module = torch
            except ImportError:
                _torch_module = None
            _torch_checked = True

        if _torch_module is None:
            raise ImportError("PyTorch is required to process torch tensors")

        # Handle torch tensors
        if x.dtype == torch.complex64 or x.dtype == torch.complex128:  # type:ignore
            # Handle complex numbers
            return x ** (1 / 5)
        x = torch.where(x < 0, -(-x) ** (1 / 5), x ** (1 / 5))
        return x

    if not isinstance(x, complex) and x < 0:
        # Discard imaginary component
        return - (-x) ** (1 / 5)
    else:
        return x ** (1 / 5)

abs

abs(x: float) -> float

Return the element-wise absolute value of x.

Source code in src/simplipy/operators.py
def abs(x: float) -> float:
    """Return the element-wise absolute value of x."""
    global _torch_module, _torch_checked
    if isinstance(x, np.ndarray):
        # Handle numpy arrays
        return np.abs(x)
    if type(x).__module__ == 'torch' and type(x).__name__ == 'Tensor':
        if not _torch_checked:
            try:
                import torch  # type:ignore
                _torch_module = torch
            except ImportError:
                _torch_module = None
            _torch_checked = True

        if _torch_module is None:
            raise ImportError("PyTorch is required to process torch tensors")

        # Handle torch tensors
        return _torch_module.abs(x)
    if isinstance(x, complex):
        # Handle complex numbers
        return (x.real ** 2 + x.imag ** 2) ** 0.5
    # Handle scalar case
    return x if x >= 0 else -x  # Ensure non-negative result

sin

sin(x: float) -> float

Return the element-wise sine of x.

Source code in src/simplipy/operators.py
def sin(x: float) -> float:
    """Return the element-wise sine of x."""
    global _torch_module, _torch_checked
    if isinstance(x, np.ndarray):
        # Handle numpy arrays
        return np.sin(x)
    if type(x).__module__ == 'torch' and type(x).__name__ == 'Tensor':
        if not _torch_checked:
            try:
                import torch  # type:ignore
                _torch_module = torch
            except ImportError:
                _torch_module = None
            _torch_checked = True

        if _torch_module is None:
            raise ImportError("PyTorch is required to process torch tensors")

        # Handle torch tensors
        return _torch_module.sin(x)
    # Handle scalar case
    return np.sin(x)  # Use numpy for scalar sine calculation

cos

cos(x: float) -> float

Return the element-wise cosine of x.

Source code in src/simplipy/operators.py
def cos(x: float) -> float:
    """Return the element-wise cosine of x."""
    global _torch_module, _torch_checked
    if isinstance(x, np.ndarray):
        # Handle numpy arrays
        return np.cos(x)
    if type(x).__module__ == 'torch' and type(x).__name__ == 'Tensor':
        if not _torch_checked:
            try:
                import torch  # type:ignore
                _torch_module = torch
            except ImportError:
                _torch_module = None
            _torch_checked = True

        if _torch_module is None:
            raise ImportError("PyTorch is required to process torch tensors")

        # Handle torch tensors
        return _torch_module.cos(x)
    # Handle scalar case
    return np.cos(x)  # Use numpy for scalar cosine calculation

tan

tan(x: float) -> float

Return the element-wise tangent of x.

Source code in src/simplipy/operators.py
def tan(x: float) -> float:
    """Return the element-wise tangent of x."""
    global _torch_module, _torch_checked
    if isinstance(x, np.ndarray):
        # Handle numpy arrays
        return np.tan(x)
    if type(x).__module__ == 'torch' and type(x).__name__ == 'Tensor':
        if not _torch_checked:
            try:
                import torch  # type:ignore
                _torch_module = torch
            except ImportError:
                _torch_module = None
            _torch_checked = True

        if _torch_module is None:
            raise ImportError("PyTorch is required to process torch tensors")

        # Handle torch tensors
        return _torch_module.tan(x)
    # Handle scalar case
    return np.tan(x)  # Use numpy for scalar tangent calculation

asin

asin(x: float) -> float

Return the element-wise inverse sine of x.

Source code in src/simplipy/operators.py
def asin(x: float) -> float:
    """Return the element-wise inverse sine of x."""
    global _torch_module, _torch_checked
    if isinstance(x, np.ndarray):
        # Handle numpy arrays
        return np.arcsin(x)
    if type(x).__module__ == 'torch' and type(x).__name__ == 'Tensor':
        if not _torch_checked:
            try:
                import torch  # type:ignore
                _torch_module = torch
            except ImportError:
                _torch_module = None
            _torch_checked = True

        if _torch_module is None:
            raise ImportError("PyTorch is required to process torch tensors")

        # Handle torch tensors
        return _torch_module.asin(x)
    # Handle scalar case
    return np.arcsin(x)  # Use numpy for scalar arcsine calculation

acos

acos(x: float) -> float

Return the element-wise inverse cosine of x.

Source code in src/simplipy/operators.py
def acos(x: float) -> float:
    """Return the element-wise inverse cosine of x."""
    global _torch_module, _torch_checked
    if isinstance(x, np.ndarray):
        # Handle numpy arrays
        return np.arccos(x)
    if type(x).__module__ == 'torch' and type(x).__name__ == 'Tensor':
        if not _torch_checked:
            try:
                import torch  # type:ignore
                _torch_module = torch
            except ImportError:
                _torch_module = None
            _torch_checked = True

        if _torch_module is None:
            raise ImportError("PyTorch is required to process torch tensors")

        # Handle torch tensors
        return _torch_module.acos(x)
    # Handle scalar case
    return np.arccos(x)  # Use numpy for scalar arccosine calculation

atan

atan(x: float) -> float

Return the element-wise inverse tangent of x.

Source code in src/simplipy/operators.py
def atan(x: float) -> float:
    """Return the element-wise inverse tangent of x."""
    global _torch_module, _torch_checked
    if isinstance(x, np.ndarray):
        # Handle numpy arrays
        return np.arctan(x)
    if type(x).__module__ == 'torch' and type(x).__name__ == 'Tensor':
        if not _torch_checked:
            try:
                import torch  # type:ignore
                _torch_module = torch
            except ImportError:
                _torch_module = None
            _torch_checked = True

        if _torch_module is None:
            raise ImportError("PyTorch is required to process torch tensors")

        # Handle torch tensors
        return _torch_module.atan(x)
    # Handle scalar case
    return np.arctan(x)  # Use numpy for scalar arctangent calculation

sinh

sinh(x: float) -> float

Return the element-wise hyperbolic sine of x.

Source code in src/simplipy/operators.py
def sinh(x: float) -> float:
    """Return the element-wise hyperbolic sine of x."""
    global _torch_module, _torch_checked
    if isinstance(x, np.ndarray):
        # Handle numpy arrays
        return np.sinh(x)
    if type(x).__module__ == 'torch' and type(x).__name__ == 'Tensor':
        if not _torch_checked:
            try:
                import torch  # type:ignore
                _torch_module = torch
            except ImportError:
                _torch_module = None
            _torch_checked = True

        if _torch_module is None:
            raise ImportError("PyTorch is required to process torch tensors")

        # Handle torch tensors
        return _torch_module.sinh(x)
    # Handle scalar case
    return np.sinh(x)  # Use numpy for scalar hyperbolic sine calculation

cosh

cosh(x: float) -> float

Return the element-wise hyperbolic cosine of x.

Source code in src/simplipy/operators.py
def cosh(x: float) -> float:
    """Return the element-wise hyperbolic cosine of x."""
    global _torch_module, _torch_checked
    if isinstance(x, np.ndarray):
        # Handle numpy arrays
        return np.cosh(x)
    if type(x).__module__ == 'torch' and type(x).__name__ == 'Tensor':
        if not _torch_checked:
            try:
                import torch  # type:ignore
                _torch_module = torch
            except ImportError:
                _torch_module = None
            _torch_checked = True

        if _torch_module is None:
            raise ImportError("PyTorch is required to process torch tensors")

        # Handle torch tensors
        return _torch_module.cosh(x)
    # Handle scalar case
    return np.cosh(x)  # Use numpy for scalar hyperbolic cosine calculation

tanh

tanh(x: float) -> float

Return the element-wise hyperbolic tangent of x.

Source code in src/simplipy/operators.py
def tanh(x: float) -> float:
    """Return the element-wise hyperbolic tangent of x."""
    global _torch_module, _torch_checked
    if isinstance(x, np.ndarray):
        # Handle numpy arrays
        return np.tanh(x)
    if type(x).__module__ == 'torch' and type(x).__name__ == 'Tensor':
        if not _torch_checked:
            try:
                import torch  # type:ignore
                _torch_module = torch
            except ImportError:
                _torch_module = None
            _torch_checked = True

        if _torch_module is None:
            raise ImportError("PyTorch is required to process torch tensors")

        # Handle torch tensors
        return _torch_module.tanh(x)
    # Handle scalar case
    return np.tanh(x)  # Use numpy for scalar hyperbolic tangent calculation

asinh

asinh(x: float) -> float

Return the element-wise inverse hyperbolic sine of x.

Source code in src/simplipy/operators.py
def asinh(x: float) -> float:
    """Return the element-wise inverse hyperbolic sine of x."""
    global _torch_module, _torch_checked
    if isinstance(x, np.ndarray):
        # Handle numpy arrays
        return np.arcsinh(x)
    if type(x).__module__ == 'torch' and type(x).__name__ == 'Tensor':
        if not _torch_checked:
            try:
                import torch  # type:ignore
                _torch_module = torch
            except ImportError:
                _torch_module = None
            _torch_checked = True

        if _torch_module is None:
            raise ImportError("PyTorch is required to process torch tensors")

        # Handle torch tensors
        return _torch_module.asinh(x)
    # Handle scalar case
    return np.arcsinh(x)  # Use numpy for scalar inverse hyperbolic sine calculation

acosh

acosh(x: float) -> float

Return the element-wise inverse hyperbolic cosine of x.

Source code in src/simplipy/operators.py
def acosh(x: float) -> float:
    """Return the element-wise inverse hyperbolic cosine of x."""
    global _torch_module, _torch_checked
    if isinstance(x, np.ndarray):
        # Handle numpy arrays
        return np.arccosh(x)
    if type(x).__module__ == 'torch' and type(x).__name__ == 'Tensor':
        if not _torch_checked:
            try:
                import torch  # type:ignore
                _torch_module = torch
            except ImportError:
                _torch_module = None
            _torch_checked = True

        if _torch_module is None:
            raise ImportError("PyTorch is required to process torch tensors")

        # Handle torch tensors
        return _torch_module.acosh(x)
    # Handle scalar case
    return np.arccosh(x)  # Use numpy for scalar inverse hyperbolic cosine calculation

atanh

atanh(x: float) -> float

Return the element-wise inverse hyperbolic tangent of x.

Source code in src/simplipy/operators.py
def atanh(x: float) -> float:
    """Return the element-wise inverse hyperbolic tangent of x."""
    global _torch_module, _torch_checked
    if isinstance(x, np.ndarray):
        # Handle numpy arrays
        return np.arctanh(x)
    if type(x).__module__ == 'torch' and type(x).__name__ == 'Tensor':
        if not _torch_checked:
            try:
                import torch  # type:ignore
                _torch_module = torch
            except ImportError:
                _torch_module = None
            _torch_checked = True

        if _torch_module is None:
            raise ImportError("PyTorch is required to process torch tensors")

        # Handle torch tensors
        return _torch_module.atanh(x)
    # Handle scalar case
    return np.arctanh(x)  # Use numpy for scalar inverse hyperbolic tangent calculation

exp

exp(x: float) -> float

Return the element-wise exponential of x.

Source code in src/simplipy/operators.py
def exp(x: float) -> float:
    """Return the element-wise exponential of x."""
    global _torch_module, _torch_checked
    if isinstance(x, np.ndarray):
        # Handle numpy arrays
        return np.exp(x)
    if type(x).__module__ == 'torch' and type(x).__name__ == 'Tensor':
        if not _torch_checked:
            try:
                import torch  # type:ignore
                _torch_module = torch
            except ImportError:
                _torch_module = None
            _torch_checked = True

        if _torch_module is None:
            raise ImportError("PyTorch is required to process torch tensors")

        # Handle torch tensors
        return _torch_module.exp(x)
    # Handle scalar case
    return np.exp(x)  # Use numpy for scalar exponential calculation

log

log(x: float) -> float

Return the element-wise natural logarithm of x.

Source code in src/simplipy/operators.py
def log(x: float) -> float:
    """Return the element-wise natural logarithm of x."""
    global _torch_module, _torch_checked
    if isinstance(x, np.ndarray):
        # Handle numpy arrays
        return np.log(x)
    if type(x).__module__ == 'torch' and type(x).__name__ == 'Tensor':
        if not _torch_checked:
            try:
                import torch  # type:ignore
                _torch_module = torch
            except ImportError:
                _torch_module = None
            _torch_checked = True

        if _torch_module is None:
            raise ImportError("PyTorch is required to process torch tensors")

        # Handle torch tensors
        return _torch_module.log(x)
    # Handle scalar case
    return np.log(x)  # Use numpy for scalar logarithm calculation

pow

pow(x: float, y: float) -> float

Return x raised to the power of y, element-wise.

Source code in src/simplipy/operators.py
def pow(x: float, y: float) -> float:
    """Return x raised to the power of y, element-wise."""
    global _torch_module, _torch_checked
    if isinstance(x, np.ndarray) or isinstance(y, np.ndarray):
        # Handle numpy arrays
        with np.errstate(invalid='ignore'):
            return np.power(x, y)
    if type(x).__module__ == 'torch' and type(x).__name__ == 'Tensor':
        if not _torch_checked:
            try:
                import torch  # type:ignore
                _torch_module = torch
            except ImportError:
                _torch_module = None
            _torch_checked = True

        if _torch_module is None:
            raise ImportError("PyTorch is required to process torch tensors")

        # Handle torch tensors
        return _torch_module.pow(x, y)
    # Handle scalar case
    with np.errstate(invalid='ignore'):
        if isinstance(x, int):
            x = float(x)
        if isinstance(y, int):
            y = float(y)
        return np.power(x, y)  # Use numpy for scalar power calculation

Utilities

apply_on_nested

apply_on_nested(structure: list | dict, func: Callable) -> list | dict

Recursively apply a function to all non-structural values in a nested container.

This function traverses a nested dictionary or list and applies func to every value that is not itself a dict or list. The original structure is mutated; the same instance is returned for convenience. If structure is neither a list nor a dictionary, it is returned unchanged.

PARAMETER DESCRIPTION
structure

The nested list or dictionary to process.

TYPE: list or dict

func

The function to apply to each non-structural value.

TYPE: Callable

RETURNS DESCRIPTION
list or dict

The input structure with func applied to all terminal values.

Examples:

>>> data = {'a': 1, 'b': {'c': 2, 'd': [{'e': 3}, {'f': 4}, 3]}}
>>> result = apply_on_nested(data, lambda x: x * 10)
>>> result
{'a': 10, 'b': {'c': 20, 'd': [{'e': 30}, {'f': 40}, 30]}}
>>> data is result
True
Source code in src/simplipy/utils.py
def apply_on_nested(structure: list | dict, func: Callable) -> list | dict:
    """Recursively apply a function to all non-structural values in a nested container.

    This function traverses a nested dictionary or list and applies ``func`` to
    every value that is not itself a ``dict`` or ``list``. The original
    ``structure`` is mutated; the same instance is returned for convenience. If
    ``structure`` is neither a list nor a dictionary, it is returned unchanged.

    Parameters
    ----------
    structure : list or dict
        The nested list or dictionary to process.
    func : Callable
        The function to apply to each non-structural value.

    Returns
    -------
    list or dict
        The input ``structure`` with ``func`` applied to all terminal values.

    Examples
    --------
    >>> data = {'a': 1, 'b': {'c': 2, 'd': [{'e': 3}, {'f': 4}, 3]}}
    >>> result = apply_on_nested(data, lambda x: x * 10)
    >>> result
    {'a': 10, 'b': {'c': 20, 'd': [{'e': 30}, {'f': 40}, 30]}}
    >>> data is result
    True
    """
    if isinstance(structure, list):
        for i, value in enumerate(structure):
            if isinstance(value, (list, dict)):
                structure[i] = apply_on_nested(value, func)
            else:
                structure[i] = func(value)
        return structure

    if isinstance(structure, dict):
        for key, value in structure.items():
            if isinstance(value, (list, dict)):
                structure[key] = apply_on_nested(value, func)
            else:
                structure[key] = func(value)
        return structure

    return structure

traverse_dict

traverse_dict(dict_: dict[str, Any]) -> Generator[tuple[str, Any], None, None]

Recursively traverse a nested dictionary and yield key-value pairs.

This generator function walks through a dictionary, descending into any nested dictionaries it finds. It yields the key and value for any value that is not a dictionary.

PARAMETER DESCRIPTION
dict_

The nested dictionary to traverse.

TYPE: dict[str, Any]

YIELDS DESCRIPTION
tuple[str, Any]

A tuple containing the key and its corresponding non-dictionary value.

Examples:

>>> data = {'a': 1, 'b': {'c': 2, 'd': 3}}
>>> list(traverse_dict(data))
[('a', 1), ('c', 2), ('d', 3)]
Source code in src/simplipy/utils.py
def traverse_dict(dict_: dict[str, Any]) -> Generator[tuple[str, Any], None, None]:
    """Recursively traverse a nested dictionary and yield key-value pairs.

    This generator function walks through a dictionary, descending into any
    nested dictionaries it finds. It yields the key and value for any
    value that is not a dictionary.

    Parameters
    ----------
    dict_ : dict[str, Any]
        The nested dictionary to traverse.

    Yields
    ------
    tuple[str, Any]
        A tuple containing the key and its corresponding non-dictionary value.

    Examples
    --------
    >>> data = {'a': 1, 'b': {'c': 2, 'd': 3}}
    >>> list(traverse_dict(data))
    [('a', 1), ('c', 2), ('d', 3)]
    """

    for key, value in dict_.items():
        if isinstance(value, dict):
            yield from traverse_dict(value)
        else:
            yield key, value

codify

codify(code_string: str, variables: list[str] | None = None) -> CodeType

Compile a string expression into a Python code object.

This function takes a string representing a mathematical expression and compiles it into a code object that can be executed later using eval or converted into a lambda function. It wraps the expression in a lambda function signature.

PARAMETER DESCRIPTION
code_string

The mathematical expression string to compile.

TYPE: str

variables

A list of variable names to be used as arguments for the lambda function, by default None.

TYPE: list[str] or None DEFAULT: None

RETURNS DESCRIPTION
CodeType

The compiled code object, ready for execution.

Examples:

>>> code_obj = codify("x + y", variables=['x', 'y'])
>>> compiled_func = eval(code_obj)
>>> compiled_func(2, 3)
5
Source code in src/simplipy/utils.py
def codify(code_string: str, variables: list[str] | None = None) -> CodeType:
    """Compile a string expression into a Python code object.

    This function takes a string representing a mathematical expression and
    compiles it into a code object that can be executed later using `eval` or
    converted into a lambda function. It wraps the expression in a lambda
    function signature.

    Parameters
    ----------
    code_string : str
        The mathematical expression string to compile.
    variables : list[str] or None, optional
        A list of variable names to be used as arguments for the lambda
        function, by default None.

    Returns
    -------
    CodeType
        The compiled code object, ready for execution.

    Examples
    --------
    >>> code_obj = codify("x + y", variables=['x', 'y'])
    >>> compiled_func = eval(code_obj)
    >>> compiled_func(2, 3)
    5
    """
    if variables is None:
        variables = []
    func_string = f'lambda {", ".join(variables)}: {code_string}'
    filename = f'<lambdifygenerated-{time.time_ns()}'
    return compile(func_string, filename, 'eval')

get_used_modules

get_used_modules(infix_expression: str) -> list[str]

Return the names of top-level Python modules referenced in an infix expression.

The function scans for dotted attribute accesses that look like module usages (for example numpy.sin(...) or math.cos(...)) and collects their leading module names. The module numpy is always included so that downstream evaluation logic can rely on it being available.

PARAMETER DESCRIPTION
infix_expression

The mathematical expression in infix notation.

TYPE: str

RETURNS DESCRIPTION
list[str]

Unique module names referenced in infix_expression. The order is derived from the underlying set and should be treated as arbitrary.

Examples:

>>> sorted(get_used_modules("numpy.sin(x) + math.exp(y)"))
['math', 'numpy']
Source code in src/simplipy/utils.py
def get_used_modules(infix_expression: str) -> list[str]:
    """Return the names of top-level Python modules referenced in an infix expression.

    The function scans for dotted attribute accesses that look like module
    usages (for example ``numpy.sin(...)`` or ``math.cos(...)``) and collects
    their leading module names. The module ``numpy`` is always included so that
    downstream evaluation logic can rely on it being available.

    Parameters
    ----------
    infix_expression : str
        The mathematical expression in infix notation.

    Returns
    -------
    list[str]
        Unique module names referenced in ``infix_expression``. The order is
        derived from the underlying ``set`` and should be treated as arbitrary.

    Examples
    --------
    >>> sorted(get_used_modules("numpy.sin(x) + math.exp(y)"))
    ['math', 'numpy']
    """
    # Match the expression against `module.submodule. ... .function(`
    pattern = re.compile(r'([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)+)\(')

    # Find all matches in the whole expression
    matches = pattern.findall(infix_expression)

    # Return the unique matches
    modules_set = set(m.split('.')[0] for m in matches)

    modules_set.update(['numpy'])

    return list(modules_set)

substitude_constants

substitude_constants(prefix_expression: list[str], values: list | ndarray, constants: list[str] | None = None, inplace: bool = False) -> list[str]

Substitute placeholders in a prefix expression with numeric values.

This helper replaces constant placeholders such as "<constant>" or the tokens listed in constants with the values supplied in values. Values are consumed from left to right as matching tokens are encountered.

PARAMETER DESCRIPTION
prefix_expression

The prefix expression containing constant placeholders.

TYPE: list[str]

values

The numeric values to substitute into the expression.

TYPE: list or ndarray

constants

An explicit list of placeholder names to be replaced. When None, the function considers "<constant>" and C_i tokens. Defaults to None.

TYPE: list[str] or None DEFAULT: None

inplace

If True, modifies prefix_expression in-place; otherwise, works on a shallow copy. Defaults to False.

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
list[str]

The prefix expression with placeholders replaced by strings holding the given numeric values.

RAISES DESCRIPTION
IndexError

If there are more placeholders than supplied values.

Examples:

>>> expr = ['*', '<constant>', '+', 'x', '<constant>']
>>> substitude_constants(expr, [3.14, 2.71])
['*', '3.14', '+', 'x', '2.71']
>>> expr = ['*', 'C_0', '+', 'x', 'C_1']
>>> substitude_constants(expr, [3.14, 2.71], constants=['C_0', 'C_1'])
['*', '3.14', '+', 'x', '2.71']
>>> expr = ['*', 'k1', '+', 'x', 'k2']
>>> substitude_constants(expr, [3.14, 2.71], constants=['k1', 'k2'])
['*', '3.14', '+', 'x', '2.71']
Source code in src/simplipy/utils.py
def substitude_constants(prefix_expression: list[str], values: list | np.ndarray, constants: list[str] | None = None, inplace: bool = False) -> list[str]:
    """Substitute placeholders in a prefix expression with numeric values.

    This helper replaces constant placeholders such as ``"<constant>"`` or the
    tokens listed in ``constants`` with the values supplied in ``values``. Values
    are consumed from left to right as matching tokens are encountered.

    Parameters
    ----------
    prefix_expression : list[str]
        The prefix expression containing constant placeholders.
    values : list or np.ndarray
        The numeric values to substitute into the expression.
    constants : list[str] or None, optional
        An explicit list of placeholder names to be replaced. When ``None``,
        the function considers ``"<constant>"`` and ``C_i`` tokens. Defaults to
        ``None``.
    inplace : bool, optional
        If ``True``, modifies ``prefix_expression`` in-place; otherwise, works on
        a shallow copy. Defaults to ``False``.

    Returns
    -------
    list[str]
        The prefix expression with placeholders replaced by strings holding the
        given numeric values.

    Raises
    ------
    IndexError
        If there are more placeholders than supplied ``values``.

    Examples
    --------
    >>> expr = ['*', '<constant>', '+', 'x', '<constant>']
    >>> substitude_constants(expr, [3.14, 2.71])
    ['*', '3.14', '+', 'x', '2.71']

    >>> expr = ['*', 'C_0', '+', 'x', 'C_1']
    >>> substitude_constants(expr, [3.14, 2.71], constants=['C_0', 'C_1'])
    ['*', '3.14', '+', 'x', '2.71']

    >>> expr = ['*', 'k1', '+', 'x', 'k2']
    >>> substitude_constants(expr, [3.14, 2.71], constants=['k1', 'k2'])
    ['*', '3.14', '+', 'x', '2.71']
    """
    if inplace:
        modified_prefix_expression = prefix_expression
    else:
        modified_prefix_expression = prefix_expression.copy()

    constant_index = 0
    if constants is None:
        constants = []
    else:
        constants = list(constants)

    for i, token in enumerate(prefix_expression):
        if token == "<constant>" or re.match(r"C_\d+", token) or token in constants:
            modified_prefix_expression[i] = str(values[constant_index])
            constant_index += 1

    return modified_prefix_expression

apply_variable_mapping

apply_variable_mapping(prefix_expression: list[str], variable_mapping: dict[str, str]) -> list[str]

Rename variables in a prefix expression using a mapping.

Applies a given mapping to rename variables within a prefix expression. Any token in the expression that is a key in the mapping will be replaced by its corresponding value.

PARAMETER DESCRIPTION
prefix_expression

The prefix expression to modify.

TYPE: list[str]

variable_mapping

A dictionary mapping original variable names to new names.

TYPE: dict[str, str]

RETURNS DESCRIPTION
list[str]

A new prefix expression with variables renamed.

Examples:

>>> expr = ['+', 'var1', 'var2']
>>> mapping = {'var1': 'x', 'var2': 'y'}
>>> apply_variable_mapping(expr, mapping)
['+', 'x', 'y']
Source code in src/simplipy/utils.py
def apply_variable_mapping(prefix_expression: list[str], variable_mapping: dict[str, str]) -> list[str]:
    """Rename variables in a prefix expression using a mapping.

    Applies a given mapping to rename variables within a prefix expression.
    Any token in the expression that is a key in the mapping will be
    replaced by its corresponding value.

    Parameters
    ----------
    prefix_expression : list[str]
        The prefix expression to modify.
    variable_mapping : dict[str, str]
        A dictionary mapping original variable names to new names.

    Returns
    -------
    list[str]
        A new prefix expression with variables renamed.

    Examples
    --------
    >>> expr = ['+', 'var1', 'var2']
    >>> mapping = {'var1': 'x', 'var2': 'y'}
    >>> apply_variable_mapping(expr, mapping)
    ['+', 'x', 'y']
    """
    return list(map(lambda token: variable_mapping.get(token, token), prefix_expression))

numbers_to_constant

numbers_to_constant(prefix_expression: list[str], inplace: bool = False) -> list[str]

Replace all numeric literals in a prefix expression with ''.

This function standardizes an expression by replacing all tokens that can be interpreted as numbers with a generic <constant> placeholder. This is useful for structural comparison and rule matching.

PARAMETER DESCRIPTION
prefix_expression

The prefix expression to process.

TYPE: list[str]

inplace

If True, modifies the list in-place; otherwise, returns a new list. Defaults to False.

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
list[str]

The modified prefix expression.

Examples:

>>> expr = ['+', 'x', '3.14', '*', 'y', '-2']
>>> numbers_to_constant(expr)
['+', 'x', '<constant>', '*', 'y', '<constant>']
Source code in src/simplipy/utils.py
def numbers_to_constant(prefix_expression: list[str], inplace: bool = False) -> list[str]:
    """Replace all numeric literals in a prefix expression with '<constant>'.

    This function standardizes an expression by replacing all tokens that can be
    interpreted as numbers with a generic `<constant>` placeholder. This is
    useful for structural comparison and rule matching.

    Parameters
    ----------
    prefix_expression : list[str]
        The prefix expression to process.
    inplace : bool, optional
        If True, modifies the list in-place; otherwise, returns a new list.
        Defaults to False.

    Returns
    -------
    list[str]
        The modified prefix expression.

    Examples
    --------
    >>> expr = ['+', 'x', '3.14', '*', 'y', '-2']
    >>> numbers_to_constant(expr)
    ['+', 'x', '<constant>', '*', 'y', '<constant>']
    """
    if inplace:
        modified_prefix_expression = prefix_expression
    else:
        modified_prefix_expression = prefix_expression.copy()

    for i, token in enumerate(prefix_expression):
        try:
            float(token)
            modified_prefix_expression[i] = '<constant>'
        except ValueError:
            modified_prefix_expression[i] = token

    return modified_prefix_expression

explicit_constant_placeholders

explicit_constant_placeholders(prefix_expression: list[str], constants: list[str] | None = None, inplace: bool = False, convert_numbers_to_constant: bool = True) -> tuple[list[str], list[str]]

Convert placeholder tokens to explicit constant names (for example C_0, C_1).

"<constant>" tokens — and, when convert_numbers_to_constant is True, integer-like numeric strings or existing C_i tokens — are replaced with explicit constant identifiers. This is useful for generating call signatures where constants are passed as named arguments.

PARAMETER DESCRIPTION
prefix_expression

The prefix expression to process.

TYPE: list[str]

constants

Initial constant names to reuse before generating new ones. The returned list includes these values plus any newly generated identifiers.

TYPE: list[str] or None DEFAULT: None

inplace

If True, modifies the input list; otherwise, works on a shallow copy. Defaults to False.

TYPE: bool DEFAULT: False

convert_numbers_to_constant

If True, numeric strings consisting only of digits are also replaced. Defaults to True.

TYPE: bool DEFAULT: True

RETURNS DESCRIPTION
tuple[list[str], list[str]]

Two items: the modified prefix expression and the list of constant names used in order of appearance.

Examples:

>>> expr = ['*', '<constant>', '+', 'x', '2']
>>> explicit_constant_placeholders(expr)
(['*', 'C_0', '+', 'x', 'C_1'], ['C_0', 'C_1'])
>>> explicit_constant_placeholders(['+', 'C_3', '<constant>'], constants=['K'])
(['+', 'K', 'C_0'], ['K', 'C_0', 'C_1'])
Source code in src/simplipy/utils.py
def explicit_constant_placeholders(prefix_expression: list[str], constants: list[str] | None = None, inplace: bool = False, convert_numbers_to_constant: bool = True) -> tuple[list[str], list[str]]:
    """Convert placeholder tokens to explicit constant names (for example ``C_0``, ``C_1``).

    ``"<constant>"`` tokens — and, when ``convert_numbers_to_constant`` is ``True``,
    integer-like numeric strings or existing ``C_i`` tokens — are replaced with
    explicit constant identifiers. This is useful for generating call signatures
    where constants are passed as named arguments.

    Parameters
    ----------
    prefix_expression : list[str]
        The prefix expression to process.
    constants : list[str] or None, optional
        Initial constant names to reuse before generating new ones. The returned
        list includes these values plus any newly generated identifiers.
    inplace : bool, optional
        If ``True``, modifies the input list; otherwise, works on a shallow copy.
        Defaults to ``False``.
    convert_numbers_to_constant : bool, optional
        If ``True``, numeric strings consisting only of digits are also replaced.
        Defaults to ``True``.

    Returns
    -------
    tuple[list[str], list[str]]
        Two items: the modified prefix expression and the list of constant
        names used in order of appearance.

    Examples
    --------
    >>> expr = ['*', '<constant>', '+', 'x', '2']
    >>> explicit_constant_placeholders(expr)
    (['*', 'C_0', '+', 'x', 'C_1'], ['C_0', 'C_1'])

    >>> explicit_constant_placeholders(['+', 'C_3', '<constant>'], constants=['K'])
    (['+', 'K', 'C_0'], ['K', 'C_0', 'C_1'])
    """
    if inplace:
        modified_prefix_expression = prefix_expression
    else:
        modified_prefix_expression = prefix_expression.copy()

    provided_constants = list(constants) if constants is not None else []
    used_constants: list[str] = []
    provided_index = 0
    generated_index = 0

    for i, token in enumerate(prefix_expression):
        if token == "<constant>" or (convert_numbers_to_constant and (re.match(r"C_\d+", token) or token.isnumeric())):
            if provided_index < len(provided_constants):
                constant_name = provided_constants[provided_index]
                provided_index += 1
            else:
                constant_name = f"C_{generated_index}"
                generated_index += 1

            modified_prefix_expression[i] = constant_name
            used_constants.append(constant_name)

    return modified_prefix_expression, used_constants

flatten_nested_list

flatten_nested_list(nested_list: list) -> list[str]

Flatten an arbitrarily nested list into a single list of leaf values.

A stack-based traversal is used to avoid recursion limits. Because a LIFO stack is employed, values appear in reverse depth-first order relative to the original nesting. list(reversed(...)) can be used to restore a left-to-right ordering if required.

PARAMETER DESCRIPTION
nested_list

The nested list to flatten.

TYPE: list

RETURNS DESCRIPTION
list[str]

The flattened list of elements encountered during traversal.

Examples:

>>> flatten_nested_list([1, [2, [3, 4], 5], 6])
[6, 5, 4, 3, 2, 1]
Source code in src/simplipy/utils.py
def flatten_nested_list(nested_list: list) -> list[str]:
    """Flatten an arbitrarily nested list into a single list of leaf values.

    A stack-based traversal is used to avoid recursion limits. Because a LIFO
    stack is employed, values appear in reverse depth-first order relative to
    the original nesting. ``list(reversed(...))`` can be used to restore a
    left-to-right ordering if required.

    Parameters
    ----------
    nested_list : list
        The nested list to flatten.

    Returns
    -------
    list[str]
        The flattened list of elements encountered during traversal.

    Examples
    --------
    >>> flatten_nested_list([1, [2, [3, 4], 5], 6])
    [6, 5, 4, 3, 2, 1]
    """
    flat_list: list[str] = []
    stack = [nested_list]
    while stack:
        current = stack.pop()
        if isinstance(current, list):
            stack.extend(current)
        else:
            flat_list.append(current)
    return flat_list

is_prime

is_prime(n: int) -> bool

Check if an integer is a prime number.

Determines if the input number n is prime. The implementation includes optimizations such as checking for even numbers and only testing divisors up to the square root of n.

PARAMETER DESCRIPTION
n

The integer to check.

TYPE: int

RETURNS DESCRIPTION
bool

True if n is a prime number, False otherwise.

Examples:

>>> is_prime(29)
True
>>> is_prime(30)
False
Source code in src/simplipy/utils.py
def is_prime(n: int) -> bool:
    """Check if an integer is a prime number.

    Determines if the input number `n` is prime. The implementation includes
    optimizations such as checking for even numbers and only testing divisors
    up to the square root of `n`.

    Parameters
    ----------
    n : int
        The integer to check.

    Returns
    -------
    bool
        True if `n` is a prime number, False otherwise.

    Examples
    --------
    >>> is_prime(29)
    True
    >>> is_prime(30)
    False
    """
    if n % 2 == 0 and n > 2:
        return False
    return all(n % i for i in range(3, int(math.sqrt(n)) + 1, 2))

safe_f

safe_f(f: Callable, X: ndarray, constants: ndarray | None = None) -> np.ndarray

Safely evaluate a compiled function on an array of inputs.

The callable f is invoked with the columns of X unpacked as separate arguments, followed by any optional constants. Scalar results are broadcast to all samples to guarantee a one-dimensional NumPy array of length X.shape[0].

PARAMETER DESCRIPTION
f

The function to evaluate.

TYPE: Callable

X

Two-dimensional array of input samples. Each column is passed as a positional argument to f.

TYPE: ndarray

constants

Extra constant values appended when calling f. Defaults to None.

TYPE: ndarray or None DEFAULT: None

RETURNS DESCRIPTION
ndarray

A one-dimensional array with the evaluation results for each row of X.

Examples:

>>> import numpy as np
>>> f = lambda x, y: x + y
>>> safe_f(f, np.array([[1, 2], [3, 4]]))
array([3, 7])
>>> g = lambda x, y, c0: c0
>>> safe_f(g, np.array([[1, 2], [3, 4]]), constants=np.array([5]))
array([5, 5])
Source code in src/simplipy/utils.py
def safe_f(f: Callable, X: np.ndarray, constants: np.ndarray | None = None) -> np.ndarray:
    """Safely evaluate a compiled function on an array of inputs.

    The callable ``f`` is invoked with the columns of ``X`` unpacked as separate
    arguments, followed by any optional ``constants``. Scalar results are
    broadcast to all samples to guarantee a one-dimensional NumPy array of
    length ``X.shape[0]``.

    Parameters
    ----------
    f : Callable
        The function to evaluate.
    X : np.ndarray
        Two-dimensional array of input samples. Each column is passed as a
        positional argument to ``f``.
    constants : np.ndarray or None, optional
        Extra constant values appended when calling ``f``. Defaults to ``None``.

    Returns
    -------
    np.ndarray
        A one-dimensional array with the evaluation results for each row of
        ``X``.

    Examples
    --------
    >>> import numpy as np
    >>> f = lambda x, y: x + y
    >>> safe_f(f, np.array([[1, 2], [3, 4]]))
    array([3, 7])

    >>> g = lambda x, y, c0: c0
    >>> safe_f(g, np.array([[1, 2], [3, 4]]), constants=np.array([5]))
    array([5, 5])
    """
    if constants is None:
        y = f(*X.T)
    else:
        y = f(*X.T, *constants)
    if not isinstance(y, np.ndarray) or y.shape[0] == 1:
        y = np.full(X.shape[0], y)
    return y

remap_expression

remap_expression(source_expression: list[str], dummy_variables: list[str], variable_mapping: dict | None = None, variable_prefix: str = '_', enumeration_offset: int = 0) -> tuple[list[str], dict]

Standardize variable names in a prefix expression for canonical representation.

Remaps variables (identified from dummy_variables) to a generic, enumerated format (e.g., _0, _1). This is crucial for comparing the structure of two expressions regardless of their original variable names.

PARAMETER DESCRIPTION
source_expression

The prefix expression to remap.

TYPE: list[str]

dummy_variables

A list of tokens to be treated as variables.

TYPE: list[str]

variable_mapping

An existing mapping to apply. If None, a new one is created. Defaults to None.

TYPE: dict or None DEFAULT: None

variable_prefix

The prefix for the new standardized variable names, by default "_".

TYPE: str DEFAULT: '_'

enumeration_offset

The starting number for enumeration, by default 0.

TYPE: int DEFAULT: 0

RETURNS DESCRIPTION
tuple[list[str], dict]

A tuple containing: - The remapped prefix expression. - The variable mapping that was created or used.

Source code in src/simplipy/utils.py
def remap_expression(source_expression: list[str], dummy_variables: list[str], variable_mapping: dict | None = None, variable_prefix: str = "_", enumeration_offset: int = 0) -> tuple[list[str], dict]:
    """Standardize variable names in a prefix expression for canonical representation.

    Remaps variables (identified from `dummy_variables`) to a generic,
    enumerated format (e.g., `_0`, `_1`). This is crucial for comparing the
    structure of two expressions regardless of their original variable names.

    Parameters
    ----------
    source_expression : list[str]
        The prefix expression to remap.
    dummy_variables : list[str]
        A list of tokens to be treated as variables.
    variable_mapping : dict or None, optional
        An existing mapping to apply. If None, a new one is created.
        Defaults to None.
    variable_prefix : str, optional
        The prefix for the new standardized variable names, by default "_".
    enumeration_offset : int, optional
        The starting number for enumeration, by default 0.

    Returns
    -------
    tuple[list[str], dict]
        A tuple containing:
        - The remapped prefix expression.
        - The variable mapping that was created or used.
    """
    source_expression = deepcopy(source_expression)
    if variable_mapping is None:
        variable_mapping = {}
        for i, token in enumerate(source_expression):
            if token in dummy_variables:
                if token not in variable_mapping:
                    variable_mapping[token] = f'{variable_prefix}{len(variable_mapping) + enumeration_offset}'

    for i, token in enumerate(source_expression):
        if token in dummy_variables:
            source_expression[i] = variable_mapping[token]

    return source_expression, variable_mapping

deduplicate_rules

deduplicate_rules(rules_list: list[tuple[tuple[str, ...], tuple[str, ...]]], dummy_variables: list[str], verbose: bool = False) -> list[tuple[tuple[str, ...], tuple[str, ...]]]

Deduplicate a list of simplification rules by canonicalizing variables.

This function processes a list of (source, target) simplification rules. It standardizes the variables in each rule to a canonical form and then

removes duplicates. If multiple rules simplify to different targets from the same canonical source, it keeps the one with the shortest target.

PARAMETER DESCRIPTION
rules_list

The list of simplification rules to deduplicate.

TYPE: list[tuple[tuple[str, ...], tuple[str, ...]]]

dummy_variables

A list of tokens to be treated as variables for remapping.

TYPE: list[str]

verbose

If True, displays a progress bar. Defaults to False.

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
list[tuple[tuple[str, ...], tuple[str, ...]]]

The deduplicated and optimized list of simplification rules.

Source code in src/simplipy/utils.py
def deduplicate_rules(rules_list: list[tuple[tuple[str, ...], tuple[str, ...]]], dummy_variables: list[str], verbose: bool = False) -> list[tuple[tuple[str, ...], tuple[str, ...]]]:
    """Deduplicate a list of simplification rules by canonicalizing variables.

    This function processes a list of (source, target) simplification rules. It
    standardizes the variables in each rule to a canonical form and then

    removes duplicates. If multiple rules simplify to different targets from
    the same canonical source, it keeps the one with the shortest target.

    Parameters
    ----------
    rules_list : list[tuple[tuple[str, ...], tuple[str, ...]]]
        The list of simplification rules to deduplicate.
    dummy_variables : list[str]
        A list of tokens to be treated as variables for remapping.
    verbose : bool, optional
        If True, displays a progress bar. Defaults to False.

    Returns
    -------
    list[tuple[tuple[str, ...], tuple[str, ...]]]
        The deduplicated and optimized list of simplification rules.
    """
    deduplicated_rules: dict[tuple[str, ...], tuple[str, ...]] = {}
    for rule in tqdm(rules_list, desc='Deduplicating rules', disable=not verbose):
        # Rename variables in the source expression
        remapped_source, variable_mapping = remap_expression(list(rule[0]), dummy_variables=dummy_variables)
        remapped_target, _ = remap_expression(list(rule[1]), dummy_variables, variable_mapping)

        remapped_source_key = tuple(remapped_source)
        remapped_target_value = tuple(remapped_target)

        existing_replacement = deduplicated_rules.get(remapped_source_key)
        if existing_replacement is None or len(remapped_target_value) < len(existing_replacement):
            # Found a better (shorter) target expression for the same source
            deduplicated_rules[remapped_source_key] = remapped_target_value

    return list(deduplicated_rules.items())

is_numeric_string

is_numeric_string(s: str) -> bool

Check if a string represents a number (integer or float).

This function determines if the given string can be interpreted as a numeric value. It handles integers, floats, and scientific notation.

Original author: Cecil Curry Source: https://stackoverflow.com/questions/354038/how-do-i-check-if-a-string-represents-a-number-float-or-int

PARAMETER DESCRIPTION
s

The string to check.

TYPE: str

RETURNS DESCRIPTION
bool

True if the string represents a number, False otherwise.

Examples:

>>> is_numeric_string("123")
True
>>> is_numeric_string("-1.5e-2")
True
>>> is_numeric_string("abc")
False
Source code in src/simplipy/utils.py
def is_numeric_string(s: str) -> bool:
    """Check if a string represents a number (integer or float).

    This function determines if the given string can be interpreted as a
    numeric value. It handles integers, floats, and scientific notation.

    Original author: Cecil Curry
    Source: https://stackoverflow.com/questions/354038/how-do-i-check-if-a-string-represents-a-number-float-or-int

    Parameters
    ----------
    s : str
        The string to check.

    Returns
    -------
    bool
        True if the string represents a number, False otherwise.

    Examples
    --------
    >>> is_numeric_string("123")
    True
    >>> is_numeric_string("-1.5e-2")
    True
    >>> is_numeric_string("abc")
    False
    """
    return isinstance(s, str) and s.lstrip('-').replace('.', '', 1).replace('e-', '', 1).replace('e', '', 1).isdigit()

factorize_to_at_most

factorize_to_at_most(p: int, max_factor: int, max_iter: int = 1000) -> list[int]

Factorize an integer into factors limited by max_factor.

This helper decomposes p into a list of factors whose product equals p such that every factor is less than or equal to max_factor. If the decomposition is impossible (for example because p contains a prime factor larger than max_factor) a :class:ValueError is raised instead of returning an invalid factorization.

PARAMETER DESCRIPTION
p

The integer to factorize. Must be greater than or equal to 1.

TYPE: int

max_factor

The maximum allowable value for any single factor. Must be at least 2.

TYPE: int

max_iter

A soft cap on the number of prime factors processed. If the algorithm exceeds this limit, a :class:ValueError is raised to guard against accidental infinite loops.

TYPE: int DEFAULT: 1000

RETURNS DESCRIPTION
list[int]

The factors of p. Their product is equal to p and each factor is less than or equal to max_factor. The factors are yielded in the order they are discovered and are not sorted.

RAISES DESCRIPTION
ValueError

If p cannot be decomposed using the specified max_factor value or if max_iter is exceeded.

Examples:

>>> factorize_to_at_most(100, 10)
[4, 5, 5]
>>> factorize_to_at_most(18, 5)
[2, 3, 3]
Source code in src/simplipy/utils.py
def factorize_to_at_most(p: int, max_factor: int, max_iter: int = 1000) -> list[int]:
    """Factorize an integer into factors limited by ``max_factor``.

    This helper decomposes ``p`` into a list of factors whose product equals
    ``p`` such that every factor is less than or equal to ``max_factor``. If the
    decomposition is impossible (for example because ``p`` contains a prime
    factor larger than ``max_factor``) a :class:`ValueError` is raised instead of
    returning an invalid factorization.

    Parameters
    ----------
    p : int
        The integer to factorize. Must be greater than or equal to ``1``.
    max_factor : int
        The maximum allowable value for any single factor. Must be at least 2.
    max_iter : int, optional
        A soft cap on the number of prime factors processed. If the algorithm
        exceeds this limit, a :class:`ValueError` is raised to guard against
        accidental infinite loops.

    Returns
    -------
    list[int]
        The factors of ``p``. Their product is equal to ``p`` and each factor is
        less than or equal to ``max_factor``. The factors are yielded in the
        order they are discovered and are not sorted.

    Raises
    ------
    ValueError
        If ``p`` cannot be decomposed using the specified ``max_factor`` value
        or if ``max_iter`` is exceeded.

    Examples
    --------
    >>> factorize_to_at_most(100, 10)
    [4, 5, 5]
    >>> factorize_to_at_most(18, 5)
    [2, 3, 3]
    """

    if p < 1:
        raise ValueError("p must be a positive integer")
    if max_factor < 2:
        raise ValueError("max_factor must be at least 2")

    if p == 1:
        return []

    remaining = p
    factors: list[int] = []
    current_factor = 1
    processed_factors = 0

    def flush_current() -> None:
        nonlocal current_factor
        if current_factor > 1:
            factors.append(current_factor)
            current_factor = 1

    divisor = 2
    while divisor * divisor <= remaining:
        while remaining % divisor == 0:
            processed_factors += 1
            if processed_factors > max_iter:
                raise ValueError(
                    f'Factorization of {p} into factors <= {max_factor} exceeded {max_iter} steps')

            if divisor > max_factor:
                raise ValueError(f'Cannot factorize {p} with factors <= {max_factor}')

            if current_factor * divisor <= max_factor:
                current_factor *= divisor
            else:
                flush_current()
                current_factor = divisor

            remaining //= divisor
        divisor = 3 if divisor == 2 else divisor + 2

    if remaining > 1:
        # remaining is prime at this point
        if remaining > max_factor:
            raise ValueError(f'Cannot factorize {p} with factors <= {max_factor}')

        if current_factor * remaining <= max_factor:
            current_factor *= remaining
        else:
            flush_current()
            current_factor = remaining

    flush_current()

    return factors

mask_elementary_literals

mask_elementary_literals(prefix_expression: list[str], inplace: bool = False) -> list[str]

Replace all numeric string literals with the '' token.

Scans a prefix expression and replaces any token that represents a number (e.g., "0", "1", "3.14") with the generic placeholder "". This is used to abstract away specific numbers for general simplification rules.

PARAMETER DESCRIPTION
prefix_expression

The prefix expression to modify.

TYPE: list[str]

inplace

If True, modifies the list in-place; otherwise, returns a new list. Defaults to False.

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
list[str]

The expression with numeric literals masked.

Source code in src/simplipy/utils.py
def mask_elementary_literals(prefix_expression: list[str], inplace: bool = False) -> list[str]:
    """Replace all numeric string literals with the '<constant>' token.

    Scans a prefix expression and replaces any token that represents a number
    (e.g., "0", "1", "3.14") with the generic placeholder "<constant>". This is
    used to abstract away specific numbers for general simplification rules.

    Parameters
    ----------
    prefix_expression : list[str]
        The prefix expression to modify.
    inplace : bool, optional
        If True, modifies the list in-place; otherwise, returns a new list.
        Defaults to False.

    Returns
    -------
    list[str]
        The expression with numeric literals masked.
    """
    if inplace:
        modified_prefix_expression = prefix_expression
    else:
        modified_prefix_expression = prefix_expression.copy()

    for i, token in enumerate(prefix_expression):
        if is_numeric_string(token):
            modified_prefix_expression[i] = '<constant>'

    return modified_prefix_expression

construct_expressions

construct_expressions(expressions_of_length: dict[int, set[tuple[str, ...]]], non_leaf_nodes: dict[str, int], must_have_sizes: list | set | None = None) -> Generator[tuple[str, ...], None, None]

Generate new prefix expressions by combining existing building blocks.

Expressions are grouped by length in expressions_of_length. For each operator in non_leaf_nodes the generator enumerates every compatible tuple of child expressions and yields the resulting prefix encoding. When must_have_sizes is provided, at least one operand must have a length contained in that collection before the expression is yielded.

PARAMETER DESCRIPTION
expressions_of_length

Mapping from expression length to the set of expressions with that length.

TYPE: dict[int, set[tuple[str, ...]]]

non_leaf_nodes

Mapping from operator tokens to their arity.

TYPE: dict[str, int]

must_have_sizes

If provided, filters generated combinations so that at least one child expression has a length contained in this collection. Defaults to None.

TYPE: list or set or None DEFAULT: None

YIELDS DESCRIPTION
tuple[str, ...]

Newly constructed prefix expressions.

Examples:

>>> expressions = {1: {('x',), ('y',)}}
>>> operators = {'+': 2}
>>> sorted(construct_expressions(expressions, operators))
[('+', 'x', 'x'), ('+', 'x', 'y'), ('+', 'y', 'x'), ('+', 'y', 'y')]
Source code in src/simplipy/utils.py
def construct_expressions(expressions_of_length: dict[int, set[tuple[str, ...]]], non_leaf_nodes: dict[str, int], must_have_sizes: list | set | None = None) -> Generator[tuple[str, ...], None, None]:
    """Generate new prefix expressions by combining existing building blocks.

    Expressions are grouped by length in ``expressions_of_length``. For each
    operator in ``non_leaf_nodes`` the generator enumerates every compatible
    tuple of child expressions and yields the resulting prefix encoding. When
    ``must_have_sizes`` is provided, at least one operand must have a length
    contained in that collection before the expression is yielded.

    Parameters
    ----------
    expressions_of_length : dict[int, set[tuple[str, ...]]]
        Mapping from expression length to the set of expressions with that
        length.
    non_leaf_nodes : dict[str, int]
        Mapping from operator tokens to their arity.
    must_have_sizes : list or set or None, optional
        If provided, filters generated combinations so that at least one child
        expression has a length contained in this collection. Defaults to
        ``None``.

    Yields
    ------
    tuple[str, ...]
        Newly constructed prefix expressions.

    Examples
    --------
    >>> expressions = {1: {('x',), ('y',)}}
    >>> operators = {'+': 2}
    >>> sorted(construct_expressions(expressions, operators))
    [('+', 'x', 'x'), ('+', 'x', 'y'), ('+', 'y', 'x'), ('+', 'y', 'y')]
    """
    expressions_of_length_with_lists = {k: list(v) for k, v in expressions_of_length.items()}

    filter_sizes = must_have_sizes is not None and not len(must_have_sizes) == 0
    if must_have_sizes is not None and filter_sizes:
        must_have_sizes_set = set(must_have_sizes)

    # Append existing trees to every operator
    for new_root_operator, arity in non_leaf_nodes.items():
        # Start with the smallest arity-tuples of trees
        for child_lengths in sorted(itertools.product(list(expressions_of_length_with_lists.keys()), repeat=arity), key=lambda x: sum(x)):
            # Check all possible combinations of child trees
            if filter_sizes and not any(length in must_have_sizes_set for length in child_lengths):
                # Skip combinations that do not have any of the required sizes (e.g. duplicates is used correctly)
                continue
            for child_combination in itertools.product(*[expressions_of_length_with_lists[child_length] for child_length in child_lengths]):
                yield (new_root_operator,) + tuple(itertools.chain.from_iterable(child_combination))

apply_mapping

apply_mapping(tree: list, mapping: dict[str, Any]) -> list

Apply a placeholder-to-subtree mapping to a target expression tree.

Trees are represented as [operator, [operands...]] where each operand is itself a tree. Leaves are encoded as one-element lists, for example ['x']. Placeholders such as '_0' are replaced with the corresponding subtree provided in mapping.

PARAMETER DESCRIPTION
tree

The target expression tree containing placeholders.

TYPE: list

mapping

Dictionary mapping placeholder names to the subtrees that should replace them.

TYPE: dict[str, Any]

RETURNS DESCRIPTION
list

A new expression tree with placeholders substituted.

Examples:

>>> template = ['mul', [['_0'], ['_1']]]
>>> mapping = {'_0': ['x'], '_1': ['add', [['y'], ['z']]]}
>>> apply_mapping(template, mapping)
['mul', [['x'], ['add', [['y'], ['z']]]]]
Source code in src/simplipy/utils.py
def apply_mapping(tree: list, mapping: dict[str, Any]) -> list:
    """Apply a placeholder-to-subtree mapping to a target expression tree.

    Trees are represented as ``[operator, [operands...]]`` where each operand is
    itself a tree. Leaves are encoded as one-element lists, for example
    ``['x']``. Placeholders such as ``'_0'`` are replaced with the corresponding
    subtree provided in ``mapping``.

    Parameters
    ----------
    tree : list
        The target expression tree containing placeholders.
    mapping : dict[str, Any]
        Dictionary mapping placeholder names to the subtrees that should
        replace them.

    Returns
    -------
    list
        A new expression tree with placeholders substituted.

    Examples
    --------
    >>> template = ['mul', [['_0'], ['_1']]]
    >>> mapping = {'_0': ['x'], '_1': ['add', [['y'], ['z']]]}
    >>> apply_mapping(template, mapping)
    ['mul', [['x'], ['add', [['y'], ['z']]]]]
    """
    # If the tree is a leaf node, replace the placeholder with the actual subtree defined in the mapping
    if len(tree) == 1 and isinstance(tree[0], str):
        if tree[0].startswith('_'):
            return mapping[tree[0]]  # TODO: I put a bracket here. Find out why this is necessary
        return tree

    operator, operands = tree
    return [operator, [apply_mapping(operand, mapping) for operand in operands]]

match_pattern

match_pattern(tree: list, pattern: list, mapping: dict[str, Any] | None = None) -> tuple[bool, dict[str, Any]]

Recursively match an expression tree against a pattern tree.

tree and pattern use the same representation as described in :func:apply_mapping. Placeholders in pattern (for example '_0') match any subtree. When a match succeeds the mapping is populated with the subtrees that correspond to each placeholder.

PARAMETER DESCRIPTION
tree

The expression tree to be matched.

TYPE: list

pattern

The pattern tree to match against.

TYPE: list

mapping

Initial mapping dictionary. If None, an empty one is created.

TYPE: dict[str, Any] or None DEFAULT: None

RETURNS DESCRIPTION
tuple[bool, dict[str, Any]]

(True, mapping) when the structures align; otherwise (False, mapping). The returned mapping may contain partial assignments even when the match fails.

Examples:

>>> tree = ['mul', [['x'], ['add', [['y'], ['z']]]]]
>>> pattern = ['mul', [['_a'], ['_b']]]
>>> match_pattern(tree, pattern)
(True, {'_a': ['x'], '_b': ['add', [['y'], ['z']]]})
Source code in src/simplipy/utils.py
def match_pattern(tree: list, pattern: list, mapping: dict[str, Any] | None = None) -> tuple[bool, dict[str, Any]]:
    """Recursively match an expression tree against a pattern tree.

    ``tree`` and ``pattern`` use the same representation as described in
    :func:`apply_mapping`. Placeholders in ``pattern`` (for example ``'_0'``)
    match any subtree. When a match succeeds the mapping is populated with the
    subtrees that correspond to each placeholder.

    Parameters
    ----------
    tree : list
        The expression tree to be matched.
    pattern : list
        The pattern tree to match against.
    mapping : dict[str, Any] or None, optional
        Initial mapping dictionary. If ``None``, an empty one is created.

    Returns
    -------
    tuple[bool, dict[str, Any]]
        ``(True, mapping)`` when the structures align; otherwise ``(False, mapping)``.
        The returned mapping may contain partial assignments even when the match
        fails.

    Examples
    --------
    >>> tree = ['mul', [['x'], ['add', [['y'], ['z']]]]]
    >>> pattern = ['mul', [['_a'], ['_b']]]
    >>> match_pattern(tree, pattern)
    (True, {'_a': ['x'], '_b': ['add', [['y'], ['z']]]})
    """
    if mapping is None:
        mapping = {}

    pattern_length = len(pattern)

    # The leaf node is a variable but the pattern is not
    if len(tree) == 1 and isinstance(tree[0], str) and pattern_length != 1:
        return False, mapping

    # Elementary pattern
    pattern_key = pattern[0]
    if pattern_length == 1 and isinstance(pattern_key, str):
        # Check if the pattern is a placeholder to be filled with the tree
        if pattern_key.startswith('_'):
            # Try to match the tree with the placeholder pattern
            existing_value = mapping.get(pattern_key)
            if existing_value is None:
                # Placeholder is not yet filled, can be filled with the tree
                mapping[pattern_key] = tree
                return True, mapping
            else:
                # The placeholder has a mapped value already

                # If the existing value is a constant, it is not a match
                # We cannot map multiple (independent) constants to the same placeholder
                if "<constant>" in flatten_nested_list(existing_value):
                    return False, mapping

                # Placeholder is occupied by another tree, check if the existing value matches the tree
                return (existing_value == tree), mapping

        # The literal pattern must match the tree
        return (tree == pattern), mapping

    # The pattern is tree-structured
    tree_operator, tree_operands = tree
    pattern_operator, pattern_operands = pattern

    # If the operators do not match, the tree does not match the pattern
    if tree_operator != pattern_operator:
        return False, mapping

    # Try to recursively match the operands
    for tree_operand, pattern_operand in zip(tree_operands, pattern_operands):
        # If the pattern operand is a leaf node
        if isinstance(pattern_operand, str):
            # Check if the pattern operand is a placeholder to be filled with the tree operand
            existing_value = mapping.get(pattern_operand)
            if existing_value is None:
                # Placeholder is not yet filled, can be filled with the tree operand
                mapping[pattern_operand] = tree_operand
                return True, mapping
            elif existing_value != tree_operand:
                # Placeholder is occupied by another tree, the tree does not match the pattern
                return False, mapping
        else:
            # Recursively match the tree operand with the pattern operand
            does_match, mapping = match_pattern(tree_operand, pattern_operand, mapping)

            # If the tree operand does not match the pattern operand, the tree does not match the pattern
            if not does_match:
                return False, mapping

    # The tree matches the pattern
    return True, mapping

remove_pow1

remove_pow1(prefix_expression: list[str]) -> list[str]

Remove identity power operations from a prefix expression.

This utility cleans up an expression by removing pow1 operators, which represent raising to the power of 1 (an identity operation), and replaces pow_1 (power of -1) with its canonical equivalent, inv.

PARAMETER DESCRIPTION
prefix_expression

The prefix expression to clean.

TYPE: list[str]

RETURNS DESCRIPTION
list[str]

The cleaned prefix expression without pow1 or pow_1 tokens.

Examples:

>>> expr = ['pow1', 'x', '+', 'y', 'pow_1', 'z']
>>> remove_pow1(expr)
['x', '+', 'y', 'inv', 'z']
Source code in src/simplipy/utils.py
def remove_pow1(prefix_expression: list[str]) -> list[str]:
    """Remove identity power operations from a prefix expression.

    This utility cleans up an expression by removing `pow1` operators, which
    represent raising to the power of 1 (an identity operation), and replaces
    `pow_1` (power of -1) with its canonical equivalent, `inv`.

    Parameters
    ----------
    prefix_expression : list[str]
        The prefix expression to clean.

    Returns
    -------
    list[str]
        The cleaned prefix expression without `pow1` or `pow_1` tokens.

    Examples
    --------
    >>> expr = ['pow1', 'x', '+', 'y', 'pow_1', 'z']
    >>> remove_pow1(expr)
    ['x', '+', 'y', 'inv', 'z']
    """
    filtered_expression = []
    for token in prefix_expression:
        if token == 'pow1':
            continue

        if token == 'pow_1':
            filtered_expression.append('inv')
            continue

        filtered_expression.append(token)

    return filtered_expression

violates_wildcard_multiplicity

violates_wildcard_multiplicity(lhs: list[str] | tuple[str, ...], rhs: list[str] | tuple[str, ...]) -> bool

Check whether a rule violates the non-increasing wildcard multiplicity condition.

A rule lhs -> rhs violates the condition when any wildcard token (matching _\d+) appears more times on the right-hand side than on the left-hand side. Enforcing this property prevents duplication of wildcard-matched subtrees by ensuring that no wildcard occurs more often in the replacement than in the pattern.

PARAMETER DESCRIPTION
lhs

The source (left-hand side) of the rule in prefix notation.

TYPE: list[str] or tuple[str, ...]

rhs

The target (right-hand side) of the rule in prefix notation.

TYPE: list[str] or tuple[str, ...]

RETURNS DESCRIPTION
bool

True if the rule violates the condition (i.e. some wildcard has higher multiplicity on the RHS), False otherwise.

Source code in src/simplipy/utils.py
def violates_wildcard_multiplicity(lhs: list[str] | tuple[str, ...], rhs: list[str] | tuple[str, ...]) -> bool:
    """Check whether a rule violates the non-increasing wildcard multiplicity condition.

    A rule ``lhs -> rhs`` violates the condition when any wildcard token
    (matching ``_\\d+``) appears *more* times on the right-hand side than on
    the left-hand side. Enforcing this property prevents duplication of
    wildcard-matched subtrees by ensuring that no wildcard occurs more often
    in the replacement than in the pattern.

    Parameters
    ----------
    lhs : list[str] or tuple[str, ...]
        The source (left-hand side) of the rule in prefix notation.
    rhs : list[str] or tuple[str, ...]
        The target (right-hand side) of the rule in prefix notation.

    Returns
    -------
    bool
        ``True`` if the rule violates the condition (i.e. some wildcard has
        higher multiplicity on the RHS), ``False`` otherwise.
    """
    lhs_wc = Counter(t for t in lhs if _WILDCARD_RE.match(t))
    rhs_wc = Counter(t for t in rhs if _WILDCARD_RE.match(t))
    return any(rhs_wc[w] > lhs_wc[w] for w in rhs_wc)

I/O Functions

load_config

load_config(config: dict[str, Any] | str, resolve_paths: bool = True) -> dict[str, Any]

Load a configuration file.

PARAMETER DESCRIPTION
config

The configuration dictionary or path to the configuration file.

TYPE: dict or str

resolve_paths

Whether to resolve relative paths in the configuration file, by default True.

TYPE: bool DEFAULT: True

RETURNS DESCRIPTION
dict

The configuration dictionary.

Source code in src/simplipy/io.py
def load_config(config: dict[str, Any] | str, resolve_paths: bool = True) -> dict[str, Any]:
    '''
    Load a configuration file.

    Parameters
    ----------
    config : dict or str
        The configuration dictionary or path to the configuration file.
    resolve_paths : bool, optional
        Whether to resolve relative paths in the configuration file, by default True.

    Returns
    -------
    dict
        The configuration dictionary.
    '''

    if isinstance(config, str):
        config_path = config
        config_base_path = os.path.dirname(config_path)

        if not os.path.exists(config_path):
            raise FileNotFoundError(f'Config file {config_path} not found.')
        if os.path.isfile(config_path):
            with open(config_path, 'r') as config_file:
                config_ = yaml.safe_load(config_file)
        else:
            raise ValueError(f'Config file {config_path} is not a valid file.')

        def resolve_path(value: Any) -> str:
            if isinstance(value, str) and (value.endswith('.yaml') or value.endswith('.json')) and value.startswith('.'):  # HACK: Find a way to check if a string is a path
                return os.path.join(config_base_path, value)
            return value

        if resolve_paths:
            config_ = apply_on_nested(config_, resolve_path)

    else:
        config_ = config

    return config_

save_config

save_config(config: dict[str, Any], directory: str, filename: str, reference: str = 'relative', recursive: bool = True, resolve_paths: bool = False) -> None

Save a configuration dictionary to a YAML file.

PARAMETER DESCRIPTION
config

The configuration dictionary to save.

TYPE: dict

directory

The directory to save the configuration file to.

TYPE: str

filename

The name of the configuration file.

TYPE: str

reference

Determines the reference base path. One of - 'project': relative to the project root - 'absolute': absolute paths

TYPE: str DEFAULT: 'relative'

recursive

Save any referenced configs too

TYPE: bool DEFAULT: True

Source code in src/simplipy/io.py
def save_config(config: dict[str, Any], directory: str, filename: str, reference: str = 'relative', recursive: bool = True, resolve_paths: bool = False) -> None:
    '''
    Save a configuration dictionary to a YAML file.

    Parameters
    ----------
    config : dict
        The configuration dictionary to save.
    directory : str
        The directory to save the configuration file to.
    filename : str
        The name of the configuration file.
    reference : str, optional
        Determines the reference base path. One of
        - 'project': relative to the project root
        - 'absolute': absolute paths
    recursive : bool, optional
        Save any referenced configs too
    '''
    config_ = copy.deepcopy(config)

    def save_config_relative_func(value: Any) -> Any:
        if isinstance(value, str) and value.endswith('.yaml'):
            relative_path = value
            if not value.startswith('.'):
                relative_path = os.path.join('.', os.path.basename(value))
            save_config(load_config(value, resolve_paths=resolve_paths), directory, os.path.basename(relative_path), reference=reference, recursive=recursive, resolve_paths=resolve_paths)
        return value

    def save_config_absolute_func(value: Any) -> Any:
        if isinstance(value, str) and value.endswith('.yaml'):
            relative_path = value
            if not value.startswith('.'):
                relative_path = os.path.abspath(value)
            save_config(load_config(value, resolve_paths=resolve_paths), directory, os.path.basename(relative_path), reference=reference, recursive=recursive, resolve_paths=resolve_paths)
        return value

    if recursive:
        match reference:
            case 'relative':
                apply_on_nested(config_, save_config_relative_func)
            case 'absolute':
                apply_on_nested(config_, save_config_absolute_func)
            case _:
                raise ValueError(f'Invalid reference type: {reference}')

    save_path = os.path.join(directory, filename)
    os.makedirs(os.path.dirname(save_path), exist_ok=True)

    with open(save_path, 'w') as config_file:
        yaml.dump(config_, config_file, sort_keys=False)