399 lines
12 KiB
Python
399 lines
12 KiB
Python
"""
|
|
Formula Engine — facade orchestrating parsing, DAG, and evaluation.
|
|
|
|
Coordinates:
|
|
- Parsing formula text via the DSL parser
|
|
- Registering formulas and their dependencies in the DependencyGraph
|
|
- Evaluating dirty formula columns row-by-row via FormulaEvaluator
|
|
- Updating ns_fast_access caches in the DatagridStore
|
|
"""
|
|
import logging
|
|
from typing import Any, Callable, Optional
|
|
|
|
import numpy as np
|
|
|
|
from .dataclasses import FormulaDefinition, WhereClause
|
|
from .dependency_graph import DependencyGraph
|
|
from .dsl.parser import get_parser
|
|
from .dsl.transformer import FormulaTransformer
|
|
from .evaluator import FormulaEvaluator
|
|
|
|
logger = logging.getLogger("FormulaEngine")
|
|
|
|
# Callback that returns a DatagridStore-like object for a given table name
|
|
RegistryResolver = Callable[[str], Any]
|
|
|
|
|
|
def parse_formula(text: str) -> FormulaDefinition | None:
|
|
"""Parse a formula expression string into a FormulaDefinition AST.
|
|
|
|
Args:
|
|
text: The formula expression string.
|
|
|
|
Returns:
|
|
FormulaDefinition on success, None if text is empty.
|
|
|
|
Raises:
|
|
FormulaSyntaxError: If the formula text is syntactically invalid.
|
|
"""
|
|
text = text.strip() if text else ""
|
|
if not text:
|
|
return None
|
|
|
|
parser = get_parser()
|
|
tree = parser.parse(text)
|
|
if tree is None:
|
|
return None
|
|
|
|
transformer = FormulaTransformer()
|
|
formula = transformer.transform(tree)
|
|
formula.source_text = text
|
|
return formula
|
|
|
|
|
|
class FormulaEngine:
|
|
"""
|
|
Facade for the formula calculation system.
|
|
|
|
Orchestrates formula parsing, dependency tracking, and incremental
|
|
recalculation of formula columns.
|
|
|
|
Args:
|
|
registry_resolver: Callback that takes a table name and returns
|
|
the DatagridStore for that table (used for cross-table refs).
|
|
Provided by DataGridsManager.
|
|
"""
|
|
|
|
def __init__(self, registry_resolver: Optional[RegistryResolver] = None):
|
|
self._graph = DependencyGraph()
|
|
self._registry_resolver = registry_resolver
|
|
# Cache of parsed formulas: {(table, col): FormulaDefinition}
|
|
self._formulas: dict[tuple[str, str], FormulaDefinition] = {}
|
|
|
|
def set_formula(self, table: str, col: str, formula_text: str) -> None:
|
|
"""
|
|
Parse and register a formula for a column.
|
|
|
|
Args:
|
|
table: Table name.
|
|
col: Column name.
|
|
formula_text: The formula expression string.
|
|
|
|
Raises:
|
|
FormulaSyntaxError: If the formula is syntactically invalid.
|
|
FormulaCycleError: If the formula would create a circular dependency.
|
|
"""
|
|
formula_text = formula_text.strip() if formula_text else ""
|
|
if not formula_text:
|
|
self.remove_formula(table, col)
|
|
return
|
|
|
|
formula = parse_formula(formula_text)
|
|
if formula is None:
|
|
self.remove_formula(table, col)
|
|
return
|
|
|
|
# Registers in DAG and raises FormulaCycleError if cycle detected
|
|
self._graph.add_formula(table, col, formula)
|
|
self._formulas[(table, col)] = formula
|
|
|
|
logger.debug("Formula set for %s.%s: %s", table, col, formula_text)
|
|
|
|
def remove_formula(self, table: str, col: str) -> None:
|
|
"""
|
|
Remove a formula column from the engine.
|
|
|
|
Args:
|
|
table: Table name.
|
|
col: Column name.
|
|
"""
|
|
self._graph.remove_formula(table, col)
|
|
self._formulas.pop((table, col), None)
|
|
|
|
def mark_data_changed(
|
|
self,
|
|
table: str,
|
|
col: str,
|
|
rows: Optional[list[int]] = None,
|
|
) -> None:
|
|
"""
|
|
Mark a column's data as changed, propagating dirty flags.
|
|
|
|
Call this when source data is modified so that dependent formula
|
|
columns are re-evaluated on next render.
|
|
|
|
Args:
|
|
table: Table name.
|
|
col: Column name.
|
|
rows: Specific row indices that changed. None means all rows.
|
|
"""
|
|
self._graph.mark_dirty(table, col, rows)
|
|
|
|
def recalculate_if_needed(self, table: str, store: Any) -> bool:
|
|
"""
|
|
Recalculate all dirty formula columns for a table.
|
|
|
|
Should be called at the start of ``mk_body_content_page()`` to
|
|
ensure formula columns are up-to-date before rendering.
|
|
|
|
Updates ``store.ns_fast_access`` and ``store.ns_row_data`` in place.
|
|
|
|
Args:
|
|
table: Table name.
|
|
store: The DatagridStore instance for this table.
|
|
|
|
Returns:
|
|
True if any columns were recalculated, False otherwise.
|
|
"""
|
|
dirty_nodes = self._graph.get_calculation_order(table=table)
|
|
|
|
if not dirty_nodes:
|
|
return False
|
|
|
|
for node in dirty_nodes:
|
|
formula = node.formula
|
|
if formula is None:
|
|
continue
|
|
self._evaluate_column(table, node.column, formula, store)
|
|
self._graph.clear_dirty(node.node_id)
|
|
|
|
# Rebuild ns_row_data after recalculation
|
|
if dirty_nodes and store.ns_fast_access:
|
|
self._rebuild_row_data(store)
|
|
|
|
return True
|
|
|
|
def has_formula(self, table: str, col: str) -> bool:
|
|
"""
|
|
Check if a column has a formula registered.
|
|
|
|
Args:
|
|
table: Table name.
|
|
col: Column name.
|
|
|
|
Returns:
|
|
True if the column has a registered formula.
|
|
"""
|
|
return self._graph.has_formula(table, col)
|
|
|
|
def get_formula_text(self, table: str, col: str) -> Optional[str]:
|
|
"""
|
|
Get the source text of a registered formula.
|
|
|
|
Args:
|
|
table: Table name.
|
|
col: Column name.
|
|
|
|
Returns:
|
|
Formula source text or None if not registered.
|
|
"""
|
|
formula = self._formulas.get((table, col))
|
|
return formula.source_text if formula else None
|
|
|
|
# ==================== Private helpers ====================
|
|
|
|
def _evaluate_column(
|
|
self,
|
|
table: str,
|
|
col: str,
|
|
formula: FormulaDefinition,
|
|
store: Any,
|
|
) -> None:
|
|
"""
|
|
Evaluate a formula column row-by-row and update ns_fast_access.
|
|
|
|
Args:
|
|
table: Table name.
|
|
col: Column name.
|
|
formula: The parsed FormulaDefinition.
|
|
store: The DatagridStore with ns_fast_access and ns_row_data.
|
|
"""
|
|
if store.ns_row_data is None or len(store.ns_row_data) == 0:
|
|
return
|
|
|
|
n_rows = len(store.ns_row_data)
|
|
resolver = self._make_cross_table_resolver(table)
|
|
evaluator = FormulaEvaluator(cross_table_resolver=resolver)
|
|
|
|
# Ensure ns_fast_access exists before the loop so that formula columns
|
|
# evaluated earlier in the same pass are visible to subsequent columns.
|
|
if store.ns_fast_access is None:
|
|
store.ns_fast_access = {}
|
|
|
|
results = np.empty(n_rows, dtype=object)
|
|
|
|
for row_index in range(n_rows):
|
|
# Build row_data from ns_fast_access so that formula columns evaluated
|
|
# earlier in this pass (e.g. B) are available to dependent columns (e.g. C).
|
|
row_data = {
|
|
c: arr[row_index]
|
|
for c, arr in store.ns_fast_access.items()
|
|
if arr is not None and row_index < len(arr)
|
|
}
|
|
results[row_index] = evaluator.evaluate(formula, row_data, row_index)
|
|
|
|
store.ns_fast_access[col] = results
|
|
|
|
logger.debug("Evaluated formula column %s.%s (%d rows)", table, col, n_rows)
|
|
|
|
def _rebuild_row_data(self, store: Any) -> None:
|
|
"""
|
|
Rebuild ns_row_data to include formula column results.
|
|
|
|
This ensures formula values are available to dependent formulas
|
|
in subsequent evaluation passes.
|
|
|
|
Args:
|
|
store: The DatagridStore to update.
|
|
"""
|
|
if store.ns_fast_access is None:
|
|
return
|
|
|
|
n_rows = len(store.ns_row_data)
|
|
for row_index in range(n_rows):
|
|
row = store.ns_row_data[row_index]
|
|
for col, arr in store.ns_fast_access.items():
|
|
if arr is not None and row_index < len(arr):
|
|
row[col] = arr[row_index]
|
|
|
|
def _make_cross_table_resolver(self, current_table: str):
|
|
"""
|
|
Create a cross-table resolver callback for the given table context.
|
|
|
|
Resolution strategy:
|
|
1. Explicit WHERE clause: scan remote column for matching rows.
|
|
2. Implicit join by ``id`` column: match rows where both tables share
|
|
the same id value.
|
|
3. Fallback: match by row_index.
|
|
|
|
Args:
|
|
current_table: The table that contains the formula.
|
|
|
|
Returns:
|
|
A callable ``resolver(table, column, where_clause, row_index) -> value``.
|
|
"""
|
|
|
|
def resolver(
|
|
remote_table: str,
|
|
remote_column: str,
|
|
where_clause: Optional[WhereClause],
|
|
row_index: int,
|
|
) -> Any:
|
|
if self._registry_resolver is None:
|
|
logger.warning(
|
|
"No registry_resolver set for cross-table ref %s.%s",
|
|
remote_table, remote_column,
|
|
)
|
|
return None
|
|
|
|
remote_store = self._registry_resolver(remote_table)
|
|
if remote_store is None:
|
|
logger.warning("Table '%s' not found in registry", remote_table)
|
|
return None
|
|
|
|
ns = remote_store.ns_fast_access
|
|
if not ns or remote_column not in ns:
|
|
logger.debug(
|
|
"Column '%s' not found in table '%s'", remote_column, remote_table
|
|
)
|
|
return None
|
|
|
|
remote_array = ns[remote_column]
|
|
|
|
# Strategy 1: Explicit WHERE clause
|
|
if where_clause is not None:
|
|
return self._resolve_with_where(
|
|
where_clause, remote_store, remote_column,
|
|
remote_array, current_table, row_index,
|
|
)
|
|
|
|
# Strategy 2: Implicit join by 'id' column
|
|
current_store = self._registry_resolver(current_table)
|
|
if (
|
|
current_store is not None
|
|
and current_store.ns_fast_access is not None
|
|
and "id" in current_store.ns_fast_access
|
|
and "id" in ns
|
|
):
|
|
local_id_arr = current_store.ns_fast_access["id"]
|
|
remote_id_arr = ns["id"]
|
|
if row_index < len(local_id_arr):
|
|
local_id = local_id_arr[row_index]
|
|
# Find first matching row in remote table
|
|
matches = np.where(remote_id_arr == local_id)[0]
|
|
if len(matches) > 0:
|
|
return remote_array[matches[0]]
|
|
return None
|
|
|
|
# Strategy 3: Fallback — match by row_index
|
|
if row_index < len(remote_array):
|
|
return remote_array[row_index]
|
|
return None
|
|
|
|
return resolver
|
|
|
|
def _resolve_with_where(
|
|
self,
|
|
where_clause: WhereClause,
|
|
remote_store: Any,
|
|
remote_column: str,
|
|
remote_array: Any,
|
|
current_table: str,
|
|
row_index: int,
|
|
) -> Any:
|
|
"""
|
|
Resolve a cross-table reference using an explicit WHERE clause.
|
|
|
|
Args:
|
|
where_clause: The parsed WHERE clause.
|
|
remote_store: DatagridStore for the remote table.
|
|
remote_column: Column to return value from.
|
|
remote_array: numpy array of the remote column values.
|
|
current_table: Table containing the formula.
|
|
row_index: Current row being evaluated.
|
|
|
|
Returns:
|
|
The value from the first matching remote row, or None.
|
|
"""
|
|
remote_ns = remote_store.ns_fast_access
|
|
if not remote_ns:
|
|
return None
|
|
|
|
# Get the remote key column array
|
|
remote_key_col = where_clause.remote_column
|
|
if remote_key_col not in remote_ns:
|
|
logger.debug(
|
|
"WHERE key column '%s' not found in remote table", remote_key_col
|
|
)
|
|
return None
|
|
|
|
remote_key_array = remote_ns[remote_key_col]
|
|
|
|
# Get the local value to compare
|
|
current_store = self._registry_resolver(current_table) if self._registry_resolver else None
|
|
if current_store is None or current_store.ns_fast_access is None:
|
|
return None
|
|
|
|
local_col = where_clause.local_column
|
|
if local_col not in current_store.ns_fast_access:
|
|
logger.debug("WHERE local column '%s' not found", local_col)
|
|
return None
|
|
|
|
local_array = current_store.ns_fast_access[local_col]
|
|
if row_index >= len(local_array):
|
|
return None
|
|
|
|
local_value = local_array[row_index]
|
|
|
|
# Find matching rows
|
|
try:
|
|
matches = np.where(remote_key_array == local_value)[0]
|
|
except Exception:
|
|
matches = []
|
|
|
|
if len(matches) == 0:
|
|
return None
|
|
|
|
# Return value from first match (use aggregation functions for multi-row)
|
|
return remote_array[matches[0]]
|