Introducing columns formulas
This commit is contained in:
@@ -26,6 +26,7 @@ class ColumnType(Enum):
|
||||
Bool = "Boolean"
|
||||
Choice = "Choice"
|
||||
Enum = "Enum"
|
||||
Formula = "Formula"
|
||||
|
||||
|
||||
class ViewType(Enum):
|
||||
|
||||
@@ -98,6 +98,9 @@ class StyleResolver:
|
||||
return StyleContainer(None, "")
|
||||
|
||||
cls = props.pop("__class__", None)
|
||||
if not props:
|
||||
return StyleContainer(cls, "")
|
||||
|
||||
css = "; ".join(f"{key}: {value}" for key, value in props.items()) + ";"
|
||||
|
||||
|
||||
return StyleContainer(cls, css)
|
||||
|
||||
0
src/myfasthtml/core/formula/__init__.py
Normal file
0
src/myfasthtml/core/formula/__init__.py
Normal file
79
src/myfasthtml/core/formula/dataclasses.py
Normal file
79
src/myfasthtml/core/formula/dataclasses.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class FormulaNode:
|
||||
"""Base AST node for formula expressions."""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class LiteralNode(FormulaNode):
|
||||
"""A literal value (number, string, boolean)."""
|
||||
value: Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class ColumnRef(FormulaNode):
|
||||
"""Reference to a column in the current table: {ColumnName}."""
|
||||
column: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class WhereClause:
|
||||
"""WHERE clause for cross-table references: WHERE remote_table.remote_col = local_col."""
|
||||
remote_table: str
|
||||
remote_column: str
|
||||
local_column: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class CrossTableRef(FormulaNode):
|
||||
"""Reference to a column in another table: {Table.Column}."""
|
||||
table: str
|
||||
column: str
|
||||
where_clause: Optional[WhereClause] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class BinaryOp(FormulaNode):
|
||||
"""Binary operation: left op right."""
|
||||
operator: str
|
||||
left: FormulaNode
|
||||
right: FormulaNode
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnaryOp(FormulaNode):
|
||||
"""Unary operation: -expr or not expr."""
|
||||
operator: str
|
||||
operand: FormulaNode
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionCall(FormulaNode):
|
||||
"""Function call: func(args...)."""
|
||||
function_name: str
|
||||
arguments: list = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConditionalExpr(FormulaNode):
|
||||
"""Conditional: value_expr if condition [else else_expr].
|
||||
|
||||
Chainable: val1 if cond1 else val2 if cond2 else val3
|
||||
"""
|
||||
value_expr: FormulaNode
|
||||
condition: FormulaNode
|
||||
else_expr: Optional[FormulaNode] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FormulaDefinition:
|
||||
"""A complete formula definition for a column."""
|
||||
expression: FormulaNode
|
||||
source_text: str = ""
|
||||
|
||||
def __str__(self):
|
||||
return self.source_text
|
||||
386
src/myfasthtml/core/formula/dependency_graph.py
Normal file
386
src/myfasthtml/core/formula/dependency_graph.py
Normal file
@@ -0,0 +1,386 @@
|
||||
"""
|
||||
Dependency Graph (DAG) for formula columns.
|
||||
|
||||
Tracks column dependencies, propagates dirty flags, and provides
|
||||
topological ordering for incremental recalculation.
|
||||
|
||||
Node IDs use the format ``"table_name.column_id"`` for column-level
|
||||
granularity, designed to be extensible to ``"table_name.column_id[row]"``
|
||||
for cell-level overrides.
|
||||
"""
|
||||
import logging
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Set
|
||||
|
||||
from .dataclasses import (
|
||||
FormulaDefinition,
|
||||
ColumnRef,
|
||||
CrossTableRef,
|
||||
BinaryOp,
|
||||
UnaryOp,
|
||||
FunctionCall,
|
||||
ConditionalExpr,
|
||||
FormulaNode,
|
||||
)
|
||||
from .dsl.exceptions import FormulaCycleError
|
||||
|
||||
logger = logging.getLogger("DependencyGraph")
|
||||
|
||||
|
||||
@dataclass
|
||||
class DependencyNode:
|
||||
"""
|
||||
A node in the dependency graph.
|
||||
|
||||
Attributes:
|
||||
node_id: Unique identifier in the format ``"table.column"``.
|
||||
table: Table name.
|
||||
column: Column name.
|
||||
dirty: Whether this node needs recalculation.
|
||||
dirty_rows: Set of specific row indices that are dirty.
|
||||
Empty set means all rows are dirty.
|
||||
formula: The parsed FormulaDefinition for formula nodes.
|
||||
"""
|
||||
node_id: str
|
||||
table: str
|
||||
column: str
|
||||
dirty: bool = False
|
||||
dirty_rows: Set[int] = field(default_factory=set)
|
||||
formula: Optional[FormulaDefinition] = None
|
||||
|
||||
|
||||
class DependencyGraph:
|
||||
"""
|
||||
Directed Acyclic Graph of formula column dependencies.
|
||||
|
||||
Tracks which columns depend on which other columns and provides:
|
||||
- Dirty flag propagation via BFS
|
||||
- Topological ordering for recalculation (Kahn's algorithm)
|
||||
- Cycle detection before registering new formulas
|
||||
|
||||
The graph is bidirectional:
|
||||
- ``_dependents[A]`` = set of nodes that depend on A (forward edges)
|
||||
- ``_precedents[B]`` = set of nodes that B depends on (reverse edges)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._nodes: dict[str, DependencyNode] = {}
|
||||
# forward: A -> {B, C} means "B and C depend on A"
|
||||
self._dependents: dict[str, set[str]] = defaultdict(set)
|
||||
# reverse: B -> {A} means "B depends on A"
|
||||
self._precedents: dict[str, set[str]] = defaultdict(set)
|
||||
|
||||
def add_formula(
|
||||
self,
|
||||
table: str,
|
||||
column: str,
|
||||
formula: FormulaDefinition,
|
||||
) -> None:
|
||||
"""
|
||||
Register a formula for a column and add dependency edges.
|
||||
|
||||
Raises:
|
||||
FormulaCycleError: If adding this formula would create a cycle.
|
||||
|
||||
Args:
|
||||
table: Table name.
|
||||
column: Column name.
|
||||
formula: The parsed FormulaDefinition.
|
||||
"""
|
||||
logger.debug(f"add_formula {table}.{column}:{formula.source_text}")
|
||||
|
||||
node_id = self._make_node_id(table, column)
|
||||
|
||||
# Extract dependency node_ids from the formula AST
|
||||
dep_ids = self._extract_dependencies(formula.expression, table)
|
||||
|
||||
# Temporarily remove old edges to avoid stale dependencies
|
||||
self._remove_edges(node_id)
|
||||
|
||||
# Add new edges
|
||||
for dep_id in dep_ids:
|
||||
if dep_id == node_id:
|
||||
raise FormulaCycleError([node_id])
|
||||
self._dependents[dep_id].add(node_id)
|
||||
self._precedents[node_id].add(dep_id)
|
||||
|
||||
# Ensure all referenced nodes exist (as data nodes without formulas)
|
||||
for dep_id in dep_ids:
|
||||
if dep_id not in self._nodes:
|
||||
dep_table, dep_col = dep_id.split(".", 1)
|
||||
self._nodes[dep_id] = DependencyNode(
|
||||
node_id=dep_id,
|
||||
table=dep_table,
|
||||
column=dep_col,
|
||||
)
|
||||
|
||||
# Ensure formula node exists
|
||||
node = self._get_or_create_node(table, column)
|
||||
node.formula = formula
|
||||
node.dirty = True # New formula -> needs evaluation
|
||||
|
||||
# Detect cycles using Kahn's algorithm
|
||||
self._detect_cycles()
|
||||
|
||||
logger.debug("Added formula for %s depending on: %s", node_id, dep_ids)
|
||||
|
||||
def remove_formula(self, table: str, column: str) -> None:
|
||||
"""
|
||||
Remove a formula column and its edges from the graph.
|
||||
|
||||
Args:
|
||||
table: Table name.
|
||||
column: Column name.
|
||||
"""
|
||||
node_id = self._make_node_id(table, column)
|
||||
self._remove_edges(node_id)
|
||||
|
||||
if node_id in self._nodes:
|
||||
node = self._nodes[node_id]
|
||||
node.formula = None
|
||||
node.dirty = False
|
||||
node.dirty_rows.clear()
|
||||
# If the node has no dependents either, remove it
|
||||
if not self._dependents.get(node_id):
|
||||
del self._nodes[node_id]
|
||||
|
||||
logger.debug("Removed formula for %s", node_id)
|
||||
|
||||
def get_calculation_order(self, table: Optional[str] = None) -> list[DependencyNode]:
|
||||
"""
|
||||
Return dirty formula nodes in topological order.
|
||||
|
||||
Uses Kahn's algorithm (BFS-based topological sort).
|
||||
Only returns nodes with a formula that are dirty.
|
||||
|
||||
Args:
|
||||
table: If provided, filter to only nodes for this table.
|
||||
|
||||
Returns:
|
||||
List of dirty DependencyNode objects in calculation order.
|
||||
"""
|
||||
# Build in-degree map for nodes with formulas
|
||||
formula_nodes = {
|
||||
nid: node for nid, node in self._nodes.items()
|
||||
if node.formula is not None and node.dirty
|
||||
}
|
||||
|
||||
if not formula_nodes:
|
||||
return []
|
||||
|
||||
# Kahn's algorithm on the subgraph of formula nodes
|
||||
in_degree = {nid: 0 for nid in formula_nodes}
|
||||
for nid in formula_nodes:
|
||||
for prec_id in self._precedents.get(nid, set()):
|
||||
if prec_id in formula_nodes:
|
||||
in_degree[nid] += 1
|
||||
|
||||
queue = deque([nid for nid, deg in in_degree.items() if deg == 0])
|
||||
result = []
|
||||
|
||||
while queue:
|
||||
nid = queue.popleft()
|
||||
node = self._nodes[nid]
|
||||
if table is None or node.table == table:
|
||||
result.append(node)
|
||||
for dep_id in self._dependents.get(nid, set()):
|
||||
if dep_id in in_degree:
|
||||
in_degree[dep_id] -= 1
|
||||
if in_degree[dep_id] == 0:
|
||||
queue.append(dep_id)
|
||||
|
||||
return result
|
||||
|
||||
def clear_dirty(self, node_id: str) -> None:
|
||||
"""
|
||||
Clear dirty flags for a node after successful recalculation.
|
||||
|
||||
Args:
|
||||
node_id: The node ID in format ``"table.column"``.
|
||||
"""
|
||||
if node_id in self._nodes:
|
||||
node = self._nodes[node_id]
|
||||
node.dirty = False
|
||||
node.dirty_rows.clear()
|
||||
|
||||
def get_node(self, table: str, column: str) -> Optional[DependencyNode]:
|
||||
"""
|
||||
Get a node by table and column.
|
||||
|
||||
Args:
|
||||
table: Table name.
|
||||
column: Column name.
|
||||
|
||||
Returns:
|
||||
DependencyNode or None if not found.
|
||||
"""
|
||||
node_id = self._make_node_id(table, column)
|
||||
return self._nodes.get(node_id)
|
||||
|
||||
def has_formula(self, table: str, column: str) -> bool:
|
||||
"""
|
||||
Check if a column has a formula registered.
|
||||
|
||||
Args:
|
||||
table: Table name.
|
||||
column: Column name.
|
||||
|
||||
Returns:
|
||||
True if the column has a formula.
|
||||
"""
|
||||
node = self.get_node(table, column)
|
||||
return node is not None and node.formula is not None
|
||||
|
||||
def mark_dirty(
|
||||
self,
|
||||
table: str,
|
||||
column: str,
|
||||
rows: Optional[list[int]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Mark a column (and its transitive dependents) as dirty.
|
||||
|
||||
Uses BFS to propagate dirty flags through the dependency graph.
|
||||
|
||||
Args:
|
||||
table: Table name.
|
||||
column: Column name.
|
||||
rows: Specific row indices to mark dirty. None means all rows.
|
||||
"""
|
||||
node_id = self._make_node_id(table, column)
|
||||
self._mark_node_dirty(node_id, rows)
|
||||
|
||||
# BFS propagation through dependents
|
||||
queue = deque([node_id])
|
||||
visited = {node_id}
|
||||
|
||||
while queue:
|
||||
current_id = queue.popleft()
|
||||
for dep_id in self._dependents.get(current_id, set()):
|
||||
self._mark_node_dirty(dep_id, rows)
|
||||
if dep_id not in visited:
|
||||
visited.add(dep_id)
|
||||
queue.append(dep_id)
|
||||
|
||||
# ==================== Private helpers ====================
|
||||
|
||||
@staticmethod
|
||||
def _make_node_id(table: str, column: str) -> str:
|
||||
"""Create a standard node ID from table and column names."""
|
||||
return f"{table}.{column}"
|
||||
|
||||
def _get_or_create_node(self, table: str, column: str) -> DependencyNode:
|
||||
"""Get existing node or create a new one."""
|
||||
node_id = self._make_node_id(table, column)
|
||||
if node_id not in self._nodes:
|
||||
self._nodes[node_id] = DependencyNode(
|
||||
node_id=node_id,
|
||||
table=table,
|
||||
column=column,
|
||||
)
|
||||
return self._nodes[node_id]
|
||||
|
||||
def _mark_node_dirty(self, node_id: str, rows: Optional[list[int]]) -> None:
|
||||
"""Mark a specific node as dirty."""
|
||||
if node_id not in self._nodes:
|
||||
return
|
||||
node = self._nodes[node_id]
|
||||
node.dirty = True
|
||||
if rows is not None:
|
||||
node.dirty_rows.update(rows)
|
||||
else:
|
||||
node.dirty_rows.clear() # Empty = all rows dirty
|
||||
|
||||
def _remove_edges(self, node_id: str) -> None:
|
||||
"""Remove all edges connected to a node."""
|
||||
# Remove this node from its precedents' dependents sets
|
||||
for prec_id in list(self._precedents.get(node_id, set())):
|
||||
self._dependents[prec_id].discard(node_id)
|
||||
# Clear this node's precedents
|
||||
self._precedents[node_id].clear()
|
||||
|
||||
def _detect_cycles(self) -> None:
|
||||
"""
|
||||
Detect cycles in the full graph using Kahn's algorithm.
|
||||
|
||||
Raises:
|
||||
FormulaCycleError: If a cycle is detected.
|
||||
"""
|
||||
# Only check formula nodes
|
||||
formula_nodes = {
|
||||
nid for nid, node in self._nodes.items()
|
||||
if node.formula is not None
|
||||
}
|
||||
|
||||
if not formula_nodes:
|
||||
return
|
||||
|
||||
in_degree = {}
|
||||
for nid in formula_nodes:
|
||||
in_degree[nid] = 0
|
||||
for nid in formula_nodes:
|
||||
for prec_id in self._precedents.get(nid, set()):
|
||||
if prec_id in formula_nodes:
|
||||
in_degree[nid] = in_degree.get(nid, 0) + 1
|
||||
|
||||
queue = deque([nid for nid in formula_nodes if in_degree.get(nid, 0) == 0])
|
||||
processed = set()
|
||||
|
||||
while queue:
|
||||
nid = queue.popleft()
|
||||
processed.add(nid)
|
||||
for dep_id in self._dependents.get(nid, set()):
|
||||
if dep_id in formula_nodes:
|
||||
in_degree[dep_id] -= 1
|
||||
if in_degree[dep_id] == 0:
|
||||
queue.append(dep_id)
|
||||
|
||||
cycle_nodes = formula_nodes - processed
|
||||
if cycle_nodes:
|
||||
raise FormulaCycleError(sorted(cycle_nodes))
|
||||
|
||||
def _extract_dependencies(
|
||||
self,
|
||||
node: FormulaNode,
|
||||
current_table: str,
|
||||
) -> set[str]:
|
||||
"""
|
||||
Recursively extract all column dependency IDs from a formula AST.
|
||||
|
||||
Args:
|
||||
node: The AST node to walk.
|
||||
current_table: The table containing this formula (for ColumnRef).
|
||||
|
||||
Returns:
|
||||
Set of dependency node IDs (``"table.column"`` format).
|
||||
"""
|
||||
deps = set()
|
||||
|
||||
if isinstance(node, ColumnRef):
|
||||
deps.add(self._make_node_id(current_table, node.column))
|
||||
|
||||
elif isinstance(node, CrossTableRef):
|
||||
deps.add(self._make_node_id(node.table, node.column))
|
||||
# Also depend on the local column used in WHERE clause
|
||||
if node.where_clause is not None:
|
||||
deps.add(self._make_node_id(current_table, node.where_clause.local_column))
|
||||
|
||||
elif isinstance(node, BinaryOp):
|
||||
deps.update(self._extract_dependencies(node.left, current_table))
|
||||
deps.update(self._extract_dependencies(node.right, current_table))
|
||||
|
||||
elif isinstance(node, UnaryOp):
|
||||
deps.update(self._extract_dependencies(node.operand, current_table))
|
||||
|
||||
elif isinstance(node, FunctionCall):
|
||||
for arg in node.arguments:
|
||||
deps.update(self._extract_dependencies(arg, current_table))
|
||||
|
||||
elif isinstance(node, ConditionalExpr):
|
||||
deps.update(self._extract_dependencies(node.value_expr, current_table))
|
||||
deps.update(self._extract_dependencies(node.condition, current_table))
|
||||
if node.else_expr is not None:
|
||||
deps.update(self._extract_dependencies(node.else_expr, current_table))
|
||||
|
||||
return deps
|
||||
0
src/myfasthtml/core/formula/dsl/__init__.py
Normal file
0
src/myfasthtml/core/formula/dsl/__init__.py
Normal file
@@ -0,0 +1,180 @@
|
||||
"""
|
||||
Autocompletion engine for the DataGrid Formula DSL.
|
||||
|
||||
Provides context-aware suggestions for:
|
||||
- Column names (after ``{``)
|
||||
- Cross-table references (``{Table.``)
|
||||
- Built-in function names
|
||||
- Keywords: ``if``, ``else``, ``and``, ``or``, ``not``, ``WHERE``
|
||||
"""
|
||||
import re
|
||||
|
||||
from myfasthtml.core.dsl.base_completion import BaseCompletionEngine
|
||||
from myfasthtml.core.dsl.types import Position, Suggestion
|
||||
from myfasthtml.core.formula.evaluator import BUILTIN_FUNCTIONS
|
||||
from myfasthtml.core.utils import make_safe_id
|
||||
|
||||
FORMULA_KEYWORDS = [
|
||||
"if", "else", "and", "or", "not", "where",
|
||||
"between", "in", "isempty", "isnotempty", "isnan",
|
||||
"contains", "startswith", "endswith",
|
||||
"true", "false",
|
||||
]
|
||||
|
||||
|
||||
class FormulaCompletionEngine(BaseCompletionEngine):
|
||||
"""
|
||||
Context-aware completion engine for formula expressions.
|
||||
|
||||
Provides suggestions for column references, functions, and keywords.
|
||||
|
||||
Args:
|
||||
provider: DataGrid metadata provider (DataGridsManager or similar).
|
||||
table_name: Name of the current table in ``"namespace.name"`` format.
|
||||
"""
|
||||
|
||||
def __init__(self, provider, table_name: str):
|
||||
super().__init__(provider)
|
||||
self.table_name = table_name
|
||||
self._id = "formula_completion_engine#" + make_safe_id(table_name)
|
||||
|
||||
def detect_scope(self, text: str, current_line: int):
|
||||
"""Formula has no scope — always the same single-expression scope."""
|
||||
return None
|
||||
|
||||
def detect_context(self, text: str, cursor: Position, scope):
|
||||
"""
|
||||
Detect completion context based on cursor position in formula text.
|
||||
|
||||
Args:
|
||||
text: The full formula text.
|
||||
cursor: Cursor position (line, ch).
|
||||
scope: Unused (formulas have no scopes).
|
||||
|
||||
Returns:
|
||||
Context string: ``"column_ref"``, ``"cross_table"``,
|
||||
``"function"``, ``"keyword"``, or ``"general"``.
|
||||
"""
|
||||
# Get text up to cursor
|
||||
lines = text.split("\n")
|
||||
line_idx = min(cursor.line, len(lines) - 1)
|
||||
line_text = lines[line_idx]
|
||||
text_before = line_text[:cursor.ch]
|
||||
|
||||
# Check if we are inside a { ... } reference
|
||||
last_brace = text_before.rfind("{")
|
||||
if last_brace >= 0:
|
||||
inside = text_before[last_brace + 1:]
|
||||
if "}" not in inside:
|
||||
if "." in inside:
|
||||
return "cross_table"
|
||||
return "column_ref"
|
||||
|
||||
# Check if we are typing a function name (alphanumeric at word start)
|
||||
word_match = re.search(r"[a-z_][a-z0-9_]*$", text_before, re.IGNORECASE)
|
||||
if word_match:
|
||||
return "function_or_keyword"
|
||||
|
||||
return "general"
|
||||
|
||||
def get_suggestions(self, text: str, cursor: Position, scope, context) -> list:
|
||||
"""
|
||||
Generate suggestions based on the detected context.
|
||||
|
||||
Args:
|
||||
text: The full formula text.
|
||||
cursor: Cursor position.
|
||||
scope: Unused.
|
||||
context: String from ``detect_context``.
|
||||
|
||||
Returns:
|
||||
List of Suggestion objects.
|
||||
"""
|
||||
suggestions = []
|
||||
|
||||
if context == "column_ref":
|
||||
# Suggest columns from the current table
|
||||
suggestions += self._column_suggestions(self.table_name)
|
||||
|
||||
elif context == "cross_table":
|
||||
# Get the table name prefix from text_before
|
||||
lines = text.split("\n")
|
||||
line_text = lines[min(cursor.line, len(lines) - 1)]
|
||||
text_before = line_text[:cursor.ch]
|
||||
last_brace = text_before.rfind("{")
|
||||
inside = text_before[last_brace + 1:] if last_brace >= 0 else ""
|
||||
dot_pos = inside.rfind(".")
|
||||
table_prefix = inside[:dot_pos] if dot_pos >= 0 else ""
|
||||
|
||||
# Suggest columns from the referenced table
|
||||
if table_prefix:
|
||||
suggestions += self._column_suggestions(table_prefix)
|
||||
else:
|
||||
suggestions += self._table_suggestions()
|
||||
|
||||
elif context == "function_or_keyword":
|
||||
suggestions += self._function_suggestions()
|
||||
suggestions += self._keyword_suggestions()
|
||||
|
||||
else: # general
|
||||
suggestions += self._function_suggestions()
|
||||
suggestions += self._keyword_suggestions()
|
||||
suggestions += [
|
||||
Suggestion(
|
||||
label="{",
|
||||
detail="Column reference",
|
||||
insert_text="{",
|
||||
)
|
||||
]
|
||||
|
||||
return suggestions
|
||||
|
||||
# ==================== Private helpers ====================
|
||||
|
||||
def _column_suggestions(self, table_name: str) -> list:
|
||||
"""Get column name suggestions for a table."""
|
||||
try:
|
||||
columns = self.provider.list_columns(table_name)
|
||||
return [
|
||||
Suggestion(
|
||||
label=col,
|
||||
detail=f"Column from {table_name}",
|
||||
insert_text=col,
|
||||
)
|
||||
for col in (columns or [])
|
||||
]
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def _table_suggestions(self) -> list:
|
||||
"""Get table name suggestions."""
|
||||
try:
|
||||
tables = self.provider.list_tables()
|
||||
return [
|
||||
Suggestion(
|
||||
label=t,
|
||||
detail="Table",
|
||||
insert_text=t,
|
||||
)
|
||||
for t in (tables or [])
|
||||
]
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def _function_suggestions(self) -> list:
|
||||
"""Get built-in function name suggestions."""
|
||||
return [
|
||||
Suggestion(
|
||||
label=name,
|
||||
detail="Function",
|
||||
insert_text=f"{name}(",
|
||||
)
|
||||
for name in sorted(BUILTIN_FUNCTIONS.keys())
|
||||
]
|
||||
|
||||
def _keyword_suggestions(self) -> list:
|
||||
"""Get keyword suggestions."""
|
||||
return [
|
||||
Suggestion(label=kw, detail="Keyword", insert_text=kw)
|
||||
for kw in FORMULA_KEYWORDS
|
||||
]
|
||||
1
src/myfasthtml/core/formula/dsl/completion/__init__.py
Normal file
1
src/myfasthtml/core/formula/dsl/completion/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
79
src/myfasthtml/core/formula/dsl/definition.py
Normal file
79
src/myfasthtml/core/formula/dsl/definition.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""
|
||||
FormulaDSL definition for the DslEditor control.
|
||||
|
||||
Provides the Lark grammar and derived completions for the
|
||||
DataGrid Formula DSL (CodeMirror 5 Simple Mode).
|
||||
"""
|
||||
|
||||
from functools import cached_property
|
||||
from typing import Dict, Any
|
||||
|
||||
from myfasthtml.core.dsl.base import DSLDefinition
|
||||
from .grammar import FORMULA_GRAMMAR
|
||||
|
||||
|
||||
class FormulaDSL(DSLDefinition):
|
||||
"""
|
||||
DSL definition for DataGrid formula expressions.
|
||||
|
||||
Uses the Lark grammar from grammar.py to drive syntax highlighting
|
||||
and autocompletion in the DslEditor.
|
||||
"""
|
||||
|
||||
name: str = "Formula DSL"
|
||||
|
||||
def get_grammar(self) -> str:
|
||||
"""Return the Lark grammar for the formula DSL."""
|
||||
return FORMULA_GRAMMAR
|
||||
|
||||
@cached_property
|
||||
def simple_mode_config(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Return a hand-tuned CodeMirror 5 Simple Mode config for formula syntax.
|
||||
|
||||
Overrides the base class to provide optimized highlighting rules
|
||||
for column references, operators, functions, and keywords.
|
||||
"""
|
||||
return {
|
||||
"start": [
|
||||
# Column references: {ColumnName} or {Table.Column}
|
||||
{
|
||||
"regex": r"\{[A-Za-z_][A-Za-z0-9_.]*(?:\s+where\s+[A-Za-z_][A-Za-z0-9_.]*\s*=\s*[A-Za-z_][A-Za-z0-9_]*)?\}",
|
||||
"token": "variable-2",
|
||||
},
|
||||
# Function names before parenthesis
|
||||
{
|
||||
"regex": r"[a-z_][a-z0-9_]*(?=\s*\()",
|
||||
"token": "keyword",
|
||||
},
|
||||
# Keywords: if, else, and, or, not, where, between, in
|
||||
{
|
||||
"regex": r"\b(if|else|and|or|not|where|between|in|isempty|isnotempty|isnan|contains|startswith|endswith|true|false)\b",
|
||||
"token": "keyword",
|
||||
},
|
||||
# Numbers
|
||||
{
|
||||
"regex": r"[+-]?(\d+\.?\d*|\.\d+)([eE][+-]?\d+)?",
|
||||
"token": "number",
|
||||
},
|
||||
# Strings
|
||||
{
|
||||
"regex": r'"[^"\\]*"',
|
||||
"token": "string",
|
||||
},
|
||||
# Operators
|
||||
{
|
||||
"regex": r"[=!<>]=?|[+\-*/%^]",
|
||||
"token": "operator",
|
||||
},
|
||||
# Parentheses and brackets
|
||||
{
|
||||
"regex": r"[()[\],]",
|
||||
"token": "punctuation",
|
||||
},
|
||||
],
|
||||
"meta": {
|
||||
"dontIndentStates": ["comment"],
|
||||
"lineComment": "#",
|
||||
},
|
||||
}
|
||||
35
src/myfasthtml/core/formula/dsl/exceptions.py
Normal file
35
src/myfasthtml/core/formula/dsl/exceptions.py
Normal file
@@ -0,0 +1,35 @@
|
||||
class FormulaError(Exception):
|
||||
"""Base exception for formula errors."""
|
||||
pass
|
||||
|
||||
|
||||
class FormulaSyntaxError(FormulaError):
|
||||
"""Raised when the formula has syntax errors."""
|
||||
|
||||
def __init__(self, message, line=None, column=None, context=None):
|
||||
self.message = message
|
||||
self.line = line
|
||||
self.column = column
|
||||
self.context = context
|
||||
super().__init__(self._format_message())
|
||||
|
||||
def _format_message(self):
|
||||
parts = [self.message]
|
||||
if self.line is not None:
|
||||
parts.append(f"at line {self.line}")
|
||||
if self.column is not None:
|
||||
parts.append(f"col {self.column}")
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
class FormulaValidationError(FormulaError):
|
||||
"""Raised when the formula is syntactically correct but semantically invalid."""
|
||||
pass
|
||||
|
||||
|
||||
class FormulaCycleError(FormulaError):
|
||||
"""Raised when formula dependencies contain a cycle."""
|
||||
|
||||
def __init__(self, cycle_nodes):
|
||||
self.cycle_nodes = cycle_nodes
|
||||
super().__init__(f"Circular dependency detected involving: {', '.join(cycle_nodes)}")
|
||||
100
src/myfasthtml/core/formula/dsl/grammar.py
Normal file
100
src/myfasthtml/core/formula/dsl/grammar.py
Normal file
@@ -0,0 +1,100 @@
|
||||
FORMULA_GRAMMAR = r"""
|
||||
start: expression
|
||||
|
||||
// ==================== Top-level expression ====================
|
||||
|
||||
?expression: conditional_expr
|
||||
|
||||
// Suffix-if: value_expr if condition [else expression]
|
||||
// Right-associative for chaining: a if c1 else b if c2 else d
|
||||
?conditional_expr: or_expr "if" or_expr "else" conditional_expr -> conditional_with_else
|
||||
| or_expr "if" or_expr -> conditional_no_else
|
||||
| or_expr
|
||||
|
||||
// ==================== Logical ====================
|
||||
|
||||
?or_expr: and_expr ("or" and_expr)* -> or_op
|
||||
?and_expr: not_expr ("and" not_expr)* -> and_op
|
||||
?not_expr: "not" not_expr -> not_op
|
||||
| comparison
|
||||
|
||||
// ==================== Comparison ====================
|
||||
|
||||
?comparison: addition comp_op addition -> comparison_expr
|
||||
| addition "in" "[" literal ("," literal)* "]" -> in_expr
|
||||
| addition "between" addition "and" addition -> between_expr
|
||||
| addition "contains" addition -> contains_expr
|
||||
| addition "startswith" addition -> startswith_expr
|
||||
| addition "endswith" addition -> endswith_expr
|
||||
| addition "isempty" -> isempty_expr
|
||||
| addition "isnotempty" -> isnotempty_expr
|
||||
| addition "isnan" -> isnan_expr
|
||||
| addition
|
||||
|
||||
comp_op: "==" -> eq
|
||||
| "!=" -> ne
|
||||
| "<=" -> le
|
||||
| "<" -> lt
|
||||
| ">=" -> ge
|
||||
| ">" -> gt
|
||||
|
||||
// ==================== Arithmetic ====================
|
||||
|
||||
?addition: multiplication (add_op multiplication)* -> add_expr
|
||||
?multiplication: power (mul_op power)* -> mul_expr
|
||||
?power: unary ("^" unary)* -> pow_expr
|
||||
|
||||
add_op: "+" -> plus
|
||||
| "-" -> minus
|
||||
|
||||
mul_op: "*" -> times
|
||||
| "/" -> divide
|
||||
| "%" -> modulo
|
||||
|
||||
?unary: "-" unary -> neg
|
||||
| atom
|
||||
|
||||
// ==================== Atoms ====================
|
||||
|
||||
?atom: function_call
|
||||
| cross_table_ref
|
||||
| column_ref
|
||||
| literal
|
||||
| "(" expression ")" -> paren
|
||||
|
||||
// ==================== References ====================
|
||||
|
||||
// Cross-table must be checked before column_ref since both use { }
|
||||
// TABLE_NAME.COL_NAME with optional WHERE clause
|
||||
// Note: whitespace around "where" and "=" is handled by %ignore
|
||||
cross_table_ref: "{" TABLE_COL_REF "}" -> cross_ref_simple
|
||||
| "{" TABLE_COL_REF "where" where_clause "}" -> cross_ref_where
|
||||
|
||||
column_ref: "{" COL_NAME "}"
|
||||
|
||||
where_clause: TABLE_COL_REF "=" COL_NAME
|
||||
|
||||
// TABLE_COL_REF matches "TableName.ColumnName" (dot-separated, no spaces)
|
||||
TABLE_COL_REF: /[A-Za-z_][A-Za-z0-9_]*\.[A-Za-z_][A-Za-z0-9_]*/
|
||||
COL_NAME: /[A-Za-z_][A-Za-z0-9_ ]*/
|
||||
|
||||
// ==================== Functions ====================
|
||||
|
||||
function_call: FUNC_NAME "(" [expression ("," expression)*] ")"
|
||||
FUNC_NAME: /[a-z_][a-z0-9_]*/
|
||||
|
||||
// ==================== Literals ====================
|
||||
|
||||
?literal: NUMBER -> number_literal
|
||||
| ESCAPED_STRING -> string_literal
|
||||
| "true"i -> true_literal
|
||||
| "false"i -> false_literal
|
||||
|
||||
// ==================== Terminals ====================
|
||||
|
||||
NUMBER: /[+-]?(\d+\.?\d*|\.\d+)([eE][+-]?\d+)?/
|
||||
ESCAPED_STRING: "\"" /[^"\\]*/ "\""
|
||||
|
||||
%ignore /[ \t\f]+/
|
||||
%ignore /\#[^\n]*/
|
||||
"""
|
||||
85
src/myfasthtml/core/formula/dsl/parser.py
Normal file
85
src/myfasthtml/core/formula/dsl/parser.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""
|
||||
Formula DSL parser using Lark.
|
||||
|
||||
Handles parsing of formula expression strings into a Lark AST.
|
||||
No indentation handling needed — formulas are single-line expressions.
|
||||
"""
|
||||
from lark import Lark, UnexpectedInput
|
||||
|
||||
from .exceptions import FormulaSyntaxError
|
||||
from .grammar import FORMULA_GRAMMAR
|
||||
|
||||
|
||||
class FormulaParser:
|
||||
"""
|
||||
Parser for the DataGrid formula language.
|
||||
|
||||
Uses Lark LALR parser without indentation handling.
|
||||
|
||||
Example:
|
||||
parser = FormulaParser()
|
||||
tree = parser.parse("{Price} * {Quantity}")
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._parser = Lark(
|
||||
FORMULA_GRAMMAR,
|
||||
parser="lalr",
|
||||
propagate_positions=False,
|
||||
)
|
||||
|
||||
def parse(self, text: str):
|
||||
"""
|
||||
Parse a formula expression string into a Lark Tree.
|
||||
|
||||
Args:
|
||||
text: The formula expression text.
|
||||
|
||||
Returns:
|
||||
lark.Tree: The parsed AST.
|
||||
|
||||
Raises:
|
||||
FormulaSyntaxError: If the text has syntax errors.
|
||||
"""
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return None
|
||||
|
||||
try:
|
||||
return self._parser.parse(text)
|
||||
except UnexpectedInput as e:
|
||||
context = None
|
||||
if hasattr(e, "get_context"):
|
||||
context = e.get_context(text)
|
||||
|
||||
raise FormulaSyntaxError(
|
||||
message=self._format_error_message(e),
|
||||
line=getattr(e, "line", None),
|
||||
column=getattr(e, "column", None),
|
||||
context=context,
|
||||
) from e
|
||||
|
||||
def _format_error_message(self, error: UnexpectedInput) -> str:
|
||||
"""Format a user-friendly error message from a Lark exception."""
|
||||
if hasattr(error, "expected"):
|
||||
expected = list(error.expected)
|
||||
if len(expected) == 1:
|
||||
return f"Expected {expected[0]}"
|
||||
elif len(expected) <= 5:
|
||||
return f"Expected one of: {', '.join(expected)}"
|
||||
else:
|
||||
return "Unexpected input"
|
||||
|
||||
return str(error)
|
||||
|
||||
|
||||
# Singleton parser instance
|
||||
_parser_instance = None
|
||||
|
||||
|
||||
def get_parser() -> FormulaParser:
|
||||
"""Get the singleton FormulaParser instance."""
|
||||
global _parser_instance
|
||||
if _parser_instance is None:
|
||||
_parser_instance = FormulaParser()
|
||||
return _parser_instance
|
||||
274
src/myfasthtml/core/formula/dsl/transformer.py
Normal file
274
src/myfasthtml/core/formula/dsl/transformer.py
Normal file
@@ -0,0 +1,274 @@
|
||||
"""
|
||||
Formula DSL Transformer.
|
||||
|
||||
Converts a Lark AST tree into FormulaDefinition and related AST dataclasses.
|
||||
"""
|
||||
from lark import Transformer
|
||||
|
||||
from ..dataclasses import (
|
||||
FormulaDefinition,
|
||||
FormulaNode,
|
||||
LiteralNode,
|
||||
ColumnRef,
|
||||
CrossTableRef,
|
||||
WhereClause,
|
||||
BinaryOp,
|
||||
UnaryOp,
|
||||
FunctionCall,
|
||||
ConditionalExpr,
|
||||
)
|
||||
|
||||
|
||||
class FormulaTransformer(Transformer):
|
||||
"""
|
||||
Transforms the Lark parse tree into FormulaDefinition AST dataclasses.
|
||||
|
||||
Handles left-associative folding for arithmetic and logical operators.
|
||||
Handles right-associative chaining for conditional expressions.
|
||||
"""
|
||||
|
||||
# ==================== Top-level ====================
|
||||
|
||||
def start(self, items):
|
||||
"""Return the FormulaDefinition wrapping the single expression."""
|
||||
expr = items[0]
|
||||
return FormulaDefinition(expression=expr)
|
||||
|
||||
# ==================== Conditionals ====================
|
||||
|
||||
def conditional_with_else(self, items):
|
||||
"""value_expr if condition else else_expr"""
|
||||
value_expr, condition, else_expr = items
|
||||
return ConditionalExpr(
|
||||
value_expr=value_expr,
|
||||
condition=condition,
|
||||
else_expr=else_expr,
|
||||
)
|
||||
|
||||
def conditional_no_else(self, items):
|
||||
"""value_expr if condition"""
|
||||
value_expr, condition = items
|
||||
return ConditionalExpr(
|
||||
value_expr=value_expr,
|
||||
condition=condition,
|
||||
else_expr=None,
|
||||
)
|
||||
|
||||
# ==================== Logical ====================
|
||||
|
||||
def or_op(self, items):
|
||||
"""Fold left-associatively: a or b or c -> BinaryOp(or, BinaryOp(or, a, b), c)"""
|
||||
return self._fold_left(items, "or")
|
||||
|
||||
def and_op(self, items):
|
||||
"""Fold left-associatively: a and b and c -> BinaryOp(and, BinaryOp(and, a, b), c)"""
|
||||
return self._fold_left(items, "and")
|
||||
|
||||
def not_op(self, items):
|
||||
"""not expr"""
|
||||
return UnaryOp(operator="not", operand=items[0])
|
||||
|
||||
# ==================== Comparisons ====================
|
||||
|
||||
def comparison_expr(self, items):
|
||||
"""left comp_op right"""
|
||||
left, op, right = items
|
||||
return BinaryOp(operator=op, left=left, right=right)
|
||||
|
||||
def in_expr(self, items):
|
||||
"""operand in [literal, ...]"""
|
||||
operand = items[0]
|
||||
values = list(items[1:])
|
||||
return BinaryOp(operator="in", left=operand, right=LiteralNode(value=values))
|
||||
|
||||
def between_expr(self, items):
|
||||
"""operand between low and high"""
|
||||
operand, low, high = items
|
||||
return BinaryOp(
|
||||
operator="between",
|
||||
left=operand,
|
||||
right=LiteralNode(value=[low, high]),
|
||||
)
|
||||
|
||||
def contains_expr(self, items):
|
||||
left, right = items
|
||||
return BinaryOp(operator="contains", left=left, right=right)
|
||||
|
||||
def startswith_expr(self, items):
|
||||
left, right = items
|
||||
return BinaryOp(operator="startswith", left=left, right=right)
|
||||
|
||||
def endswith_expr(self, items):
|
||||
left, right = items
|
||||
return BinaryOp(operator="endswith", left=left, right=right)
|
||||
|
||||
def isempty_expr(self, items):
|
||||
return UnaryOp(operator="isempty", operand=items[0])
|
||||
|
||||
def isnotempty_expr(self, items):
|
||||
return UnaryOp(operator="isnotempty", operand=items[0])
|
||||
|
||||
def isnan_expr(self, items):
|
||||
return UnaryOp(operator="isnan", operand=items[0])
|
||||
|
||||
# ==================== Comparison operators ====================
|
||||
|
||||
def eq(self, items):
|
||||
return "=="
|
||||
|
||||
def ne(self, items):
|
||||
return "!="
|
||||
|
||||
def le(self, items):
|
||||
return "<="
|
||||
|
||||
def lt(self, items):
|
||||
return "<"
|
||||
|
||||
def ge(self, items):
|
||||
return ">="
|
||||
|
||||
def gt(self, items):
|
||||
return ">"
|
||||
|
||||
# ==================== Arithmetic ====================
|
||||
|
||||
def add_expr(self, items):
|
||||
"""Fold left-associatively with alternating operands and operators."""
|
||||
return self._fold_binary_with_ops(items)
|
||||
|
||||
def mul_expr(self, items):
|
||||
"""Fold left-associatively with alternating operands and operators."""
|
||||
return self._fold_binary_with_ops(items)
|
||||
|
||||
def pow_expr(self, items):
|
||||
"""Fold left-associatively for power expressions (^ is left-assoc here)."""
|
||||
# pow_expr items are [base, exp1, exp2, ...] — operator is always "^"
|
||||
result = items[0]
|
||||
for exp in items[1:]:
|
||||
result = BinaryOp(operator="^", left=result, right=exp)
|
||||
return result
|
||||
|
||||
def plus(self, items):
|
||||
return "+"
|
||||
|
||||
def minus(self, items):
|
||||
return "-"
|
||||
|
||||
def times(self, items):
|
||||
return "*"
|
||||
|
||||
def divide(self, items):
|
||||
return "/"
|
||||
|
||||
def modulo(self, items):
|
||||
return "%"
|
||||
|
||||
def neg(self, items):
|
||||
"""Unary negation: -expr"""
|
||||
return UnaryOp(operator="-", operand=items[0])
|
||||
|
||||
# ==================== References ====================
|
||||
|
||||
def cross_ref_simple(self, items):
|
||||
"""{ Table.Column }"""
|
||||
table_col = str(items[0])
|
||||
table, column = table_col.split(".", 1)
|
||||
return CrossTableRef(table=table, column=column)
|
||||
|
||||
def cross_ref_where(self, items):
|
||||
"""{ Table.Column WHERE remote_table.remote_col = local_col }"""
|
||||
table_col = str(items[0])
|
||||
where = items[1]
|
||||
table, column = table_col.split(".", 1)
|
||||
return CrossTableRef(table=table, column=column, where_clause=where)
|
||||
|
||||
def column_ref(self, items):
|
||||
"""{ ColumnName }"""
|
||||
col_name = str(items[0]).strip()
|
||||
return ColumnRef(column=col_name)
|
||||
|
||||
def where_clause(self, items):
|
||||
"""TABLE_COL_REF = COL_NAME"""
|
||||
remote_table_col = str(items[0])
|
||||
local_col = str(items[1]).strip()
|
||||
remote_table, remote_col = remote_table_col.split(".", 1)
|
||||
return WhereClause(
|
||||
remote_table=remote_table,
|
||||
remote_column=remote_col,
|
||||
local_column=local_col,
|
||||
)
|
||||
|
||||
# ==================== Functions ====================
|
||||
|
||||
def function_call(self, items):
|
||||
"""func_name(arg1, arg2, ...)"""
|
||||
func_name = str(items[0]).lower()
|
||||
args = list(items[1:])
|
||||
return FunctionCall(function_name=func_name, arguments=args)
|
||||
|
||||
# ==================== Literals ====================
|
||||
|
||||
def number_literal(self, items):
|
||||
value = str(items[0])
|
||||
if "." in value or "e" in value.lower():
|
||||
return LiteralNode(value=float(value))
|
||||
try:
|
||||
return LiteralNode(value=int(value))
|
||||
except ValueError:
|
||||
return LiteralNode(value=float(value))
|
||||
|
||||
def string_literal(self, items):
|
||||
raw = str(items[0])
|
||||
# Remove surrounding double quotes
|
||||
if raw.startswith('"') and raw.endswith('"'):
|
||||
return LiteralNode(value=raw[1:-1])
|
||||
return LiteralNode(value=raw)
|
||||
|
||||
def true_literal(self, items):
|
||||
return LiteralNode(value=True)
|
||||
|
||||
def false_literal(self, items):
|
||||
return LiteralNode(value=False)
|
||||
|
||||
def paren(self, items):
|
||||
"""Parenthesized expression — transparent pass-through."""
|
||||
return items[0]
|
||||
|
||||
# ==================== Helpers ====================
|
||||
|
||||
def _fold_left(self, items: list, op: str) -> FormulaNode:
|
||||
"""
|
||||
Fold a list of operands left-associatively with a fixed operator.
|
||||
|
||||
Args:
|
||||
items: List of FormulaNode operands.
|
||||
op: Operator string.
|
||||
|
||||
Returns:
|
||||
Left-folded BinaryOp tree, or the single item if only one.
|
||||
"""
|
||||
result = items[0]
|
||||
for operand in items[1:]:
|
||||
result = BinaryOp(operator=op, left=result, right=operand)
|
||||
return result
|
||||
|
||||
def _fold_binary_with_ops(self, items: list) -> FormulaNode:
|
||||
"""
|
||||
Fold a list of alternating [operand, op, operand, op, operand, ...]
|
||||
left-associatively.
|
||||
|
||||
Args:
|
||||
items: Alternating list: [expr, op_str, expr, op_str, expr, ...]
|
||||
|
||||
Returns:
|
||||
Left-folded BinaryOp tree.
|
||||
"""
|
||||
result = items[0]
|
||||
i = 1
|
||||
while i < len(items):
|
||||
op = items[i]
|
||||
right = items[i + 1]
|
||||
result = BinaryOp(operator=op, left=result, right=right)
|
||||
i += 2
|
||||
return result
|
||||
398
src/myfasthtml/core/formula/engine.py
Normal file
398
src/myfasthtml/core/formula/engine.py
Normal file
@@ -0,0 +1,398 @@
|
||||
"""
|
||||
Formula Engine — facade orchestrating parsing, DAG, and evaluation.
|
||||
|
||||
Coordinates:
|
||||
- Parsing formula text via the DSL parser
|
||||
- Registering formulas and their dependencies in the DependencyGraph
|
||||
- Evaluating dirty formula columns row-by-row via FormulaEvaluator
|
||||
- Updating ns_fast_access caches in the DatagridStore
|
||||
"""
|
||||
import logging
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .dataclasses import FormulaDefinition, WhereClause
|
||||
from .dependency_graph import DependencyGraph
|
||||
from .dsl.parser import get_parser
|
||||
from .dsl.transformer import FormulaTransformer
|
||||
from .evaluator import FormulaEvaluator
|
||||
|
||||
logger = logging.getLogger("FormulaEngine")
|
||||
|
||||
# Callback that returns a DatagridStore-like object for a given table name
|
||||
RegistryResolver = Callable[[str], Any]
|
||||
|
||||
|
||||
def parse_formula(text: str) -> FormulaDefinition | None:
|
||||
"""Parse a formula expression string into a FormulaDefinition AST.
|
||||
|
||||
Args:
|
||||
text: The formula expression string.
|
||||
|
||||
Returns:
|
||||
FormulaDefinition on success, None if text is empty.
|
||||
|
||||
Raises:
|
||||
FormulaSyntaxError: If the formula text is syntactically invalid.
|
||||
"""
|
||||
text = text.strip() if text else ""
|
||||
if not text:
|
||||
return None
|
||||
|
||||
parser = get_parser()
|
||||
tree = parser.parse(text)
|
||||
if tree is None:
|
||||
return None
|
||||
|
||||
transformer = FormulaTransformer()
|
||||
formula = transformer.transform(tree)
|
||||
formula.source_text = text
|
||||
return formula
|
||||
|
||||
|
||||
class FormulaEngine:
|
||||
"""
|
||||
Facade for the formula calculation system.
|
||||
|
||||
Orchestrates formula parsing, dependency tracking, and incremental
|
||||
recalculation of formula columns.
|
||||
|
||||
Args:
|
||||
registry_resolver: Callback that takes a table name and returns
|
||||
the DatagridStore for that table (used for cross-table refs).
|
||||
Provided by DataGridsManager.
|
||||
"""
|
||||
|
||||
def __init__(self, registry_resolver: Optional[RegistryResolver] = None):
|
||||
self._graph = DependencyGraph()
|
||||
self._registry_resolver = registry_resolver
|
||||
# Cache of parsed formulas: {(table, col): FormulaDefinition}
|
||||
self._formulas: dict[tuple[str, str], FormulaDefinition] = {}
|
||||
|
||||
def set_formula(self, table: str, col: str, formula_text: str) -> None:
|
||||
"""
|
||||
Parse and register a formula for a column.
|
||||
|
||||
Args:
|
||||
table: Table name.
|
||||
col: Column name.
|
||||
formula_text: The formula expression string.
|
||||
|
||||
Raises:
|
||||
FormulaSyntaxError: If the formula is syntactically invalid.
|
||||
FormulaCycleError: If the formula would create a circular dependency.
|
||||
"""
|
||||
formula_text = formula_text.strip() if formula_text else ""
|
||||
if not formula_text:
|
||||
self.remove_formula(table, col)
|
||||
return
|
||||
|
||||
formula = parse_formula(formula_text)
|
||||
if formula is None:
|
||||
self.remove_formula(table, col)
|
||||
return
|
||||
|
||||
# Registers in DAG and raises FormulaCycleError if cycle detected
|
||||
self._graph.add_formula(table, col, formula)
|
||||
self._formulas[(table, col)] = formula
|
||||
|
||||
logger.debug("Formula set for %s.%s: %s", table, col, formula_text)
|
||||
|
||||
def remove_formula(self, table: str, col: str) -> None:
|
||||
"""
|
||||
Remove a formula column from the engine.
|
||||
|
||||
Args:
|
||||
table: Table name.
|
||||
col: Column name.
|
||||
"""
|
||||
self._graph.remove_formula(table, col)
|
||||
self._formulas.pop((table, col), None)
|
||||
|
||||
def mark_data_changed(
|
||||
self,
|
||||
table: str,
|
||||
col: str,
|
||||
rows: Optional[list[int]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Mark a column's data as changed, propagating dirty flags.
|
||||
|
||||
Call this when source data is modified so that dependent formula
|
||||
columns are re-evaluated on next render.
|
||||
|
||||
Args:
|
||||
table: Table name.
|
||||
col: Column name.
|
||||
rows: Specific row indices that changed. None means all rows.
|
||||
"""
|
||||
self._graph.mark_dirty(table, col, rows)
|
||||
|
||||
def recalculate_if_needed(self, table: str, store: Any) -> bool:
|
||||
"""
|
||||
Recalculate all dirty formula columns for a table.
|
||||
|
||||
Should be called at the start of ``mk_body_content_page()`` to
|
||||
ensure formula columns are up-to-date before rendering.
|
||||
|
||||
Updates ``store.ns_fast_access`` and ``store.ns_row_data`` in place.
|
||||
|
||||
Args:
|
||||
table: Table name.
|
||||
store: The DatagridStore instance for this table.
|
||||
|
||||
Returns:
|
||||
True if any columns were recalculated, False otherwise.
|
||||
"""
|
||||
dirty_nodes = self._graph.get_calculation_order(table=table)
|
||||
|
||||
if not dirty_nodes:
|
||||
return False
|
||||
|
||||
for node in dirty_nodes:
|
||||
formula = node.formula
|
||||
if formula is None:
|
||||
continue
|
||||
self._evaluate_column(table, node.column, formula, store)
|
||||
self._graph.clear_dirty(node.node_id)
|
||||
|
||||
# Rebuild ns_row_data after recalculation
|
||||
if dirty_nodes and store.ns_fast_access:
|
||||
self._rebuild_row_data(store)
|
||||
|
||||
return True
|
||||
|
||||
def has_formula(self, table: str, col: str) -> bool:
|
||||
"""
|
||||
Check if a column has a formula registered.
|
||||
|
||||
Args:
|
||||
table: Table name.
|
||||
col: Column name.
|
||||
|
||||
Returns:
|
||||
True if the column has a registered formula.
|
||||
"""
|
||||
return self._graph.has_formula(table, col)
|
||||
|
||||
def get_formula_text(self, table: str, col: str) -> Optional[str]:
|
||||
"""
|
||||
Get the source text of a registered formula.
|
||||
|
||||
Args:
|
||||
table: Table name.
|
||||
col: Column name.
|
||||
|
||||
Returns:
|
||||
Formula source text or None if not registered.
|
||||
"""
|
||||
formula = self._formulas.get((table, col))
|
||||
return formula.source_text if formula else None
|
||||
|
||||
# ==================== Private helpers ====================
|
||||
|
||||
def _evaluate_column(
|
||||
self,
|
||||
table: str,
|
||||
col: str,
|
||||
formula: FormulaDefinition,
|
||||
store: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Evaluate a formula column row-by-row and update ns_fast_access.
|
||||
|
||||
Args:
|
||||
table: Table name.
|
||||
col: Column name.
|
||||
formula: The parsed FormulaDefinition.
|
||||
store: The DatagridStore with ns_fast_access and ns_row_data.
|
||||
"""
|
||||
if store.ns_row_data is None or len(store.ns_row_data) == 0:
|
||||
return
|
||||
|
||||
n_rows = len(store.ns_row_data)
|
||||
resolver = self._make_cross_table_resolver(table)
|
||||
evaluator = FormulaEvaluator(cross_table_resolver=resolver)
|
||||
|
||||
# Ensure ns_fast_access exists before the loop so that formula columns
|
||||
# evaluated earlier in the same pass are visible to subsequent columns.
|
||||
if store.ns_fast_access is None:
|
||||
store.ns_fast_access = {}
|
||||
|
||||
results = np.empty(n_rows, dtype=object)
|
||||
|
||||
for row_index in range(n_rows):
|
||||
# Build row_data from ns_fast_access so that formula columns evaluated
|
||||
# earlier in this pass (e.g. B) are available to dependent columns (e.g. C).
|
||||
row_data = {
|
||||
c: arr[row_index]
|
||||
for c, arr in store.ns_fast_access.items()
|
||||
if arr is not None and row_index < len(arr)
|
||||
}
|
||||
results[row_index] = evaluator.evaluate(formula, row_data, row_index)
|
||||
|
||||
store.ns_fast_access[col] = results
|
||||
|
||||
logger.debug("Evaluated formula column %s.%s (%d rows)", table, col, n_rows)
|
||||
|
||||
def _rebuild_row_data(self, store: Any) -> None:
|
||||
"""
|
||||
Rebuild ns_row_data to include formula column results.
|
||||
|
||||
This ensures formula values are available to dependent formulas
|
||||
in subsequent evaluation passes.
|
||||
|
||||
Args:
|
||||
store: The DatagridStore to update.
|
||||
"""
|
||||
if store.ns_fast_access is None:
|
||||
return
|
||||
|
||||
n_rows = len(store.ns_row_data)
|
||||
for row_index in range(n_rows):
|
||||
row = store.ns_row_data[row_index]
|
||||
for col, arr in store.ns_fast_access.items():
|
||||
if arr is not None and row_index < len(arr):
|
||||
row[col] = arr[row_index]
|
||||
|
||||
def _make_cross_table_resolver(self, current_table: str):
|
||||
"""
|
||||
Create a cross-table resolver callback for the given table context.
|
||||
|
||||
Resolution strategy:
|
||||
1. Explicit WHERE clause: scan remote column for matching rows.
|
||||
2. Implicit join by ``id`` column: match rows where both tables share
|
||||
the same id value.
|
||||
3. Fallback: match by row_index.
|
||||
|
||||
Args:
|
||||
current_table: The table that contains the formula.
|
||||
|
||||
Returns:
|
||||
A callable ``resolver(table, column, where_clause, row_index) -> value``.
|
||||
"""
|
||||
|
||||
def resolver(
|
||||
remote_table: str,
|
||||
remote_column: str,
|
||||
where_clause: Optional[WhereClause],
|
||||
row_index: int,
|
||||
) -> Any:
|
||||
if self._registry_resolver is None:
|
||||
logger.warning(
|
||||
"No registry_resolver set for cross-table ref %s.%s",
|
||||
remote_table, remote_column,
|
||||
)
|
||||
return None
|
||||
|
||||
remote_store = self._registry_resolver(remote_table)
|
||||
if remote_store is None:
|
||||
logger.warning("Table '%s' not found in registry", remote_table)
|
||||
return None
|
||||
|
||||
ns = remote_store.ns_fast_access
|
||||
if not ns or remote_column not in ns:
|
||||
logger.debug(
|
||||
"Column '%s' not found in table '%s'", remote_column, remote_table
|
||||
)
|
||||
return None
|
||||
|
||||
remote_array = ns[remote_column]
|
||||
|
||||
# Strategy 1: Explicit WHERE clause
|
||||
if where_clause is not None:
|
||||
return self._resolve_with_where(
|
||||
where_clause, remote_store, remote_column,
|
||||
remote_array, current_table, row_index,
|
||||
)
|
||||
|
||||
# Strategy 2: Implicit join by 'id' column
|
||||
current_store = self._registry_resolver(current_table)
|
||||
if (
|
||||
current_store is not None
|
||||
and current_store.ns_fast_access is not None
|
||||
and "id" in current_store.ns_fast_access
|
||||
and "id" in ns
|
||||
):
|
||||
local_id_arr = current_store.ns_fast_access["id"]
|
||||
remote_id_arr = ns["id"]
|
||||
if row_index < len(local_id_arr):
|
||||
local_id = local_id_arr[row_index]
|
||||
# Find first matching row in remote table
|
||||
matches = np.where(remote_id_arr == local_id)[0]
|
||||
if len(matches) > 0:
|
||||
return remote_array[matches[0]]
|
||||
return None
|
||||
|
||||
# Strategy 3: Fallback — match by row_index
|
||||
if row_index < len(remote_array):
|
||||
return remote_array[row_index]
|
||||
return None
|
||||
|
||||
return resolver
|
||||
|
||||
def _resolve_with_where(
|
||||
self,
|
||||
where_clause: WhereClause,
|
||||
remote_store: Any,
|
||||
remote_column: str,
|
||||
remote_array: Any,
|
||||
current_table: str,
|
||||
row_index: int,
|
||||
) -> Any:
|
||||
"""
|
||||
Resolve a cross-table reference using an explicit WHERE clause.
|
||||
|
||||
Args:
|
||||
where_clause: The parsed WHERE clause.
|
||||
remote_store: DatagridStore for the remote table.
|
||||
remote_column: Column to return value from.
|
||||
remote_array: numpy array of the remote column values.
|
||||
current_table: Table containing the formula.
|
||||
row_index: Current row being evaluated.
|
||||
|
||||
Returns:
|
||||
The value from the first matching remote row, or None.
|
||||
"""
|
||||
remote_ns = remote_store.ns_fast_access
|
||||
if not remote_ns:
|
||||
return None
|
||||
|
||||
# Get the remote key column array
|
||||
remote_key_col = where_clause.remote_column
|
||||
if remote_key_col not in remote_ns:
|
||||
logger.debug(
|
||||
"WHERE key column '%s' not found in remote table", remote_key_col
|
||||
)
|
||||
return None
|
||||
|
||||
remote_key_array = remote_ns[remote_key_col]
|
||||
|
||||
# Get the local value to compare
|
||||
current_store = self._registry_resolver(current_table) if self._registry_resolver else None
|
||||
if current_store is None or current_store.ns_fast_access is None:
|
||||
return None
|
||||
|
||||
local_col = where_clause.local_column
|
||||
if local_col not in current_store.ns_fast_access:
|
||||
logger.debug("WHERE local column '%s' not found", local_col)
|
||||
return None
|
||||
|
||||
local_array = current_store.ns_fast_access[local_col]
|
||||
if row_index >= len(local_array):
|
||||
return None
|
||||
|
||||
local_value = local_array[row_index]
|
||||
|
||||
# Find matching rows
|
||||
try:
|
||||
matches = np.where(remote_key_array == local_value)[0]
|
||||
except Exception:
|
||||
matches = []
|
||||
|
||||
if len(matches) == 0:
|
||||
return None
|
||||
|
||||
# Return value from first match (use aggregation functions for multi-row)
|
||||
return remote_array[matches[0]]
|
||||
522
src/myfasthtml/core/formula/evaluator.py
Normal file
522
src/myfasthtml/core/formula/evaluator.py
Normal file
@@ -0,0 +1,522 @@
|
||||
"""
|
||||
Formula Evaluator.
|
||||
|
||||
Evaluates a FormulaDefinition AST row-by-row using column data.
|
||||
"""
|
||||
import logging
|
||||
import math
|
||||
from datetime import date, datetime
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from .dataclasses import (
|
||||
FormulaNode,
|
||||
FormulaDefinition,
|
||||
LiteralNode,
|
||||
ColumnRef,
|
||||
CrossTableRef,
|
||||
BinaryOp,
|
||||
UnaryOp,
|
||||
FunctionCall,
|
||||
ConditionalExpr,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("FormulaEvaluator")
|
||||
|
||||
# Type alias for the cross-table resolver callback
|
||||
CrossTableResolver = Callable[[str, str, Optional[object], int], Any]
|
||||
|
||||
|
||||
def _safe_numeric(value) -> Optional[float]:
|
||||
"""Convert value to float, returning None if not possible."""
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
# ==================== Built-in function registry ====================
|
||||
|
||||
def _fn_round(args):
|
||||
if len(args) < 1:
|
||||
return None
|
||||
value = args[0]
|
||||
decimals = int(args[1]) if len(args) > 1 else 0
|
||||
v = _safe_numeric(value)
|
||||
return round(v, decimals) if v is not None else None
|
||||
|
||||
|
||||
def _fn_abs(args):
|
||||
v = _safe_numeric(args[0]) if args else None
|
||||
return abs(v) if v is not None else None
|
||||
|
||||
|
||||
def _fn_min(args):
|
||||
nums = [_safe_numeric(a) for a in args]
|
||||
nums = [n for n in nums if n is not None]
|
||||
return min(nums) if nums else None
|
||||
|
||||
|
||||
def _fn_max(args):
|
||||
nums = [_safe_numeric(a) for a in args]
|
||||
nums = [n for n in nums if n is not None]
|
||||
return max(nums) if nums else None
|
||||
|
||||
|
||||
def _fn_floor(args):
|
||||
v = _safe_numeric(args[0]) if args else None
|
||||
return math.floor(v) if v is not None else None
|
||||
|
||||
|
||||
def _fn_ceil(args):
|
||||
v = _safe_numeric(args[0]) if args else None
|
||||
return math.ceil(v) if v is not None else None
|
||||
|
||||
|
||||
def _fn_sqrt(args):
|
||||
v = _safe_numeric(args[0]) if args else None
|
||||
return math.sqrt(v) if v is not None and v >= 0 else None
|
||||
|
||||
|
||||
def _fn_sum(args):
|
||||
"""Sum of all arguments (used for inline multi-value, not aggregation)."""
|
||||
nums = [_safe_numeric(a) for a in args]
|
||||
nums = [n for n in nums if n is not None]
|
||||
return sum(nums) if nums else None
|
||||
|
||||
|
||||
def _fn_avg(args):
|
||||
nums = [_safe_numeric(a) for a in args]
|
||||
nums = [n for n in nums if n is not None]
|
||||
return sum(nums) / len(nums) if nums else None
|
||||
|
||||
|
||||
def _fn_len(args):
|
||||
v = args[0] if args else None
|
||||
if v is None:
|
||||
return None
|
||||
return len(str(v))
|
||||
|
||||
|
||||
def _fn_upper(args):
|
||||
v = args[0] if args else None
|
||||
return str(v).upper() if v is not None else None
|
||||
|
||||
|
||||
def _fn_lower(args):
|
||||
v = args[0] if args else None
|
||||
return str(v).lower() if v is not None else None
|
||||
|
||||
|
||||
def _fn_trim(args):
|
||||
v = args[0] if args else None
|
||||
return str(v).strip() if v is not None else None
|
||||
|
||||
|
||||
def _fn_left(args):
|
||||
if len(args) < 2:
|
||||
return None
|
||||
v, n = args[0], args[1]
|
||||
if v is None or n is None:
|
||||
return None
|
||||
return str(v)[:int(n)]
|
||||
|
||||
|
||||
def _fn_right(args):
|
||||
if len(args) < 2:
|
||||
return None
|
||||
v, n = args[0], args[1]
|
||||
if v is None or n is None:
|
||||
return None
|
||||
n = int(n)
|
||||
return str(v)[-n:] if n > 0 else ""
|
||||
|
||||
|
||||
def _fn_concat(args):
|
||||
return "".join(str(a) if a is not None else "" for a in args)
|
||||
|
||||
|
||||
def _fn_year(args):
|
||||
v = args[0] if args else None
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, (datetime, date)):
|
||||
return v.year
|
||||
try:
|
||||
return datetime.fromisoformat(str(v)).year
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
def _fn_month(args):
|
||||
v = args[0] if args else None
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, (datetime, date)):
|
||||
return v.month
|
||||
try:
|
||||
return datetime.fromisoformat(str(v)).month
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
def _fn_day(args):
|
||||
v = args[0] if args else None
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, (datetime, date)):
|
||||
return v.day
|
||||
try:
|
||||
return datetime.fromisoformat(str(v)).day
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
def _fn_today(args):
|
||||
return date.today()
|
||||
|
||||
|
||||
def _fn_datediff(args):
|
||||
if len(args) < 2:
|
||||
return None
|
||||
d1, d2 = args[0], args[1]
|
||||
if d1 is None or d2 is None:
|
||||
return None
|
||||
try:
|
||||
if not isinstance(d1, (datetime, date)):
|
||||
d1 = datetime.fromisoformat(str(d1))
|
||||
if not isinstance(d2, (datetime, date)):
|
||||
d2 = datetime.fromisoformat(str(d2))
|
||||
delta = d1 - d2
|
||||
return delta.days
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
def _fn_coalesce(args):
|
||||
for a in args:
|
||||
if a is not None:
|
||||
return a
|
||||
return None
|
||||
|
||||
|
||||
def _fn_if_error(args):
|
||||
# if_error(expr, fallback) - expr already evaluated, error would be None
|
||||
if len(args) < 2:
|
||||
return args[0] if args else None
|
||||
return args[0] if args[0] is not None else args[1]
|
||||
|
||||
|
||||
def _fn_count(args):
|
||||
"""Count non-None values (used for aggregation results)."""
|
||||
return sum(1 for a in args if a is not None)
|
||||
|
||||
|
||||
# ==================== Function registry ====================
|
||||
|
||||
BUILTIN_FUNCTIONS: dict[str, Callable] = {
|
||||
# Math
|
||||
"round": _fn_round,
|
||||
"abs": _fn_abs,
|
||||
"min": _fn_min,
|
||||
"max": _fn_max,
|
||||
"floor": _fn_floor,
|
||||
"ceil": _fn_ceil,
|
||||
"sqrt": _fn_sqrt,
|
||||
"sum": _fn_sum,
|
||||
"avg": _fn_avg,
|
||||
# Text
|
||||
"len": _fn_len,
|
||||
"upper": _fn_upper,
|
||||
"lower": _fn_lower,
|
||||
"trim": _fn_trim,
|
||||
"left": _fn_left,
|
||||
"right": _fn_right,
|
||||
"concat": _fn_concat,
|
||||
# Date
|
||||
"year": _fn_year,
|
||||
"month": _fn_month,
|
||||
"day": _fn_day,
|
||||
"today": _fn_today,
|
||||
"datediff": _fn_datediff,
|
||||
# Utility
|
||||
"coalesce": _fn_coalesce,
|
||||
"if_error": _fn_if_error,
|
||||
"count": _fn_count,
|
||||
}
|
||||
|
||||
|
||||
# ==================== Evaluator ====================
|
||||
|
||||
class FormulaEvaluator:
|
||||
"""
|
||||
Row-by-row formula evaluator.
|
||||
|
||||
Evaluates a FormulaDefinition AST against a single row of data.
|
||||
|
||||
Args:
|
||||
cross_table_resolver: Optional callback for cross-table references.
|
||||
Signature: resolver(table, column, where_clause, row_index) -> value
|
||||
"""
|
||||
|
||||
def __init__(self, cross_table_resolver: Optional[CrossTableResolver] = None):
|
||||
self._cross_table_resolver = cross_table_resolver
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
formula: FormulaDefinition,
|
||||
row_data: dict,
|
||||
row_index: int,
|
||||
) -> Any:
|
||||
"""
|
||||
Evaluate a formula for a single row.
|
||||
|
||||
Args:
|
||||
formula: The parsed FormulaDefinition AST.
|
||||
row_data: Dict mapping column_id -> value for the current row.
|
||||
row_index: The integer index of the current row.
|
||||
|
||||
Returns:
|
||||
The computed value, or None on error.
|
||||
"""
|
||||
try:
|
||||
return self._eval(formula.expression, row_data, row_index)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Formula evaluation error at row %d: %s", row_index, exc
|
||||
)
|
||||
return None
|
||||
|
||||
def _eval(self, node: FormulaNode, row_data: dict, row_index: int) -> Any:
|
||||
"""
|
||||
Recursively evaluate an AST node.
|
||||
|
||||
Args:
|
||||
node: The AST node to evaluate.
|
||||
row_data: Current row data dict.
|
||||
row_index: Current row index.
|
||||
|
||||
Returns:
|
||||
Evaluated value.
|
||||
"""
|
||||
if isinstance(node, LiteralNode):
|
||||
return node.value
|
||||
|
||||
if isinstance(node, ColumnRef):
|
||||
return self._resolve_column(node.column, row_data)
|
||||
|
||||
if isinstance(node, CrossTableRef):
|
||||
return self._resolve_cross_table(node, row_index)
|
||||
|
||||
if isinstance(node, BinaryOp):
|
||||
return self._eval_binary(node, row_data, row_index)
|
||||
|
||||
if isinstance(node, UnaryOp):
|
||||
return self._eval_unary(node, row_data, row_index)
|
||||
|
||||
if isinstance(node, FunctionCall):
|
||||
return self._eval_function(node, row_data, row_index)
|
||||
|
||||
if isinstance(node, ConditionalExpr):
|
||||
return self._eval_conditional(node, row_data, row_index)
|
||||
|
||||
logger.warning("Unknown AST node type: %s", type(node).__name__)
|
||||
return None
|
||||
|
||||
def _resolve_column(self, column_name: str, row_data: dict) -> Any:
|
||||
"""Resolve a column reference in the current row."""
|
||||
if column_name in row_data:
|
||||
return row_data[column_name]
|
||||
# Try case-insensitive match
|
||||
lower_name = column_name.lower()
|
||||
for key, value in row_data.items():
|
||||
if str(key).lower() == lower_name:
|
||||
return value
|
||||
logger.debug("Column '%s' not found in row_data", column_name)
|
||||
return None
|
||||
|
||||
def _resolve_cross_table(self, node: CrossTableRef, row_index: int) -> Any:
|
||||
"""Resolve a cross-table reference."""
|
||||
if self._cross_table_resolver is None:
|
||||
logger.warning(
|
||||
"No cross_table_resolver set for cross-table ref %s.%s",
|
||||
node.table, node.column,
|
||||
)
|
||||
return None
|
||||
return self._cross_table_resolver(
|
||||
node.table, node.column, node.where_clause, row_index
|
||||
)
|
||||
|
||||
def _eval_binary(self, node: BinaryOp, row_data: dict, row_index: int) -> Any:
|
||||
"""Evaluate a binary operation."""
|
||||
left = self._eval(node.left, row_data, row_index)
|
||||
op = node.operator
|
||||
|
||||
# Short-circuit for logical operators
|
||||
if op == "and":
|
||||
if not self._truthy(left):
|
||||
return False
|
||||
right = self._eval(node.right, row_data, row_index)
|
||||
return self._truthy(left) and self._truthy(right)
|
||||
|
||||
if op == "or":
|
||||
if self._truthy(left):
|
||||
return True
|
||||
right = self._eval(node.right, row_data, row_index)
|
||||
return self._truthy(left) or self._truthy(right)
|
||||
|
||||
right = self._eval(node.right, row_data, row_index)
|
||||
|
||||
# Arithmetic
|
||||
if op == "+":
|
||||
if isinstance(left, str) or isinstance(right, str):
|
||||
return str(left or "") + str(right or "")
|
||||
return self._num_op(left, right, lambda a, b: a + b)
|
||||
if op == "-":
|
||||
return self._num_op(left, right, lambda a, b: a - b)
|
||||
if op == "*":
|
||||
return self._num_op(left, right, lambda a, b: a * b)
|
||||
if op == "/":
|
||||
if right == 0 or right == 0.0:
|
||||
return None # Division by zero -> None
|
||||
return self._num_op(left, right, lambda a, b: a / b)
|
||||
if op == "%":
|
||||
if right == 0 or right == 0.0:
|
||||
return None
|
||||
return self._num_op(left, right, lambda a, b: a % b)
|
||||
if op == "^":
|
||||
return self._num_op(left, right, lambda a, b: a ** b)
|
||||
|
||||
# Comparison
|
||||
if op == "==":
|
||||
return left == right
|
||||
if op == "!=":
|
||||
return left != right
|
||||
if op == "<":
|
||||
return self._compare(left, right) < 0
|
||||
if op == "<=":
|
||||
return self._compare(left, right) <= 0
|
||||
if op == ">":
|
||||
return self._compare(left, right) > 0
|
||||
if op == ">=":
|
||||
return self._compare(left, right) >= 0
|
||||
|
||||
# String operations
|
||||
if op == "contains":
|
||||
return str(right) in str(left) if left is not None else False
|
||||
if op == "startswith":
|
||||
return str(left).startswith(str(right)) if left is not None else False
|
||||
if op == "endswith":
|
||||
return str(left).endswith(str(right)) if left is not None else False
|
||||
|
||||
# Collection operations
|
||||
if op == "in":
|
||||
values = right.value if isinstance(right, LiteralNode) else right
|
||||
if isinstance(values, list):
|
||||
return left in [v.value if isinstance(v, LiteralNode) else v for v in values]
|
||||
return left in (values or [])
|
||||
if op == "between":
|
||||
values = right.value if isinstance(right, LiteralNode) else right
|
||||
if isinstance(values, list) and len(values) == 2:
|
||||
lo = values[0].value if isinstance(values[0], LiteralNode) else values[0]
|
||||
hi = values[1].value if isinstance(values[1], LiteralNode) else values[1]
|
||||
return lo <= left <= hi
|
||||
return None
|
||||
|
||||
logger.warning("Unknown binary operator: %s", op)
|
||||
return None
|
||||
|
||||
def _eval_unary(self, node: UnaryOp, row_data: dict, row_index: int) -> Any:
|
||||
"""Evaluate a unary operation."""
|
||||
operand = self._eval(node.operand, row_data, row_index)
|
||||
op = node.operator
|
||||
|
||||
if op == "-":
|
||||
v = _safe_numeric(operand)
|
||||
return -v if v is not None else None
|
||||
if op == "not":
|
||||
return not self._truthy(operand)
|
||||
if op == "isempty":
|
||||
return operand is None or operand == "" or operand == []
|
||||
if op == "isnotempty":
|
||||
return operand is not None and operand != "" and operand != []
|
||||
if op == "isnan":
|
||||
try:
|
||||
return math.isnan(float(operand))
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
|
||||
logger.warning("Unknown unary operator: %s", op)
|
||||
return None
|
||||
|
||||
def _eval_function(self, node: FunctionCall, row_data: dict, row_index: int) -> Any:
|
||||
"""Evaluate a function call."""
|
||||
name = node.function_name.lower()
|
||||
args = [self._eval(arg, row_data, row_index) for arg in node.arguments]
|
||||
|
||||
if name in BUILTIN_FUNCTIONS:
|
||||
try:
|
||||
return BUILTIN_FUNCTIONS[name](args)
|
||||
except Exception as exc:
|
||||
logger.warning("Function '%s' error: %s", name, exc)
|
||||
return None
|
||||
|
||||
logger.warning("Unknown function: %s", name)
|
||||
return None
|
||||
|
||||
def _eval_conditional(self, node: ConditionalExpr, row_data: dict, row_index: int) -> Any:
|
||||
"""Evaluate a conditional expression."""
|
||||
condition = self._eval(node.condition, row_data, row_index)
|
||||
if self._truthy(condition):
|
||||
return self._eval(node.value_expr, row_data, row_index)
|
||||
elif node.else_expr is not None:
|
||||
return self._eval(node.else_expr, row_data, row_index)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _truthy(value: Any) -> bool:
|
||||
"""Convert a value to boolean for conditional evaluation."""
|
||||
if value is None:
|
||||
return False
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, (int, float)):
|
||||
return value != 0
|
||||
if isinstance(value, str):
|
||||
return len(value) > 0
|
||||
return bool(value)
|
||||
|
||||
@staticmethod
|
||||
def _num_op(left: Any, right: Any, fn: Callable) -> Any:
|
||||
"""Apply a numeric binary function, returning None if inputs are non-numeric."""
|
||||
a = _safe_numeric(left)
|
||||
b = _safe_numeric(right)
|
||||
if a is None or b is None:
|
||||
return None
|
||||
try:
|
||||
return fn(a, b)
|
||||
except (ZeroDivisionError, OverflowError, ValueError):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _compare(left: Any, right: Any) -> int:
|
||||
"""
|
||||
Compare two values, returning -1, 0, or 1.
|
||||
|
||||
Handles mixed numeric/string comparisons gracefully.
|
||||
"""
|
||||
try:
|
||||
if left < right:
|
||||
return -1
|
||||
elif left > right:
|
||||
return 1
|
||||
return 0
|
||||
except TypeError:
|
||||
# Fallback: compare as strings
|
||||
sl, sr = str(left), str(right)
|
||||
if sl < sr:
|
||||
return -1
|
||||
elif sl > sr:
|
||||
return 1
|
||||
return 0
|
||||
Reference in New Issue
Block a user