Introducing columns formulas

This commit is contained in:
2026-02-13 21:38:00 +01:00
parent 0df78c0513
commit e8443f07f9
29 changed files with 3889 additions and 15 deletions

View File

@@ -418,9 +418,41 @@ class DataGrid(MultipleInstance):
self._df_store.ns_fast_access = _init_fast_access(self._df)
self._df_store.ns_row_data = _init_row_data(self._df)
self._df_store.ns_total_rows = len(self._df) if self._df is not None else 0
if init_state:
self._register_existing_formulas()
return self
def _register_existing_formulas(self) -> None:
"""
Re-register all formula columns with the FormulaEngine.
Called after data reload to ensure the engine knows about all
formula columns and their expressions.
"""
engine = self._get_formula_engine()
if engine is None:
return
table = self.get_table_name()
for col_def in self._state.columns:
if col_def.formula:
try:
engine.set_formula(table, col_def.col_id, col_def.formula)
except Exception as e:
logger.warning("Failed to register formula for %s.%s: %s", table, col_def.col_id, e)
def _recalculate_formulas(self) -> None:
"""
Recalculate dirty formula columns before rendering.
Called at the start of mk_body_content_page() to ensure formula
columns are up-to-date before cells are rendered.
"""
engine = self._get_formula_engine()
if engine is None:
return
engine.recalculate_if_needed(self.get_table_name(), self._df_store)
def _get_format_rules(self, col_pos, row_index, col_def):
"""
Get format rules for a cell, returning only the most specific level defined.
@@ -575,6 +607,11 @@ class DataGrid(MultipleInstance):
def get_table_name(self):
return f"{self._settings.namespace}.{self._settings.name}" if self._settings.namespace else self._settings.name
def get_formula_engine(self):
"""Return the FormulaEngine from the DataGridsManager, if available."""
return self._parent.get_formula_engine()
def mk_headers(self):
resize_cmd = self.commands.set_column_width()
move_cmd = self.commands.move_column()
@@ -701,6 +738,7 @@ class DataGrid(MultipleInstance):
OPTIMIZED: Extract filter keyword once instead of 10,000 times.
OPTIMIZED: Uses OptimizedDiv for rows instead of Div for faster rendering.
"""
self._recalculate_formulas()
df = self._get_filtered_df()
if df is None:
return []

View File

@@ -99,6 +99,9 @@ class DataGridColumnsManager(MultipleInstance):
col_def.type = ColumnType(v)
elif k == "width":
col_def.width = int(v)
elif k == "formula":
col_def.formula = v or ""
self._register_formula(col_def)
else:
setattr(col_def, k, v)
@@ -107,6 +110,21 @@ class DataGridColumnsManager(MultipleInstance):
return self.mk_all_columns()
def _register_formula(self, col_def) -> None:
"""Register or remove a formula column with the FormulaEngine."""
engine = self._parent.get_formula_engine()
if engine is None:
return
table = self._parent.get_table_name()
if col_def.formula:
try:
engine.set_formula(table, col_def.col_id, col_def.formula)
logger.debug("Registered formula for %s.%s", table, col_def.col_id)
except Exception as e:
logger.warning("Formula error for %s.%s: %s", table, col_def.col_id, e)
else:
engine.remove_formula(table, col_def.col_id)
def mk_column_label(self, col_def: DataGridColumnState):
return Div(
mk.mk(
@@ -168,6 +186,17 @@ class DataGridColumnsManager(MultipleInstance):
value=col_def.title,
),
*([
Label("Formula"),
Textarea(
col_def.formula or "",
name="formula",
cls=f"textarea textarea-{size} w-full font-mono",
placeholder="{Column} * {OtherColumn}",
rows=3,
),
] if col_def.type == ColumnType.Formula else []),
legend="Column details",
cls="fieldset border-base-300 rounded-box"
),

View File

@@ -0,0 +1,67 @@
"""
DataGridFormulaEditor — DslEditor for formula column expressions.
Extends DslEditor with formula-specific behavior:
- Parses the formula on content change
- Registers the formula with FormulaEngine
- Triggers a body re-render on the parent DataGrid
"""
import logging
from myfasthtml.controls.DslEditor import DslEditor
from myfasthtml.core.formula.dsl.exceptions import FormulaSyntaxError, FormulaCycleError
logger = logging.getLogger("DataGridFormulaEditor")
class DataGridFormulaEditor(DslEditor):
"""
Formula editor for a specific DataGrid column.
Args:
parent: The parent DataGrid instance.
col_def: The DataGridColumnState for the formula column.
conf: DslEditorConf for CodeMirror configuration.
_id: Optional instance ID.
"""
def __init__(self, parent, col_def, conf=None, _id=None):
super().__init__(parent, conf=conf, _id=_id)
self._col_def = col_def
def on_content_changed(self):
"""
Called when the formula text is changed in the editor.
1. Updates col_def.formula with the new text.
2. Registers the formula with the FormulaEngine.
3. Triggers a body re-render of the parent DataGrid.
"""
formula_text = self.get_content()
# Update the column definition
self._col_def.formula = formula_text or ""
# Register with the FormulaEngine
engine = self._parent._get_formula_engine()
if engine is not None:
table = self._parent.get_table_name()
try:
engine.set_formula(table, self._col_def.col_id, formula_text)
logger.debug(
"Formula updated for %s.%s: %s",
table, self._col_def.col_id, formula_text,
)
except FormulaSyntaxError as e:
logger.debug("Formula syntax error, keeping old formula: %s", e)
return
except FormulaCycleError as e:
logger.warning("Formula cycle detected for %s.%s: %s", table, self._col_def.col_id, e)
return
except Exception as e:
logger.warning("Formula engine error for %s.%s: %s", table, self._col_def.col_id, e)
return
# Save state and re-render the grid body
self._parent.save_state()
return self._parent.render_partial("body")

View File

@@ -12,7 +12,7 @@ from myfasthtml.icons.fluent import brain_circuit20_regular
from myfasthtml.icons.fluent_p1 import filter20_regular, search20_regular
from myfasthtml.icons.fluent_p2 import dismiss_circle20_regular
logger = logging.getLogger("DataGridFilter")
logger = logging.getLogger("DataGridQuery")
DG_QUERY_FILTER = "filter"
DG_QUERY_SEARCH = "search"

View File

@@ -16,6 +16,7 @@ from myfasthtml.core.commands import Command
from myfasthtml.core.dbmanager import DbObject
from myfasthtml.core.formatting.dsl.completion.provider import DatagridMetadataProvider
from myfasthtml.core.formatting.presets import DEFAULT_STYLE_PRESETS, DEFAULT_FORMATTER_PRESETS
from myfasthtml.core.formula.engine import FormulaEngine
from myfasthtml.core.instances import InstancesManager, SingleInstance
from myfasthtml.icons.fluent_p1 import table_add20_regular
from myfasthtml.icons.fluent_p3 import folder_open20_regular
@@ -91,6 +92,11 @@ class DataGridsManager(SingleInstance, DatagridMetadataProvider):
self.style_presets: dict = DEFAULT_STYLE_PRESETS.copy()
self.formatter_presets: dict = DEFAULT_FORMATTER_PRESETS.copy()
self.all_tables_formats: list = []
# Formula engine shared across all DataGrids in this session
self._formula_engine = FormulaEngine(
registry_resolver=self._resolve_store_for_table
)
def upload_from_source(self):
file_upload = FileUpload(self)
@@ -167,10 +173,10 @@ class DataGridsManager(SingleInstance, DatagridMetadataProvider):
def list_column_values(self, table_name, column_name):
return self._registry.get_column_values(table_name, column_name)
def get_row_count(self, table_name):
return self._registry.get_row_count(table_name)
def get_column_type(self, table_name, column_name):
return self._registry.get_column_type(table_name, column_name)
@@ -180,7 +186,29 @@ class DataGridsManager(SingleInstance, DatagridMetadataProvider):
def list_format_presets(self) -> list[str]:
return list(self.formatter_presets.keys())
# === Presets Management ===
def _resolve_store_for_table(self, table_name: str):
"""
Resolve the DatagridStore for a given table name.
Used by FormulaEngine as the registry_resolver callback.
Args:
table_name: Full table name in ``"namespace.name"`` format.
Returns:
DatagridStore instance or None if not found.
"""
try:
as_fullname_dict = self._registry._get_entries_as_full_name_dict()
grid_id = as_fullname_dict.get(table_name)
if grid_id is None:
return None
datagrid = InstancesManager.get(self._session, grid_id, None)
if datagrid is None:
return None
return datagrid._df_store
except Exception:
return None
def get_style_presets(self) -> dict:
"""Get the global style presets."""
@@ -190,6 +218,10 @@ class DataGridsManager(SingleInstance, DatagridMetadataProvider):
"""Get the global formatter presets."""
return self.formatter_presets
def get_formula_engine(self) -> FormulaEngine:
"""The FormulaEngine shared across all DataGrids in this session."""
return self._formula_engine
def add_style_preset(self, name: str, preset: dict):
"""
Add or update a style preset.

View File

@@ -20,6 +20,7 @@ class DataGridColumnState:
visible: bool = True
width: int = DATAGRID_DEFAULT_COLUMN_WIDTH
format: list = field(default_factory=list) #
formula: str = "" # formula expression for ColumnType.Formula columns
@dataclass

View File

@@ -26,6 +26,7 @@ class ColumnType(Enum):
Bool = "Boolean"
Choice = "Choice"
Enum = "Enum"
Formula = "Formula"
class ViewType(Enum):

View File

@@ -98,6 +98,9 @@ class StyleResolver:
return StyleContainer(None, "")
cls = props.pop("__class__", None)
if not props:
return StyleContainer(cls, "")
css = "; ".join(f"{key}: {value}" for key, value in props.items()) + ";"
return StyleContainer(cls, css)

View File

View File

@@ -0,0 +1,79 @@
from dataclasses import dataclass, field
from typing import Any, Optional
@dataclass
class FormulaNode:
"""Base AST node for formula expressions."""
pass
@dataclass
class LiteralNode(FormulaNode):
"""A literal value (number, string, boolean)."""
value: Any
@dataclass
class ColumnRef(FormulaNode):
"""Reference to a column in the current table: {ColumnName}."""
column: str
@dataclass
class WhereClause:
"""WHERE clause for cross-table references: WHERE remote_table.remote_col = local_col."""
remote_table: str
remote_column: str
local_column: str
@dataclass
class CrossTableRef(FormulaNode):
"""Reference to a column in another table: {Table.Column}."""
table: str
column: str
where_clause: Optional[WhereClause] = None
@dataclass
class BinaryOp(FormulaNode):
"""Binary operation: left op right."""
operator: str
left: FormulaNode
right: FormulaNode
@dataclass
class UnaryOp(FormulaNode):
"""Unary operation: -expr or not expr."""
operator: str
operand: FormulaNode
@dataclass
class FunctionCall(FormulaNode):
"""Function call: func(args...)."""
function_name: str
arguments: list = field(default_factory=list)
@dataclass
class ConditionalExpr(FormulaNode):
"""Conditional: value_expr if condition [else else_expr].
Chainable: val1 if cond1 else val2 if cond2 else val3
"""
value_expr: FormulaNode
condition: FormulaNode
else_expr: Optional[FormulaNode] = None
@dataclass
class FormulaDefinition:
"""A complete formula definition for a column."""
expression: FormulaNode
source_text: str = ""
def __str__(self):
return self.source_text

View File

@@ -0,0 +1,386 @@
"""
Dependency Graph (DAG) for formula columns.
Tracks column dependencies, propagates dirty flags, and provides
topological ordering for incremental recalculation.
Node IDs use the format ``"table_name.column_id"`` for column-level
granularity, designed to be extensible to ``"table_name.column_id[row]"``
for cell-level overrides.
"""
import logging
from collections import defaultdict, deque
from dataclasses import dataclass, field
from typing import Optional, Set
from .dataclasses import (
FormulaDefinition,
ColumnRef,
CrossTableRef,
BinaryOp,
UnaryOp,
FunctionCall,
ConditionalExpr,
FormulaNode,
)
from .dsl.exceptions import FormulaCycleError
logger = logging.getLogger("DependencyGraph")
@dataclass
class DependencyNode:
"""
A node in the dependency graph.
Attributes:
node_id: Unique identifier in the format ``"table.column"``.
table: Table name.
column: Column name.
dirty: Whether this node needs recalculation.
dirty_rows: Set of specific row indices that are dirty.
Empty set means all rows are dirty.
formula: The parsed FormulaDefinition for formula nodes.
"""
node_id: str
table: str
column: str
dirty: bool = False
dirty_rows: Set[int] = field(default_factory=set)
formula: Optional[FormulaDefinition] = None
class DependencyGraph:
"""
Directed Acyclic Graph of formula column dependencies.
Tracks which columns depend on which other columns and provides:
- Dirty flag propagation via BFS
- Topological ordering for recalculation (Kahn's algorithm)
- Cycle detection before registering new formulas
The graph is bidirectional:
- ``_dependents[A]`` = set of nodes that depend on A (forward edges)
- ``_precedents[B]`` = set of nodes that B depends on (reverse edges)
"""
def __init__(self):
self._nodes: dict[str, DependencyNode] = {}
# forward: A -> {B, C} means "B and C depend on A"
self._dependents: dict[str, set[str]] = defaultdict(set)
# reverse: B -> {A} means "B depends on A"
self._precedents: dict[str, set[str]] = defaultdict(set)
def add_formula(
self,
table: str,
column: str,
formula: FormulaDefinition,
) -> None:
"""
Register a formula for a column and add dependency edges.
Raises:
FormulaCycleError: If adding this formula would create a cycle.
Args:
table: Table name.
column: Column name.
formula: The parsed FormulaDefinition.
"""
logger.debug(f"add_formula {table}.{column}:{formula.source_text}")
node_id = self._make_node_id(table, column)
# Extract dependency node_ids from the formula AST
dep_ids = self._extract_dependencies(formula.expression, table)
# Temporarily remove old edges to avoid stale dependencies
self._remove_edges(node_id)
# Add new edges
for dep_id in dep_ids:
if dep_id == node_id:
raise FormulaCycleError([node_id])
self._dependents[dep_id].add(node_id)
self._precedents[node_id].add(dep_id)
# Ensure all referenced nodes exist (as data nodes without formulas)
for dep_id in dep_ids:
if dep_id not in self._nodes:
dep_table, dep_col = dep_id.split(".", 1)
self._nodes[dep_id] = DependencyNode(
node_id=dep_id,
table=dep_table,
column=dep_col,
)
# Ensure formula node exists
node = self._get_or_create_node(table, column)
node.formula = formula
node.dirty = True # New formula -> needs evaluation
# Detect cycles using Kahn's algorithm
self._detect_cycles()
logger.debug("Added formula for %s depending on: %s", node_id, dep_ids)
def remove_formula(self, table: str, column: str) -> None:
"""
Remove a formula column and its edges from the graph.
Args:
table: Table name.
column: Column name.
"""
node_id = self._make_node_id(table, column)
self._remove_edges(node_id)
if node_id in self._nodes:
node = self._nodes[node_id]
node.formula = None
node.dirty = False
node.dirty_rows.clear()
# If the node has no dependents either, remove it
if not self._dependents.get(node_id):
del self._nodes[node_id]
logger.debug("Removed formula for %s", node_id)
def get_calculation_order(self, table: Optional[str] = None) -> list[DependencyNode]:
"""
Return dirty formula nodes in topological order.
Uses Kahn's algorithm (BFS-based topological sort).
Only returns nodes with a formula that are dirty.
Args:
table: If provided, filter to only nodes for this table.
Returns:
List of dirty DependencyNode objects in calculation order.
"""
# Build in-degree map for nodes with formulas
formula_nodes = {
nid: node for nid, node in self._nodes.items()
if node.formula is not None and node.dirty
}
if not formula_nodes:
return []
# Kahn's algorithm on the subgraph of formula nodes
in_degree = {nid: 0 for nid in formula_nodes}
for nid in formula_nodes:
for prec_id in self._precedents.get(nid, set()):
if prec_id in formula_nodes:
in_degree[nid] += 1
queue = deque([nid for nid, deg in in_degree.items() if deg == 0])
result = []
while queue:
nid = queue.popleft()
node = self._nodes[nid]
if table is None or node.table == table:
result.append(node)
for dep_id in self._dependents.get(nid, set()):
if dep_id in in_degree:
in_degree[dep_id] -= 1
if in_degree[dep_id] == 0:
queue.append(dep_id)
return result
def clear_dirty(self, node_id: str) -> None:
"""
Clear dirty flags for a node after successful recalculation.
Args:
node_id: The node ID in format ``"table.column"``.
"""
if node_id in self._nodes:
node = self._nodes[node_id]
node.dirty = False
node.dirty_rows.clear()
def get_node(self, table: str, column: str) -> Optional[DependencyNode]:
"""
Get a node by table and column.
Args:
table: Table name.
column: Column name.
Returns:
DependencyNode or None if not found.
"""
node_id = self._make_node_id(table, column)
return self._nodes.get(node_id)
def has_formula(self, table: str, column: str) -> bool:
"""
Check if a column has a formula registered.
Args:
table: Table name.
column: Column name.
Returns:
True if the column has a formula.
"""
node = self.get_node(table, column)
return node is not None and node.formula is not None
def mark_dirty(
self,
table: str,
column: str,
rows: Optional[list[int]] = None,
) -> None:
"""
Mark a column (and its transitive dependents) as dirty.
Uses BFS to propagate dirty flags through the dependency graph.
Args:
table: Table name.
column: Column name.
rows: Specific row indices to mark dirty. None means all rows.
"""
node_id = self._make_node_id(table, column)
self._mark_node_dirty(node_id, rows)
# BFS propagation through dependents
queue = deque([node_id])
visited = {node_id}
while queue:
current_id = queue.popleft()
for dep_id in self._dependents.get(current_id, set()):
self._mark_node_dirty(dep_id, rows)
if dep_id not in visited:
visited.add(dep_id)
queue.append(dep_id)
# ==================== Private helpers ====================
@staticmethod
def _make_node_id(table: str, column: str) -> str:
"""Create a standard node ID from table and column names."""
return f"{table}.{column}"
def _get_or_create_node(self, table: str, column: str) -> DependencyNode:
"""Get existing node or create a new one."""
node_id = self._make_node_id(table, column)
if node_id not in self._nodes:
self._nodes[node_id] = DependencyNode(
node_id=node_id,
table=table,
column=column,
)
return self._nodes[node_id]
def _mark_node_dirty(self, node_id: str, rows: Optional[list[int]]) -> None:
"""Mark a specific node as dirty."""
if node_id not in self._nodes:
return
node = self._nodes[node_id]
node.dirty = True
if rows is not None:
node.dirty_rows.update(rows)
else:
node.dirty_rows.clear() # Empty = all rows dirty
def _remove_edges(self, node_id: str) -> None:
"""Remove all edges connected to a node."""
# Remove this node from its precedents' dependents sets
for prec_id in list(self._precedents.get(node_id, set())):
self._dependents[prec_id].discard(node_id)
# Clear this node's precedents
self._precedents[node_id].clear()
def _detect_cycles(self) -> None:
"""
Detect cycles in the full graph using Kahn's algorithm.
Raises:
FormulaCycleError: If a cycle is detected.
"""
# Only check formula nodes
formula_nodes = {
nid for nid, node in self._nodes.items()
if node.formula is not None
}
if not formula_nodes:
return
in_degree = {}
for nid in formula_nodes:
in_degree[nid] = 0
for nid in formula_nodes:
for prec_id in self._precedents.get(nid, set()):
if prec_id in formula_nodes:
in_degree[nid] = in_degree.get(nid, 0) + 1
queue = deque([nid for nid in formula_nodes if in_degree.get(nid, 0) == 0])
processed = set()
while queue:
nid = queue.popleft()
processed.add(nid)
for dep_id in self._dependents.get(nid, set()):
if dep_id in formula_nodes:
in_degree[dep_id] -= 1
if in_degree[dep_id] == 0:
queue.append(dep_id)
cycle_nodes = formula_nodes - processed
if cycle_nodes:
raise FormulaCycleError(sorted(cycle_nodes))
def _extract_dependencies(
self,
node: FormulaNode,
current_table: str,
) -> set[str]:
"""
Recursively extract all column dependency IDs from a formula AST.
Args:
node: The AST node to walk.
current_table: The table containing this formula (for ColumnRef).
Returns:
Set of dependency node IDs (``"table.column"`` format).
"""
deps = set()
if isinstance(node, ColumnRef):
deps.add(self._make_node_id(current_table, node.column))
elif isinstance(node, CrossTableRef):
deps.add(self._make_node_id(node.table, node.column))
# Also depend on the local column used in WHERE clause
if node.where_clause is not None:
deps.add(self._make_node_id(current_table, node.where_clause.local_column))
elif isinstance(node, BinaryOp):
deps.update(self._extract_dependencies(node.left, current_table))
deps.update(self._extract_dependencies(node.right, current_table))
elif isinstance(node, UnaryOp):
deps.update(self._extract_dependencies(node.operand, current_table))
elif isinstance(node, FunctionCall):
for arg in node.arguments:
deps.update(self._extract_dependencies(arg, current_table))
elif isinstance(node, ConditionalExpr):
deps.update(self._extract_dependencies(node.value_expr, current_table))
deps.update(self._extract_dependencies(node.condition, current_table))
if node.else_expr is not None:
deps.update(self._extract_dependencies(node.else_expr, current_table))
return deps

View File

@@ -0,0 +1,180 @@
"""
Autocompletion engine for the DataGrid Formula DSL.
Provides context-aware suggestions for:
- Column names (after ``{``)
- Cross-table references (``{Table.``)
- Built-in function names
- Keywords: ``if``, ``else``, ``and``, ``or``, ``not``, ``WHERE``
"""
import re
from myfasthtml.core.dsl.base_completion import BaseCompletionEngine
from myfasthtml.core.dsl.types import Position, Suggestion
from myfasthtml.core.formula.evaluator import BUILTIN_FUNCTIONS
from myfasthtml.core.utils import make_safe_id
FORMULA_KEYWORDS = [
"if", "else", "and", "or", "not", "where",
"between", "in", "isempty", "isnotempty", "isnan",
"contains", "startswith", "endswith",
"true", "false",
]
class FormulaCompletionEngine(BaseCompletionEngine):
"""
Context-aware completion engine for formula expressions.
Provides suggestions for column references, functions, and keywords.
Args:
provider: DataGrid metadata provider (DataGridsManager or similar).
table_name: Name of the current table in ``"namespace.name"`` format.
"""
def __init__(self, provider, table_name: str):
super().__init__(provider)
self.table_name = table_name
self._id = "formula_completion_engine#" + make_safe_id(table_name)
def detect_scope(self, text: str, current_line: int):
"""Formula has no scope — always the same single-expression scope."""
return None
def detect_context(self, text: str, cursor: Position, scope):
"""
Detect completion context based on cursor position in formula text.
Args:
text: The full formula text.
cursor: Cursor position (line, ch).
scope: Unused (formulas have no scopes).
Returns:
Context string: ``"column_ref"``, ``"cross_table"``,
``"function"``, ``"keyword"``, or ``"general"``.
"""
# Get text up to cursor
lines = text.split("\n")
line_idx = min(cursor.line, len(lines) - 1)
line_text = lines[line_idx]
text_before = line_text[:cursor.ch]
# Check if we are inside a { ... } reference
last_brace = text_before.rfind("{")
if last_brace >= 0:
inside = text_before[last_brace + 1:]
if "}" not in inside:
if "." in inside:
return "cross_table"
return "column_ref"
# Check if we are typing a function name (alphanumeric at word start)
word_match = re.search(r"[a-z_][a-z0-9_]*$", text_before, re.IGNORECASE)
if word_match:
return "function_or_keyword"
return "general"
def get_suggestions(self, text: str, cursor: Position, scope, context) -> list:
"""
Generate suggestions based on the detected context.
Args:
text: The full formula text.
cursor: Cursor position.
scope: Unused.
context: String from ``detect_context``.
Returns:
List of Suggestion objects.
"""
suggestions = []
if context == "column_ref":
# Suggest columns from the current table
suggestions += self._column_suggestions(self.table_name)
elif context == "cross_table":
# Get the table name prefix from text_before
lines = text.split("\n")
line_text = lines[min(cursor.line, len(lines) - 1)]
text_before = line_text[:cursor.ch]
last_brace = text_before.rfind("{")
inside = text_before[last_brace + 1:] if last_brace >= 0 else ""
dot_pos = inside.rfind(".")
table_prefix = inside[:dot_pos] if dot_pos >= 0 else ""
# Suggest columns from the referenced table
if table_prefix:
suggestions += self._column_suggestions(table_prefix)
else:
suggestions += self._table_suggestions()
elif context == "function_or_keyword":
suggestions += self._function_suggestions()
suggestions += self._keyword_suggestions()
else: # general
suggestions += self._function_suggestions()
suggestions += self._keyword_suggestions()
suggestions += [
Suggestion(
label="{",
detail="Column reference",
insert_text="{",
)
]
return suggestions
# ==================== Private helpers ====================
def _column_suggestions(self, table_name: str) -> list:
"""Get column name suggestions for a table."""
try:
columns = self.provider.list_columns(table_name)
return [
Suggestion(
label=col,
detail=f"Column from {table_name}",
insert_text=col,
)
for col in (columns or [])
]
except Exception:
return []
def _table_suggestions(self) -> list:
"""Get table name suggestions."""
try:
tables = self.provider.list_tables()
return [
Suggestion(
label=t,
detail="Table",
insert_text=t,
)
for t in (tables or [])
]
except Exception:
return []
def _function_suggestions(self) -> list:
"""Get built-in function name suggestions."""
return [
Suggestion(
label=name,
detail="Function",
insert_text=f"{name}(",
)
for name in sorted(BUILTIN_FUNCTIONS.keys())
]
def _keyword_suggestions(self) -> list:
"""Get keyword suggestions."""
return [
Suggestion(label=kw, detail="Keyword", insert_text=kw)
for kw in FORMULA_KEYWORDS
]

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,79 @@
"""
FormulaDSL definition for the DslEditor control.
Provides the Lark grammar and derived completions for the
DataGrid Formula DSL (CodeMirror 5 Simple Mode).
"""
from functools import cached_property
from typing import Dict, Any
from myfasthtml.core.dsl.base import DSLDefinition
from .grammar import FORMULA_GRAMMAR
class FormulaDSL(DSLDefinition):
"""
DSL definition for DataGrid formula expressions.
Uses the Lark grammar from grammar.py to drive syntax highlighting
and autocompletion in the DslEditor.
"""
name: str = "Formula DSL"
def get_grammar(self) -> str:
"""Return the Lark grammar for the formula DSL."""
return FORMULA_GRAMMAR
@cached_property
def simple_mode_config(self) -> Dict[str, Any]:
"""
Return a hand-tuned CodeMirror 5 Simple Mode config for formula syntax.
Overrides the base class to provide optimized highlighting rules
for column references, operators, functions, and keywords.
"""
return {
"start": [
# Column references: {ColumnName} or {Table.Column}
{
"regex": r"\{[A-Za-z_][A-Za-z0-9_.]*(?:\s+where\s+[A-Za-z_][A-Za-z0-9_.]*\s*=\s*[A-Za-z_][A-Za-z0-9_]*)?\}",
"token": "variable-2",
},
# Function names before parenthesis
{
"regex": r"[a-z_][a-z0-9_]*(?=\s*\()",
"token": "keyword",
},
# Keywords: if, else, and, or, not, where, between, in
{
"regex": r"\b(if|else|and|or|not|where|between|in|isempty|isnotempty|isnan|contains|startswith|endswith|true|false)\b",
"token": "keyword",
},
# Numbers
{
"regex": r"[+-]?(\d+\.?\d*|\.\d+)([eE][+-]?\d+)?",
"token": "number",
},
# Strings
{
"regex": r'"[^"\\]*"',
"token": "string",
},
# Operators
{
"regex": r"[=!<>]=?|[+\-*/%^]",
"token": "operator",
},
# Parentheses and brackets
{
"regex": r"[()[\],]",
"token": "punctuation",
},
],
"meta": {
"dontIndentStates": ["comment"],
"lineComment": "#",
},
}

View File

@@ -0,0 +1,35 @@
class FormulaError(Exception):
"""Base exception for formula errors."""
pass
class FormulaSyntaxError(FormulaError):
"""Raised when the formula has syntax errors."""
def __init__(self, message, line=None, column=None, context=None):
self.message = message
self.line = line
self.column = column
self.context = context
super().__init__(self._format_message())
def _format_message(self):
parts = [self.message]
if self.line is not None:
parts.append(f"at line {self.line}")
if self.column is not None:
parts.append(f"col {self.column}")
return " ".join(parts)
class FormulaValidationError(FormulaError):
"""Raised when the formula is syntactically correct but semantically invalid."""
pass
class FormulaCycleError(FormulaError):
"""Raised when formula dependencies contain a cycle."""
def __init__(self, cycle_nodes):
self.cycle_nodes = cycle_nodes
super().__init__(f"Circular dependency detected involving: {', '.join(cycle_nodes)}")

View File

@@ -0,0 +1,100 @@
FORMULA_GRAMMAR = r"""
start: expression
// ==================== Top-level expression ====================
?expression: conditional_expr
// Suffix-if: value_expr if condition [else expression]
// Right-associative for chaining: a if c1 else b if c2 else d
?conditional_expr: or_expr "if" or_expr "else" conditional_expr -> conditional_with_else
| or_expr "if" or_expr -> conditional_no_else
| or_expr
// ==================== Logical ====================
?or_expr: and_expr ("or" and_expr)* -> or_op
?and_expr: not_expr ("and" not_expr)* -> and_op
?not_expr: "not" not_expr -> not_op
| comparison
// ==================== Comparison ====================
?comparison: addition comp_op addition -> comparison_expr
| addition "in" "[" literal ("," literal)* "]" -> in_expr
| addition "between" addition "and" addition -> between_expr
| addition "contains" addition -> contains_expr
| addition "startswith" addition -> startswith_expr
| addition "endswith" addition -> endswith_expr
| addition "isempty" -> isempty_expr
| addition "isnotempty" -> isnotempty_expr
| addition "isnan" -> isnan_expr
| addition
comp_op: "==" -> eq
| "!=" -> ne
| "<=" -> le
| "<" -> lt
| ">=" -> ge
| ">" -> gt
// ==================== Arithmetic ====================
?addition: multiplication (add_op multiplication)* -> add_expr
?multiplication: power (mul_op power)* -> mul_expr
?power: unary ("^" unary)* -> pow_expr
add_op: "+" -> plus
| "-" -> minus
mul_op: "*" -> times
| "/" -> divide
| "%" -> modulo
?unary: "-" unary -> neg
| atom
// ==================== Atoms ====================
?atom: function_call
| cross_table_ref
| column_ref
| literal
| "(" expression ")" -> paren
// ==================== References ====================
// Cross-table must be checked before column_ref since both use { }
// TABLE_NAME.COL_NAME with optional WHERE clause
// Note: whitespace around "where" and "=" is handled by %ignore
cross_table_ref: "{" TABLE_COL_REF "}" -> cross_ref_simple
| "{" TABLE_COL_REF "where" where_clause "}" -> cross_ref_where
column_ref: "{" COL_NAME "}"
where_clause: TABLE_COL_REF "=" COL_NAME
// TABLE_COL_REF matches "TableName.ColumnName" (dot-separated, no spaces)
TABLE_COL_REF: /[A-Za-z_][A-Za-z0-9_]*\.[A-Za-z_][A-Za-z0-9_]*/
COL_NAME: /[A-Za-z_][A-Za-z0-9_ ]*/
// ==================== Functions ====================
function_call: FUNC_NAME "(" [expression ("," expression)*] ")"
FUNC_NAME: /[a-z_][a-z0-9_]*/
// ==================== Literals ====================
?literal: NUMBER -> number_literal
| ESCAPED_STRING -> string_literal
| "true"i -> true_literal
| "false"i -> false_literal
// ==================== Terminals ====================
NUMBER: /[+-]?(\d+\.?\d*|\.\d+)([eE][+-]?\d+)?/
ESCAPED_STRING: "\"" /[^"\\]*/ "\""
%ignore /[ \t\f]+/
%ignore /\#[^\n]*/
"""

View File

@@ -0,0 +1,85 @@
"""
Formula DSL parser using Lark.
Handles parsing of formula expression strings into a Lark AST.
No indentation handling needed — formulas are single-line expressions.
"""
from lark import Lark, UnexpectedInput
from .exceptions import FormulaSyntaxError
from .grammar import FORMULA_GRAMMAR
class FormulaParser:
"""
Parser for the DataGrid formula language.
Uses Lark LALR parser without indentation handling.
Example:
parser = FormulaParser()
tree = parser.parse("{Price} * {Quantity}")
"""
def __init__(self):
self._parser = Lark(
FORMULA_GRAMMAR,
parser="lalr",
propagate_positions=False,
)
def parse(self, text: str):
"""
Parse a formula expression string into a Lark Tree.
Args:
text: The formula expression text.
Returns:
lark.Tree: The parsed AST.
Raises:
FormulaSyntaxError: If the text has syntax errors.
"""
text = text.strip()
if not text:
return None
try:
return self._parser.parse(text)
except UnexpectedInput as e:
context = None
if hasattr(e, "get_context"):
context = e.get_context(text)
raise FormulaSyntaxError(
message=self._format_error_message(e),
line=getattr(e, "line", None),
column=getattr(e, "column", None),
context=context,
) from e
def _format_error_message(self, error: UnexpectedInput) -> str:
"""Format a user-friendly error message from a Lark exception."""
if hasattr(error, "expected"):
expected = list(error.expected)
if len(expected) == 1:
return f"Expected {expected[0]}"
elif len(expected) <= 5:
return f"Expected one of: {', '.join(expected)}"
else:
return "Unexpected input"
return str(error)
# Singleton parser instance
_parser_instance = None
def get_parser() -> FormulaParser:
"""Get the singleton FormulaParser instance."""
global _parser_instance
if _parser_instance is None:
_parser_instance = FormulaParser()
return _parser_instance

View File

@@ -0,0 +1,274 @@
"""
Formula DSL Transformer.
Converts a Lark AST tree into FormulaDefinition and related AST dataclasses.
"""
from lark import Transformer
from ..dataclasses import (
FormulaDefinition,
FormulaNode,
LiteralNode,
ColumnRef,
CrossTableRef,
WhereClause,
BinaryOp,
UnaryOp,
FunctionCall,
ConditionalExpr,
)
class FormulaTransformer(Transformer):
"""
Transforms the Lark parse tree into FormulaDefinition AST dataclasses.
Handles left-associative folding for arithmetic and logical operators.
Handles right-associative chaining for conditional expressions.
"""
# ==================== Top-level ====================
def start(self, items):
"""Return the FormulaDefinition wrapping the single expression."""
expr = items[0]
return FormulaDefinition(expression=expr)
# ==================== Conditionals ====================
def conditional_with_else(self, items):
"""value_expr if condition else else_expr"""
value_expr, condition, else_expr = items
return ConditionalExpr(
value_expr=value_expr,
condition=condition,
else_expr=else_expr,
)
def conditional_no_else(self, items):
"""value_expr if condition"""
value_expr, condition = items
return ConditionalExpr(
value_expr=value_expr,
condition=condition,
else_expr=None,
)
# ==================== Logical ====================
def or_op(self, items):
"""Fold left-associatively: a or b or c -> BinaryOp(or, BinaryOp(or, a, b), c)"""
return self._fold_left(items, "or")
def and_op(self, items):
"""Fold left-associatively: a and b and c -> BinaryOp(and, BinaryOp(and, a, b), c)"""
return self._fold_left(items, "and")
def not_op(self, items):
"""not expr"""
return UnaryOp(operator="not", operand=items[0])
# ==================== Comparisons ====================
def comparison_expr(self, items):
"""left comp_op right"""
left, op, right = items
return BinaryOp(operator=op, left=left, right=right)
def in_expr(self, items):
"""operand in [literal, ...]"""
operand = items[0]
values = list(items[1:])
return BinaryOp(operator="in", left=operand, right=LiteralNode(value=values))
def between_expr(self, items):
"""operand between low and high"""
operand, low, high = items
return BinaryOp(
operator="between",
left=operand,
right=LiteralNode(value=[low, high]),
)
def contains_expr(self, items):
left, right = items
return BinaryOp(operator="contains", left=left, right=right)
def startswith_expr(self, items):
left, right = items
return BinaryOp(operator="startswith", left=left, right=right)
def endswith_expr(self, items):
left, right = items
return BinaryOp(operator="endswith", left=left, right=right)
def isempty_expr(self, items):
return UnaryOp(operator="isempty", operand=items[0])
def isnotempty_expr(self, items):
return UnaryOp(operator="isnotempty", operand=items[0])
def isnan_expr(self, items):
return UnaryOp(operator="isnan", operand=items[0])
# ==================== Comparison operators ====================
def eq(self, items):
return "=="
def ne(self, items):
return "!="
def le(self, items):
return "<="
def lt(self, items):
return "<"
def ge(self, items):
return ">="
def gt(self, items):
return ">"
# ==================== Arithmetic ====================
def add_expr(self, items):
"""Fold left-associatively with alternating operands and operators."""
return self._fold_binary_with_ops(items)
def mul_expr(self, items):
"""Fold left-associatively with alternating operands and operators."""
return self._fold_binary_with_ops(items)
def pow_expr(self, items):
"""Fold left-associatively for power expressions (^ is left-assoc here)."""
# pow_expr items are [base, exp1, exp2, ...] — operator is always "^"
result = items[0]
for exp in items[1:]:
result = BinaryOp(operator="^", left=result, right=exp)
return result
def plus(self, items):
return "+"
def minus(self, items):
return "-"
def times(self, items):
return "*"
def divide(self, items):
return "/"
def modulo(self, items):
return "%"
def neg(self, items):
"""Unary negation: -expr"""
return UnaryOp(operator="-", operand=items[0])
# ==================== References ====================
def cross_ref_simple(self, items):
"""{ Table.Column }"""
table_col = str(items[0])
table, column = table_col.split(".", 1)
return CrossTableRef(table=table, column=column)
def cross_ref_where(self, items):
"""{ Table.Column WHERE remote_table.remote_col = local_col }"""
table_col = str(items[0])
where = items[1]
table, column = table_col.split(".", 1)
return CrossTableRef(table=table, column=column, where_clause=where)
def column_ref(self, items):
"""{ ColumnName }"""
col_name = str(items[0]).strip()
return ColumnRef(column=col_name)
def where_clause(self, items):
"""TABLE_COL_REF = COL_NAME"""
remote_table_col = str(items[0])
local_col = str(items[1]).strip()
remote_table, remote_col = remote_table_col.split(".", 1)
return WhereClause(
remote_table=remote_table,
remote_column=remote_col,
local_column=local_col,
)
# ==================== Functions ====================
def function_call(self, items):
"""func_name(arg1, arg2, ...)"""
func_name = str(items[0]).lower()
args = list(items[1:])
return FunctionCall(function_name=func_name, arguments=args)
# ==================== Literals ====================
def number_literal(self, items):
value = str(items[0])
if "." in value or "e" in value.lower():
return LiteralNode(value=float(value))
try:
return LiteralNode(value=int(value))
except ValueError:
return LiteralNode(value=float(value))
def string_literal(self, items):
raw = str(items[0])
# Remove surrounding double quotes
if raw.startswith('"') and raw.endswith('"'):
return LiteralNode(value=raw[1:-1])
return LiteralNode(value=raw)
def true_literal(self, items):
return LiteralNode(value=True)
def false_literal(self, items):
return LiteralNode(value=False)
def paren(self, items):
"""Parenthesized expression — transparent pass-through."""
return items[0]
# ==================== Helpers ====================
def _fold_left(self, items: list, op: str) -> FormulaNode:
"""
Fold a list of operands left-associatively with a fixed operator.
Args:
items: List of FormulaNode operands.
op: Operator string.
Returns:
Left-folded BinaryOp tree, or the single item if only one.
"""
result = items[0]
for operand in items[1:]:
result = BinaryOp(operator=op, left=result, right=operand)
return result
def _fold_binary_with_ops(self, items: list) -> FormulaNode:
"""
Fold a list of alternating [operand, op, operand, op, operand, ...]
left-associatively.
Args:
items: Alternating list: [expr, op_str, expr, op_str, expr, ...]
Returns:
Left-folded BinaryOp tree.
"""
result = items[0]
i = 1
while i < len(items):
op = items[i]
right = items[i + 1]
result = BinaryOp(operator=op, left=result, right=right)
i += 2
return result

View File

@@ -0,0 +1,398 @@
"""
Formula Engine — facade orchestrating parsing, DAG, and evaluation.
Coordinates:
- Parsing formula text via the DSL parser
- Registering formulas and their dependencies in the DependencyGraph
- Evaluating dirty formula columns row-by-row via FormulaEvaluator
- Updating ns_fast_access caches in the DatagridStore
"""
import logging
from typing import Any, Callable, Optional
import numpy as np
from .dataclasses import FormulaDefinition, WhereClause
from .dependency_graph import DependencyGraph
from .dsl.parser import get_parser
from .dsl.transformer import FormulaTransformer
from .evaluator import FormulaEvaluator
logger = logging.getLogger("FormulaEngine")
# Callback that returns a DatagridStore-like object for a given table name
RegistryResolver = Callable[[str], Any]
def parse_formula(text: str) -> FormulaDefinition | None:
"""Parse a formula expression string into a FormulaDefinition AST.
Args:
text: The formula expression string.
Returns:
FormulaDefinition on success, None if text is empty.
Raises:
FormulaSyntaxError: If the formula text is syntactically invalid.
"""
text = text.strip() if text else ""
if not text:
return None
parser = get_parser()
tree = parser.parse(text)
if tree is None:
return None
transformer = FormulaTransformer()
formula = transformer.transform(tree)
formula.source_text = text
return formula
class FormulaEngine:
"""
Facade for the formula calculation system.
Orchestrates formula parsing, dependency tracking, and incremental
recalculation of formula columns.
Args:
registry_resolver: Callback that takes a table name and returns
the DatagridStore for that table (used for cross-table refs).
Provided by DataGridsManager.
"""
def __init__(self, registry_resolver: Optional[RegistryResolver] = None):
self._graph = DependencyGraph()
self._registry_resolver = registry_resolver
# Cache of parsed formulas: {(table, col): FormulaDefinition}
self._formulas: dict[tuple[str, str], FormulaDefinition] = {}
def set_formula(self, table: str, col: str, formula_text: str) -> None:
"""
Parse and register a formula for a column.
Args:
table: Table name.
col: Column name.
formula_text: The formula expression string.
Raises:
FormulaSyntaxError: If the formula is syntactically invalid.
FormulaCycleError: If the formula would create a circular dependency.
"""
formula_text = formula_text.strip() if formula_text else ""
if not formula_text:
self.remove_formula(table, col)
return
formula = parse_formula(formula_text)
if formula is None:
self.remove_formula(table, col)
return
# Registers in DAG and raises FormulaCycleError if cycle detected
self._graph.add_formula(table, col, formula)
self._formulas[(table, col)] = formula
logger.debug("Formula set for %s.%s: %s", table, col, formula_text)
def remove_formula(self, table: str, col: str) -> None:
"""
Remove a formula column from the engine.
Args:
table: Table name.
col: Column name.
"""
self._graph.remove_formula(table, col)
self._formulas.pop((table, col), None)
def mark_data_changed(
self,
table: str,
col: str,
rows: Optional[list[int]] = None,
) -> None:
"""
Mark a column's data as changed, propagating dirty flags.
Call this when source data is modified so that dependent formula
columns are re-evaluated on next render.
Args:
table: Table name.
col: Column name.
rows: Specific row indices that changed. None means all rows.
"""
self._graph.mark_dirty(table, col, rows)
def recalculate_if_needed(self, table: str, store: Any) -> bool:
"""
Recalculate all dirty formula columns for a table.
Should be called at the start of ``mk_body_content_page()`` to
ensure formula columns are up-to-date before rendering.
Updates ``store.ns_fast_access`` and ``store.ns_row_data`` in place.
Args:
table: Table name.
store: The DatagridStore instance for this table.
Returns:
True if any columns were recalculated, False otherwise.
"""
dirty_nodes = self._graph.get_calculation_order(table=table)
if not dirty_nodes:
return False
for node in dirty_nodes:
formula = node.formula
if formula is None:
continue
self._evaluate_column(table, node.column, formula, store)
self._graph.clear_dirty(node.node_id)
# Rebuild ns_row_data after recalculation
if dirty_nodes and store.ns_fast_access:
self._rebuild_row_data(store)
return True
def has_formula(self, table: str, col: str) -> bool:
"""
Check if a column has a formula registered.
Args:
table: Table name.
col: Column name.
Returns:
True if the column has a registered formula.
"""
return self._graph.has_formula(table, col)
def get_formula_text(self, table: str, col: str) -> Optional[str]:
"""
Get the source text of a registered formula.
Args:
table: Table name.
col: Column name.
Returns:
Formula source text or None if not registered.
"""
formula = self._formulas.get((table, col))
return formula.source_text if formula else None
# ==================== Private helpers ====================
def _evaluate_column(
self,
table: str,
col: str,
formula: FormulaDefinition,
store: Any,
) -> None:
"""
Evaluate a formula column row-by-row and update ns_fast_access.
Args:
table: Table name.
col: Column name.
formula: The parsed FormulaDefinition.
store: The DatagridStore with ns_fast_access and ns_row_data.
"""
if store.ns_row_data is None or len(store.ns_row_data) == 0:
return
n_rows = len(store.ns_row_data)
resolver = self._make_cross_table_resolver(table)
evaluator = FormulaEvaluator(cross_table_resolver=resolver)
# Ensure ns_fast_access exists before the loop so that formula columns
# evaluated earlier in the same pass are visible to subsequent columns.
if store.ns_fast_access is None:
store.ns_fast_access = {}
results = np.empty(n_rows, dtype=object)
for row_index in range(n_rows):
# Build row_data from ns_fast_access so that formula columns evaluated
# earlier in this pass (e.g. B) are available to dependent columns (e.g. C).
row_data = {
c: arr[row_index]
for c, arr in store.ns_fast_access.items()
if arr is not None and row_index < len(arr)
}
results[row_index] = evaluator.evaluate(formula, row_data, row_index)
store.ns_fast_access[col] = results
logger.debug("Evaluated formula column %s.%s (%d rows)", table, col, n_rows)
def _rebuild_row_data(self, store: Any) -> None:
"""
Rebuild ns_row_data to include formula column results.
This ensures formula values are available to dependent formulas
in subsequent evaluation passes.
Args:
store: The DatagridStore to update.
"""
if store.ns_fast_access is None:
return
n_rows = len(store.ns_row_data)
for row_index in range(n_rows):
row = store.ns_row_data[row_index]
for col, arr in store.ns_fast_access.items():
if arr is not None and row_index < len(arr):
row[col] = arr[row_index]
def _make_cross_table_resolver(self, current_table: str):
"""
Create a cross-table resolver callback for the given table context.
Resolution strategy:
1. Explicit WHERE clause: scan remote column for matching rows.
2. Implicit join by ``id`` column: match rows where both tables share
the same id value.
3. Fallback: match by row_index.
Args:
current_table: The table that contains the formula.
Returns:
A callable ``resolver(table, column, where_clause, row_index) -> value``.
"""
def resolver(
remote_table: str,
remote_column: str,
where_clause: Optional[WhereClause],
row_index: int,
) -> Any:
if self._registry_resolver is None:
logger.warning(
"No registry_resolver set for cross-table ref %s.%s",
remote_table, remote_column,
)
return None
remote_store = self._registry_resolver(remote_table)
if remote_store is None:
logger.warning("Table '%s' not found in registry", remote_table)
return None
ns = remote_store.ns_fast_access
if not ns or remote_column not in ns:
logger.debug(
"Column '%s' not found in table '%s'", remote_column, remote_table
)
return None
remote_array = ns[remote_column]
# Strategy 1: Explicit WHERE clause
if where_clause is not None:
return self._resolve_with_where(
where_clause, remote_store, remote_column,
remote_array, current_table, row_index,
)
# Strategy 2: Implicit join by 'id' column
current_store = self._registry_resolver(current_table)
if (
current_store is not None
and current_store.ns_fast_access is not None
and "id" in current_store.ns_fast_access
and "id" in ns
):
local_id_arr = current_store.ns_fast_access["id"]
remote_id_arr = ns["id"]
if row_index < len(local_id_arr):
local_id = local_id_arr[row_index]
# Find first matching row in remote table
matches = np.where(remote_id_arr == local_id)[0]
if len(matches) > 0:
return remote_array[matches[0]]
return None
# Strategy 3: Fallback — match by row_index
if row_index < len(remote_array):
return remote_array[row_index]
return None
return resolver
def _resolve_with_where(
self,
where_clause: WhereClause,
remote_store: Any,
remote_column: str,
remote_array: Any,
current_table: str,
row_index: int,
) -> Any:
"""
Resolve a cross-table reference using an explicit WHERE clause.
Args:
where_clause: The parsed WHERE clause.
remote_store: DatagridStore for the remote table.
remote_column: Column to return value from.
remote_array: numpy array of the remote column values.
current_table: Table containing the formula.
row_index: Current row being evaluated.
Returns:
The value from the first matching remote row, or None.
"""
remote_ns = remote_store.ns_fast_access
if not remote_ns:
return None
# Get the remote key column array
remote_key_col = where_clause.remote_column
if remote_key_col not in remote_ns:
logger.debug(
"WHERE key column '%s' not found in remote table", remote_key_col
)
return None
remote_key_array = remote_ns[remote_key_col]
# Get the local value to compare
current_store = self._registry_resolver(current_table) if self._registry_resolver else None
if current_store is None or current_store.ns_fast_access is None:
return None
local_col = where_clause.local_column
if local_col not in current_store.ns_fast_access:
logger.debug("WHERE local column '%s' not found", local_col)
return None
local_array = current_store.ns_fast_access[local_col]
if row_index >= len(local_array):
return None
local_value = local_array[row_index]
# Find matching rows
try:
matches = np.where(remote_key_array == local_value)[0]
except Exception:
matches = []
if len(matches) == 0:
return None
# Return value from first match (use aggregation functions for multi-row)
return remote_array[matches[0]]

View File

@@ -0,0 +1,522 @@
"""
Formula Evaluator.
Evaluates a FormulaDefinition AST row-by-row using column data.
"""
import logging
import math
from datetime import date, datetime
from typing import Any, Callable, Optional
from .dataclasses import (
FormulaNode,
FormulaDefinition,
LiteralNode,
ColumnRef,
CrossTableRef,
BinaryOp,
UnaryOp,
FunctionCall,
ConditionalExpr,
)
logger = logging.getLogger("FormulaEvaluator")
# Type alias for the cross-table resolver callback
CrossTableResolver = Callable[[str, str, Optional[object], int], Any]
def _safe_numeric(value) -> Optional[float]:
"""Convert value to float, returning None if not possible."""
if value is None:
return None
try:
return float(value)
except (TypeError, ValueError):
return None
# ==================== Built-in function registry ====================
def _fn_round(args):
if len(args) < 1:
return None
value = args[0]
decimals = int(args[1]) if len(args) > 1 else 0
v = _safe_numeric(value)
return round(v, decimals) if v is not None else None
def _fn_abs(args):
v = _safe_numeric(args[0]) if args else None
return abs(v) if v is not None else None
def _fn_min(args):
nums = [_safe_numeric(a) for a in args]
nums = [n for n in nums if n is not None]
return min(nums) if nums else None
def _fn_max(args):
nums = [_safe_numeric(a) for a in args]
nums = [n for n in nums if n is not None]
return max(nums) if nums else None
def _fn_floor(args):
v = _safe_numeric(args[0]) if args else None
return math.floor(v) if v is not None else None
def _fn_ceil(args):
v = _safe_numeric(args[0]) if args else None
return math.ceil(v) if v is not None else None
def _fn_sqrt(args):
v = _safe_numeric(args[0]) if args else None
return math.sqrt(v) if v is not None and v >= 0 else None
def _fn_sum(args):
"""Sum of all arguments (used for inline multi-value, not aggregation)."""
nums = [_safe_numeric(a) for a in args]
nums = [n for n in nums if n is not None]
return sum(nums) if nums else None
def _fn_avg(args):
nums = [_safe_numeric(a) for a in args]
nums = [n for n in nums if n is not None]
return sum(nums) / len(nums) if nums else None
def _fn_len(args):
v = args[0] if args else None
if v is None:
return None
return len(str(v))
def _fn_upper(args):
v = args[0] if args else None
return str(v).upper() if v is not None else None
def _fn_lower(args):
v = args[0] if args else None
return str(v).lower() if v is not None else None
def _fn_trim(args):
v = args[0] if args else None
return str(v).strip() if v is not None else None
def _fn_left(args):
if len(args) < 2:
return None
v, n = args[0], args[1]
if v is None or n is None:
return None
return str(v)[:int(n)]
def _fn_right(args):
if len(args) < 2:
return None
v, n = args[0], args[1]
if v is None or n is None:
return None
n = int(n)
return str(v)[-n:] if n > 0 else ""
def _fn_concat(args):
return "".join(str(a) if a is not None else "" for a in args)
def _fn_year(args):
v = args[0] if args else None
if v is None:
return None
if isinstance(v, (datetime, date)):
return v.year
try:
return datetime.fromisoformat(str(v)).year
except (ValueError, TypeError):
return None
def _fn_month(args):
v = args[0] if args else None
if v is None:
return None
if isinstance(v, (datetime, date)):
return v.month
try:
return datetime.fromisoformat(str(v)).month
except (ValueError, TypeError):
return None
def _fn_day(args):
v = args[0] if args else None
if v is None:
return None
if isinstance(v, (datetime, date)):
return v.day
try:
return datetime.fromisoformat(str(v)).day
except (ValueError, TypeError):
return None
def _fn_today(args):
return date.today()
def _fn_datediff(args):
if len(args) < 2:
return None
d1, d2 = args[0], args[1]
if d1 is None or d2 is None:
return None
try:
if not isinstance(d1, (datetime, date)):
d1 = datetime.fromisoformat(str(d1))
if not isinstance(d2, (datetime, date)):
d2 = datetime.fromisoformat(str(d2))
delta = d1 - d2
return delta.days
except (ValueError, TypeError):
return None
def _fn_coalesce(args):
for a in args:
if a is not None:
return a
return None
def _fn_if_error(args):
# if_error(expr, fallback) - expr already evaluated, error would be None
if len(args) < 2:
return args[0] if args else None
return args[0] if args[0] is not None else args[1]
def _fn_count(args):
"""Count non-None values (used for aggregation results)."""
return sum(1 for a in args if a is not None)
# ==================== Function registry ====================
BUILTIN_FUNCTIONS: dict[str, Callable] = {
# Math
"round": _fn_round,
"abs": _fn_abs,
"min": _fn_min,
"max": _fn_max,
"floor": _fn_floor,
"ceil": _fn_ceil,
"sqrt": _fn_sqrt,
"sum": _fn_sum,
"avg": _fn_avg,
# Text
"len": _fn_len,
"upper": _fn_upper,
"lower": _fn_lower,
"trim": _fn_trim,
"left": _fn_left,
"right": _fn_right,
"concat": _fn_concat,
# Date
"year": _fn_year,
"month": _fn_month,
"day": _fn_day,
"today": _fn_today,
"datediff": _fn_datediff,
# Utility
"coalesce": _fn_coalesce,
"if_error": _fn_if_error,
"count": _fn_count,
}
# ==================== Evaluator ====================
class FormulaEvaluator:
"""
Row-by-row formula evaluator.
Evaluates a FormulaDefinition AST against a single row of data.
Args:
cross_table_resolver: Optional callback for cross-table references.
Signature: resolver(table, column, where_clause, row_index) -> value
"""
def __init__(self, cross_table_resolver: Optional[CrossTableResolver] = None):
self._cross_table_resolver = cross_table_resolver
def evaluate(
self,
formula: FormulaDefinition,
row_data: dict,
row_index: int,
) -> Any:
"""
Evaluate a formula for a single row.
Args:
formula: The parsed FormulaDefinition AST.
row_data: Dict mapping column_id -> value for the current row.
row_index: The integer index of the current row.
Returns:
The computed value, or None on error.
"""
try:
return self._eval(formula.expression, row_data, row_index)
except Exception as exc:
logger.warning(
"Formula evaluation error at row %d: %s", row_index, exc
)
return None
def _eval(self, node: FormulaNode, row_data: dict, row_index: int) -> Any:
"""
Recursively evaluate an AST node.
Args:
node: The AST node to evaluate.
row_data: Current row data dict.
row_index: Current row index.
Returns:
Evaluated value.
"""
if isinstance(node, LiteralNode):
return node.value
if isinstance(node, ColumnRef):
return self._resolve_column(node.column, row_data)
if isinstance(node, CrossTableRef):
return self._resolve_cross_table(node, row_index)
if isinstance(node, BinaryOp):
return self._eval_binary(node, row_data, row_index)
if isinstance(node, UnaryOp):
return self._eval_unary(node, row_data, row_index)
if isinstance(node, FunctionCall):
return self._eval_function(node, row_data, row_index)
if isinstance(node, ConditionalExpr):
return self._eval_conditional(node, row_data, row_index)
logger.warning("Unknown AST node type: %s", type(node).__name__)
return None
def _resolve_column(self, column_name: str, row_data: dict) -> Any:
"""Resolve a column reference in the current row."""
if column_name in row_data:
return row_data[column_name]
# Try case-insensitive match
lower_name = column_name.lower()
for key, value in row_data.items():
if str(key).lower() == lower_name:
return value
logger.debug("Column '%s' not found in row_data", column_name)
return None
def _resolve_cross_table(self, node: CrossTableRef, row_index: int) -> Any:
"""Resolve a cross-table reference."""
if self._cross_table_resolver is None:
logger.warning(
"No cross_table_resolver set for cross-table ref %s.%s",
node.table, node.column,
)
return None
return self._cross_table_resolver(
node.table, node.column, node.where_clause, row_index
)
def _eval_binary(self, node: BinaryOp, row_data: dict, row_index: int) -> Any:
"""Evaluate a binary operation."""
left = self._eval(node.left, row_data, row_index)
op = node.operator
# Short-circuit for logical operators
if op == "and":
if not self._truthy(left):
return False
right = self._eval(node.right, row_data, row_index)
return self._truthy(left) and self._truthy(right)
if op == "or":
if self._truthy(left):
return True
right = self._eval(node.right, row_data, row_index)
return self._truthy(left) or self._truthy(right)
right = self._eval(node.right, row_data, row_index)
# Arithmetic
if op == "+":
if isinstance(left, str) or isinstance(right, str):
return str(left or "") + str(right or "")
return self._num_op(left, right, lambda a, b: a + b)
if op == "-":
return self._num_op(left, right, lambda a, b: a - b)
if op == "*":
return self._num_op(left, right, lambda a, b: a * b)
if op == "/":
if right == 0 or right == 0.0:
return None # Division by zero -> None
return self._num_op(left, right, lambda a, b: a / b)
if op == "%":
if right == 0 or right == 0.0:
return None
return self._num_op(left, right, lambda a, b: a % b)
if op == "^":
return self._num_op(left, right, lambda a, b: a ** b)
# Comparison
if op == "==":
return left == right
if op == "!=":
return left != right
if op == "<":
return self._compare(left, right) < 0
if op == "<=":
return self._compare(left, right) <= 0
if op == ">":
return self._compare(left, right) > 0
if op == ">=":
return self._compare(left, right) >= 0
# String operations
if op == "contains":
return str(right) in str(left) if left is not None else False
if op == "startswith":
return str(left).startswith(str(right)) if left is not None else False
if op == "endswith":
return str(left).endswith(str(right)) if left is not None else False
# Collection operations
if op == "in":
values = right.value if isinstance(right, LiteralNode) else right
if isinstance(values, list):
return left in [v.value if isinstance(v, LiteralNode) else v for v in values]
return left in (values or [])
if op == "between":
values = right.value if isinstance(right, LiteralNode) else right
if isinstance(values, list) and len(values) == 2:
lo = values[0].value if isinstance(values[0], LiteralNode) else values[0]
hi = values[1].value if isinstance(values[1], LiteralNode) else values[1]
return lo <= left <= hi
return None
logger.warning("Unknown binary operator: %s", op)
return None
def _eval_unary(self, node: UnaryOp, row_data: dict, row_index: int) -> Any:
"""Evaluate a unary operation."""
operand = self._eval(node.operand, row_data, row_index)
op = node.operator
if op == "-":
v = _safe_numeric(operand)
return -v if v is not None else None
if op == "not":
return not self._truthy(operand)
if op == "isempty":
return operand is None or operand == "" or operand == []
if op == "isnotempty":
return operand is not None and operand != "" and operand != []
if op == "isnan":
try:
return math.isnan(float(operand))
except (TypeError, ValueError):
return False
logger.warning("Unknown unary operator: %s", op)
return None
def _eval_function(self, node: FunctionCall, row_data: dict, row_index: int) -> Any:
"""Evaluate a function call."""
name = node.function_name.lower()
args = [self._eval(arg, row_data, row_index) for arg in node.arguments]
if name in BUILTIN_FUNCTIONS:
try:
return BUILTIN_FUNCTIONS[name](args)
except Exception as exc:
logger.warning("Function '%s' error: %s", name, exc)
return None
logger.warning("Unknown function: %s", name)
return None
def _eval_conditional(self, node: ConditionalExpr, row_data: dict, row_index: int) -> Any:
"""Evaluate a conditional expression."""
condition = self._eval(node.condition, row_data, row_index)
if self._truthy(condition):
return self._eval(node.value_expr, row_data, row_index)
elif node.else_expr is not None:
return self._eval(node.else_expr, row_data, row_index)
return None
@staticmethod
def _truthy(value: Any) -> bool:
"""Convert a value to boolean for conditional evaluation."""
if value is None:
return False
if isinstance(value, bool):
return value
if isinstance(value, (int, float)):
return value != 0
if isinstance(value, str):
return len(value) > 0
return bool(value)
@staticmethod
def _num_op(left: Any, right: Any, fn: Callable) -> Any:
"""Apply a numeric binary function, returning None if inputs are non-numeric."""
a = _safe_numeric(left)
b = _safe_numeric(right)
if a is None or b is None:
return None
try:
return fn(a, b)
except (ZeroDivisionError, OverflowError, ValueError):
return None
@staticmethod
def _compare(left: Any, right: Any) -> int:
"""
Compare two values, returning -1, 0, or 1.
Handles mixed numeric/string comparisons gracefully.
"""
try:
if left < right:
return -1
elif left > right:
return 1
return 0
except TypeError:
# Fallback: compare as strings
sl, sr = str(left), str(right)
if sl < sr:
return -1
elif sl > sr:
return 1
return 0