Introducing columns formulas

This commit is contained in:
2026-02-13 21:38:00 +01:00
parent 0df78c0513
commit e8443f07f9
29 changed files with 3889 additions and 15 deletions

365
docs/Datagrid Formulas.md Normal file
View File

@@ -0,0 +1,365 @@
# DataGrid Formulas
## Overview
The DataGrid formula system adds computed columns to the DataGrid. A formula column applies a single expression to every
row, producing derived values from existing data — within the same table or across tables.
The system is designed for:
- **Column-level formulas**: one formula per column, applied to all rows
- **Cross-table references**: direct syntax to reference columns from other tables
- **Reactive recalculation**: dirty flag propagation with page-aware computation
- **Cell-level overrides** (planned): individual cells can override the column formula
## Formula Language
### Basic Syntax
A formula is an expression that references columns with `{ColumnName}` and produces a value for each row:
```
{Price} * {Quantity}
```
References use curly braces `{}` to distinguish column names from keywords and functions. Column names are matched by ID
or title.
### Operators
#### Arithmetic
| Operator | Description | Example |
|----------|----------------|------------------------|
| `+` | Addition | `{Price} + {Tax}` |
| `-` | Subtraction | `{Total} - {Discount}` |
| `*` | Multiplication | `{Price} * {Quantity}` |
| `/` | Division | `{Total} / {Count}` |
| `%` | Modulo | `{Value} % 2` |
| `^` | Power | `{Base} ^ 2` |
#### Comparison
| Operator | Description | Example |
|--------------|--------------------|---------------------------------|
| `==` | Equal | `{Status} == "active"` |
| `!=` | Not equal | `{Status} != "deleted"` |
| `>` | Greater than | `{Price} > 100` |
| `<` | Less than | `{Stock} < 10` |
| `>=` | Greater or equal | `{Score} >= 80` |
| `<=` | Less or equal | `{Age} <= 18` |
| `contains` | String contains | `{Name} contains "Corp"` |
| `startswith` | String starts with | `{Code} startswith "ERR"` |
| `endswith` | String ends with | `{File} endswith ".csv"` |
| `in` | Value in list | `{Status} in ["active", "new"]` |
| `between` | Value in range | `{Age} between 18 and 65` |
| `isempty` | Value is empty | `{Notes} isempty` |
| `isnotempty` | Value is not empty | `{Email} isnotempty` |
| `isnan` | Value is NaN | `{Score} isnan` |
#### Logical
| Operator | Description | Example |
|----------|-------------|---------------------------------------|
| `and` | Logical AND | `{Age} > 18 and {Status} == "active"` |
| `or` | Logical OR | `{Type} == "A" or {Type} == "B"` |
| `not` | Negation | `not {Status} == "deleted"` |
Parentheses control precedence: `({Type} == "A" or {Type} == "B") and {Active} == True`
### Conditions (suffix-if)
Conditions use a **suffix-if** syntax: the result expression comes first, then the condition. This keeps the focus on
the output, not the branching logic.
#### Simple condition (no else — result is None when false)
```
{Price} * 0.8 if {Country} == "FR"
```
#### With else
```
{Price} * 0.8 if {Country} == "FR" else {Price}
```
#### Chained conditions
```
{Price} * 0.8 if {Country} == "FR" else {Price} * 0.9 if {Country} == "DE" else {Price}
```
#### With logical operators
```
{Price} * 0.8 if {Country} == "FR" and {Quantity} > 10 else {Price}
```
#### With grouping
```
{Price} * 0.8 if ({Country} == "FR" or {Country} == "DE") and {Quantity} > 10
```
### Functions
#### Math
| Function | Description | Example |
|-------------------|-----------------------|-------------------------------|
| `round(expr, n)` | Round to n decimals | `round({Price} * 1.2, 2)` |
| `abs(expr)` | Absolute value | `abs({Balance})` |
| `min(expr, expr)` | Minimum of two values | `min({Price}, {MaxPrice})` |
| `max(expr, expr)` | Maximum of two values | `max({Score}, 0)` |
| `sum(expr, ...)` | Sum of values | `sum({Q1}, {Q2}, {Q3}, {Q4})` |
| `avg(expr, ...)` | Average of values | `avg({Q1}, {Q2}, {Q3}, {Q4})` |
#### Text
| Function | Description | Example |
|---------------------|---------------------|--------------------------------|
| `upper(expr)` | Uppercase | `upper({Name})` |
| `lower(expr)` | Lowercase | `lower({Email})` |
| `len(expr)` | String length | `len({Description})` |
| `concat(expr, ...)` | Concatenate strings | `concat({First}, " ", {Last})` |
| `trim(expr)` | Remove whitespace | `trim({Input})` |
| `left(expr, n)` | First n characters | `left({Code}, 3)` |
| `right(expr, n)` | Last n characters | `right({Phone}, 4)` |
#### Date
| Function | Description | Example |
|------------------------|--------------------|--------------------------------|
| `year(expr)` | Extract year | `year({CreatedAt})` |
| `month(expr)` | Extract month | `month({CreatedAt})` |
| `day(expr)` | Extract day | `day({CreatedAt})` |
| `today()` | Current date | `datediff({DueDate}, today())` |
| `datediff(expr, expr)` | Difference in days | `datediff({End}, {Start})` |
#### Aggregation (for cross-table contexts)
| Function | Description | Example |
|---------------|--------------|-----------------------------------------------------|
| `sum(expr)` | Sum values | `sum({Orders.Amount WHERE Orders.ClientId = Id})` |
| `count(expr)` | Count values | `count({Orders.Id WHERE Orders.ClientId = Id})` |
| `avg(expr)` | Average | `avg({Reviews.Score WHERE Reviews.ProductId = Id})` |
| `min(expr)` | Minimum | `min({Bids.Price WHERE Bids.ItemId = Id})` |
| `max(expr)` | Maximum | `max({Bids.Price WHERE Bids.ItemId = Id})` |
## Cross-Table References
### Direct Reference
Reference a column from another table using `{TableName.ColumnName}`:
```
{Products.Price} * {Quantity}
```
### Join Resolution (implicit)
When referencing another table without a WHERE clause, the join is resolved automatically:
1. **By `id` column**: if both tables have a column named `id`, rows are matched on equal `id` values
2. **By row index**: if no `id` column exists in both tables, rows are matched by their internal row index (stable
across sort/filter)
### Explicit Join (WHERE clause)
For explicit control over which row of the other table to use:
```
{Products.Price WHERE Products.Code = ProductCode} * {Quantity}
```
Inside the WHERE clause:
- `Products.Code` refers to a column in the referenced table
- `ProductCode` (no `Table.` prefix) refers to a column in the current table
### Aggregation with Cross-Table
When a cross-table reference matches multiple rows, use an aggregation function:
```
sum({OrderLines.Amount WHERE OrderLines.OrderId = Id})
```
Without aggregation, a multi-row match returns the first matching value.
## Calculation Engine
### Dependency Graph (DAG)
The formula system maintains a **Directed Acyclic Graph** of dependencies between columns:
- **Nodes**: each formula column is a node, identified by `table_name.column_id`
- **Edges**: if column A's formula references column B, an edge B → A exists ("A depends on B")
- Both directions are tracked:
- **Precedents**: columns that a formula reads from
- **Dependents**: columns that need recalculation when this column changes
Cross-table references create edges that span DataGrid instances, managed at the `DataGridsManager` level.
### Dirty Flag Propagation
When a source column's data changes:
1. The source column is marked **dirty**
2. All direct dependents are marked dirty
3. Propagation continues recursively through the DAG
4. Each dirty column maintains a **dirty row set**: the specific row indices that need recalculation
This propagation is **immediate** (fast — only flag marking, no computation).
### Recalculation Strategy (Hybrid)
Actual computation is **deferred to rendering time**:
1. On value change → dirty flags propagate instantly through the DAG
2. On page render (`mk_body_content_page`) → only dirty rows within the visible page (up to 1000 rows) are recalculated
3. Off-screen pages remain dirty until scrolled into view
4. Calculation follows **topological order** of the DAG to ensure precedents are computed before dependents
### Cycle Detection
Before adding a formula, the engine checks for cycles in the DAG using Kahn's algorithm during topological sort. If a
cycle is detected:
- The formula is **rejected**
- The editor displays an error identifying the circular dependency chain
- The previous formula (if any) remains unchanged
### Caching
Each formula column caches its computed values:
- Results are stored in `ns_fast_access[col_id]` alongside raw data columns
- The dirty row set tracks which cached values are stale
- Non-dirty rows return their cached value without re-evaluation
- Cache is invalidated per-row when source data changes
## Evaluation
### Row-by-Row Execution
Formulas are evaluated **row-by-row** within the page being rendered. For each row:
1. Resolve column references `{ColumnName}` to the cell value at the current row index
2. Resolve cross-table references `{Table.Column}` via the join mechanism
3. Evaluate the expression with resolved values
4. Store the result in the cache (`ns_fast_access`)
### Parser
The formula language uses a **custom grammar** parsed with Lark (consistent with the formatting DSL). The parser:
1. Tokenizes the formula string
2. Builds an AST (Abstract Syntax Tree)
3. Transforms the AST into an evaluable representation
4. Extracts column references for dependency graph registration
### Error Handling
| Error Type | Behavior |
|-----------------------|-------------------------------------------------------|
| Syntax error | Editor highlights the error, formula not saved |
| Unknown column | Editor highlights, autocompletion suggests fixes |
| Type mismatch | Cell displays error indicator, other cells unaffected |
| Division by zero | Cell displays `#DIV/0!` or None |
| Circular dependency | Formula rejected, editor shows cycle chain |
| Cross-table not found | Editor highlights unknown table name |
| No join match | Cell displays None |
## User Interface
### Creating a Formula Column
Formula columns are created and edited through the **DataGridColumnsManager**:
1. User opens the Columns Manager panel
2. Adds a new column or edits an existing one
3. Selects column type **"Formula"**
4. A **DslEditor** (CodeMirror 5) opens for formula input
5. The editor provides:
- **Syntax highlighting**: keywords, column references, functions, operators
- **Autocompletion**: column names (current table and other tables), function names, table names
- **Validation**: real-time syntax checking and dependency cycle detection
- **Error markers**: inline error indicators with descriptions
### Formula Column Properties
A formula column extends `DataGridColumnState` with:
| Property | Type | Description |
|---------------------------------------------------------------------------|---------------|------------------------------------------------|
| `formula` | `str` or None | The formula expression (None for data columns) |
| `col_type` | `ColumnType` | Set to `ColumnType.Formula` |
| Other properties (`title`, `visible`, `width`, `format`) remain unchanged |
Formula columns are **read-only** in the grid body — cell values are computed, not editable. Formatting rules from the
formatting DSL apply to formula columns like any other column.
## Integration Points
| Component | Role |
|--------------------------|----------------------------------------------------------|
| `DataGridColumnState` | Stores `formula` field and `ColumnType.Formula` type |
| `DatagridStore` | `ns_fast_access` caches formula results as numpy arrays |
| `DataGridColumnsManager` | UI for creating/editing formula columns |
| `DataGridsManager` | Hosts the global dependency DAG across all tables |
| `DslEditor` | CodeMirror 5 editor with highlighting and autocompletion |
| `FormattingEngine` | Applies formatting rules AFTER formula evaluation |
| `mk_body_content_page()` | Triggers formula computation for visible rows |
| `mk_body_cell_content()` | Reads computed values from `ns_fast_access` |
## Syntax Summary
```
# Basic arithmetic
{Price} * {Quantity}
# Function call
round({Price} * 1.2, 2)
# Simple condition (None if false)
{Price} * 0.8 if {Country} == "FR"
# Condition with else
{Price} * 0.8 if {Country} == "FR" else {Price}
# Chained conditions
{Price} * 0.8 if {Country} == "FR" else {Price} * 0.9 if {Country} == "DE" else {Price}
# Logical operators
{Price} * 0.8 if {Country} == "FR" and {Quantity} > 10
# Grouping
{Price} * 0.8 if ({Country} == "FR" or {Country} == "DE") and {Quantity} > 10
# Cross-table (implicit join on id)
{Products.Price} * {Quantity}
# Cross-table (explicit join)
{Products.Price WHERE Products.Code = ProductCode} * {Quantity}
# Cross-table aggregation
sum({OrderLines.Amount WHERE OrderLines.OrderId = Id})
# Nested functions
round(avg({Q1}, {Q2}, {Q3}, {Q4}), 1)
# Text operations
concat(upper(left({FirstName}, 1)), ". ", {LastName})
```
## Future: Cell-Level Overrides
The architecture supports adding cell-level formula overrides with ~20-30% additional work:
- **Storage**: sparse dict `cell_formulas: dict[(col_id, row_index), str]` (same pattern as `cell_formats`)
- **DAG**: new node type `table.column[row]` alongside existing `table.column` nodes
- **Evaluation**: "does this cell have an override? If yes, use it. Otherwise, use the column formula."
- **Node ID scheme**: designed to be extensible from the start (`table.column` for columns, `table.column[row]` for
cells)

View File

@@ -418,9 +418,41 @@ class DataGrid(MultipleInstance):
self._df_store.ns_fast_access = _init_fast_access(self._df) self._df_store.ns_fast_access = _init_fast_access(self._df)
self._df_store.ns_row_data = _init_row_data(self._df) self._df_store.ns_row_data = _init_row_data(self._df)
self._df_store.ns_total_rows = len(self._df) if self._df is not None else 0 self._df_store.ns_total_rows = len(self._df) if self._df is not None else 0
if init_state:
self._register_existing_formulas()
return self return self
def _register_existing_formulas(self) -> None:
"""
Re-register all formula columns with the FormulaEngine.
Called after data reload to ensure the engine knows about all
formula columns and their expressions.
"""
engine = self._get_formula_engine()
if engine is None:
return
table = self.get_table_name()
for col_def in self._state.columns:
if col_def.formula:
try:
engine.set_formula(table, col_def.col_id, col_def.formula)
except Exception as e:
logger.warning("Failed to register formula for %s.%s: %s", table, col_def.col_id, e)
def _recalculate_formulas(self) -> None:
"""
Recalculate dirty formula columns before rendering.
Called at the start of mk_body_content_page() to ensure formula
columns are up-to-date before cells are rendered.
"""
engine = self._get_formula_engine()
if engine is None:
return
engine.recalculate_if_needed(self.get_table_name(), self._df_store)
def _get_format_rules(self, col_pos, row_index, col_def): def _get_format_rules(self, col_pos, row_index, col_def):
""" """
Get format rules for a cell, returning only the most specific level defined. Get format rules for a cell, returning only the most specific level defined.
@@ -575,6 +607,11 @@ class DataGrid(MultipleInstance):
def get_table_name(self): def get_table_name(self):
return f"{self._settings.namespace}.{self._settings.name}" if self._settings.namespace else self._settings.name return f"{self._settings.namespace}.{self._settings.name}" if self._settings.namespace else self._settings.name
def get_formula_engine(self):
"""Return the FormulaEngine from the DataGridsManager, if available."""
return self._parent.get_formula_engine()
def mk_headers(self): def mk_headers(self):
resize_cmd = self.commands.set_column_width() resize_cmd = self.commands.set_column_width()
move_cmd = self.commands.move_column() move_cmd = self.commands.move_column()
@@ -701,6 +738,7 @@ class DataGrid(MultipleInstance):
OPTIMIZED: Extract filter keyword once instead of 10,000 times. OPTIMIZED: Extract filter keyword once instead of 10,000 times.
OPTIMIZED: Uses OptimizedDiv for rows instead of Div for faster rendering. OPTIMIZED: Uses OptimizedDiv for rows instead of Div for faster rendering.
""" """
self._recalculate_formulas()
df = self._get_filtered_df() df = self._get_filtered_df()
if df is None: if df is None:
return [] return []

View File

@@ -99,6 +99,9 @@ class DataGridColumnsManager(MultipleInstance):
col_def.type = ColumnType(v) col_def.type = ColumnType(v)
elif k == "width": elif k == "width":
col_def.width = int(v) col_def.width = int(v)
elif k == "formula":
col_def.formula = v or ""
self._register_formula(col_def)
else: else:
setattr(col_def, k, v) setattr(col_def, k, v)
@@ -107,6 +110,21 @@ class DataGridColumnsManager(MultipleInstance):
return self.mk_all_columns() return self.mk_all_columns()
def _register_formula(self, col_def) -> None:
"""Register or remove a formula column with the FormulaEngine."""
engine = self._parent.get_formula_engine()
if engine is None:
return
table = self._parent.get_table_name()
if col_def.formula:
try:
engine.set_formula(table, col_def.col_id, col_def.formula)
logger.debug("Registered formula for %s.%s", table, col_def.col_id)
except Exception as e:
logger.warning("Formula error for %s.%s: %s", table, col_def.col_id, e)
else:
engine.remove_formula(table, col_def.col_id)
def mk_column_label(self, col_def: DataGridColumnState): def mk_column_label(self, col_def: DataGridColumnState):
return Div( return Div(
mk.mk( mk.mk(
@@ -168,6 +186,17 @@ class DataGridColumnsManager(MultipleInstance):
value=col_def.title, value=col_def.title,
), ),
*([
Label("Formula"),
Textarea(
col_def.formula or "",
name="formula",
cls=f"textarea textarea-{size} w-full font-mono",
placeholder="{Column} * {OtherColumn}",
rows=3,
),
] if col_def.type == ColumnType.Formula else []),
legend="Column details", legend="Column details",
cls="fieldset border-base-300 rounded-box" cls="fieldset border-base-300 rounded-box"
), ),

View File

@@ -0,0 +1,67 @@
"""
DataGridFormulaEditor — DslEditor for formula column expressions.
Extends DslEditor with formula-specific behavior:
- Parses the formula on content change
- Registers the formula with FormulaEngine
- Triggers a body re-render on the parent DataGrid
"""
import logging
from myfasthtml.controls.DslEditor import DslEditor
from myfasthtml.core.formula.dsl.exceptions import FormulaSyntaxError, FormulaCycleError
logger = logging.getLogger("DataGridFormulaEditor")
class DataGridFormulaEditor(DslEditor):
"""
Formula editor for a specific DataGrid column.
Args:
parent: The parent DataGrid instance.
col_def: The DataGridColumnState for the formula column.
conf: DslEditorConf for CodeMirror configuration.
_id: Optional instance ID.
"""
def __init__(self, parent, col_def, conf=None, _id=None):
super().__init__(parent, conf=conf, _id=_id)
self._col_def = col_def
def on_content_changed(self):
"""
Called when the formula text is changed in the editor.
1. Updates col_def.formula with the new text.
2. Registers the formula with the FormulaEngine.
3. Triggers a body re-render of the parent DataGrid.
"""
formula_text = self.get_content()
# Update the column definition
self._col_def.formula = formula_text or ""
# Register with the FormulaEngine
engine = self._parent._get_formula_engine()
if engine is not None:
table = self._parent.get_table_name()
try:
engine.set_formula(table, self._col_def.col_id, formula_text)
logger.debug(
"Formula updated for %s.%s: %s",
table, self._col_def.col_id, formula_text,
)
except FormulaSyntaxError as e:
logger.debug("Formula syntax error, keeping old formula: %s", e)
return
except FormulaCycleError as e:
logger.warning("Formula cycle detected for %s.%s: %s", table, self._col_def.col_id, e)
return
except Exception as e:
logger.warning("Formula engine error for %s.%s: %s", table, self._col_def.col_id, e)
return
# Save state and re-render the grid body
self._parent.save_state()
return self._parent.render_partial("body")

View File

@@ -12,7 +12,7 @@ from myfasthtml.icons.fluent import brain_circuit20_regular
from myfasthtml.icons.fluent_p1 import filter20_regular, search20_regular from myfasthtml.icons.fluent_p1 import filter20_regular, search20_regular
from myfasthtml.icons.fluent_p2 import dismiss_circle20_regular from myfasthtml.icons.fluent_p2 import dismiss_circle20_regular
logger = logging.getLogger("DataGridFilter") logger = logging.getLogger("DataGridQuery")
DG_QUERY_FILTER = "filter" DG_QUERY_FILTER = "filter"
DG_QUERY_SEARCH = "search" DG_QUERY_SEARCH = "search"

View File

@@ -16,6 +16,7 @@ from myfasthtml.core.commands import Command
from myfasthtml.core.dbmanager import DbObject from myfasthtml.core.dbmanager import DbObject
from myfasthtml.core.formatting.dsl.completion.provider import DatagridMetadataProvider from myfasthtml.core.formatting.dsl.completion.provider import DatagridMetadataProvider
from myfasthtml.core.formatting.presets import DEFAULT_STYLE_PRESETS, DEFAULT_FORMATTER_PRESETS from myfasthtml.core.formatting.presets import DEFAULT_STYLE_PRESETS, DEFAULT_FORMATTER_PRESETS
from myfasthtml.core.formula.engine import FormulaEngine
from myfasthtml.core.instances import InstancesManager, SingleInstance from myfasthtml.core.instances import InstancesManager, SingleInstance
from myfasthtml.icons.fluent_p1 import table_add20_regular from myfasthtml.icons.fluent_p1 import table_add20_regular
from myfasthtml.icons.fluent_p3 import folder_open20_regular from myfasthtml.icons.fluent_p3 import folder_open20_regular
@@ -91,6 +92,11 @@ class DataGridsManager(SingleInstance, DatagridMetadataProvider):
self.style_presets: dict = DEFAULT_STYLE_PRESETS.copy() self.style_presets: dict = DEFAULT_STYLE_PRESETS.copy()
self.formatter_presets: dict = DEFAULT_FORMATTER_PRESETS.copy() self.formatter_presets: dict = DEFAULT_FORMATTER_PRESETS.copy()
self.all_tables_formats: list = [] self.all_tables_formats: list = []
# Formula engine shared across all DataGrids in this session
self._formula_engine = FormulaEngine(
registry_resolver=self._resolve_store_for_table
)
def upload_from_source(self): def upload_from_source(self):
file_upload = FileUpload(self) file_upload = FileUpload(self)
@@ -167,10 +173,10 @@ class DataGridsManager(SingleInstance, DatagridMetadataProvider):
def list_column_values(self, table_name, column_name): def list_column_values(self, table_name, column_name):
return self._registry.get_column_values(table_name, column_name) return self._registry.get_column_values(table_name, column_name)
def get_row_count(self, table_name): def get_row_count(self, table_name):
return self._registry.get_row_count(table_name) return self._registry.get_row_count(table_name)
def get_column_type(self, table_name, column_name): def get_column_type(self, table_name, column_name):
return self._registry.get_column_type(table_name, column_name) return self._registry.get_column_type(table_name, column_name)
@@ -180,7 +186,29 @@ class DataGridsManager(SingleInstance, DatagridMetadataProvider):
def list_format_presets(self) -> list[str]: def list_format_presets(self) -> list[str]:
return list(self.formatter_presets.keys()) return list(self.formatter_presets.keys())
# === Presets Management === def _resolve_store_for_table(self, table_name: str):
"""
Resolve the DatagridStore for a given table name.
Used by FormulaEngine as the registry_resolver callback.
Args:
table_name: Full table name in ``"namespace.name"`` format.
Returns:
DatagridStore instance or None if not found.
"""
try:
as_fullname_dict = self._registry._get_entries_as_full_name_dict()
grid_id = as_fullname_dict.get(table_name)
if grid_id is None:
return None
datagrid = InstancesManager.get(self._session, grid_id, None)
if datagrid is None:
return None
return datagrid._df_store
except Exception:
return None
def get_style_presets(self) -> dict: def get_style_presets(self) -> dict:
"""Get the global style presets.""" """Get the global style presets."""
@@ -190,6 +218,10 @@ class DataGridsManager(SingleInstance, DatagridMetadataProvider):
"""Get the global formatter presets.""" """Get the global formatter presets."""
return self.formatter_presets return self.formatter_presets
def get_formula_engine(self) -> FormulaEngine:
"""The FormulaEngine shared across all DataGrids in this session."""
return self._formula_engine
def add_style_preset(self, name: str, preset: dict): def add_style_preset(self, name: str, preset: dict):
""" """
Add or update a style preset. Add or update a style preset.

View File

@@ -20,6 +20,7 @@ class DataGridColumnState:
visible: bool = True visible: bool = True
width: int = DATAGRID_DEFAULT_COLUMN_WIDTH width: int = DATAGRID_DEFAULT_COLUMN_WIDTH
format: list = field(default_factory=list) # format: list = field(default_factory=list) #
formula: str = "" # formula expression for ColumnType.Formula columns
@dataclass @dataclass

View File

@@ -26,6 +26,7 @@ class ColumnType(Enum):
Bool = "Boolean" Bool = "Boolean"
Choice = "Choice" Choice = "Choice"
Enum = "Enum" Enum = "Enum"
Formula = "Formula"
class ViewType(Enum): class ViewType(Enum):

View File

@@ -98,6 +98,9 @@ class StyleResolver:
return StyleContainer(None, "") return StyleContainer(None, "")
cls = props.pop("__class__", None) cls = props.pop("__class__", None)
if not props:
return StyleContainer(cls, "")
css = "; ".join(f"{key}: {value}" for key, value in props.items()) + ";" css = "; ".join(f"{key}: {value}" for key, value in props.items()) + ";"
return StyleContainer(cls, css) return StyleContainer(cls, css)

View File

View File

@@ -0,0 +1,79 @@
from dataclasses import dataclass, field
from typing import Any, Optional
@dataclass
class FormulaNode:
"""Base AST node for formula expressions."""
pass
@dataclass
class LiteralNode(FormulaNode):
"""A literal value (number, string, boolean)."""
value: Any
@dataclass
class ColumnRef(FormulaNode):
"""Reference to a column in the current table: {ColumnName}."""
column: str
@dataclass
class WhereClause:
"""WHERE clause for cross-table references: WHERE remote_table.remote_col = local_col."""
remote_table: str
remote_column: str
local_column: str
@dataclass
class CrossTableRef(FormulaNode):
"""Reference to a column in another table: {Table.Column}."""
table: str
column: str
where_clause: Optional[WhereClause] = None
@dataclass
class BinaryOp(FormulaNode):
"""Binary operation: left op right."""
operator: str
left: FormulaNode
right: FormulaNode
@dataclass
class UnaryOp(FormulaNode):
"""Unary operation: -expr or not expr."""
operator: str
operand: FormulaNode
@dataclass
class FunctionCall(FormulaNode):
"""Function call: func(args...)."""
function_name: str
arguments: list = field(default_factory=list)
@dataclass
class ConditionalExpr(FormulaNode):
"""Conditional: value_expr if condition [else else_expr].
Chainable: val1 if cond1 else val2 if cond2 else val3
"""
value_expr: FormulaNode
condition: FormulaNode
else_expr: Optional[FormulaNode] = None
@dataclass
class FormulaDefinition:
"""A complete formula definition for a column."""
expression: FormulaNode
source_text: str = ""
def __str__(self):
return self.source_text

View File

@@ -0,0 +1,386 @@
"""
Dependency Graph (DAG) for formula columns.
Tracks column dependencies, propagates dirty flags, and provides
topological ordering for incremental recalculation.
Node IDs use the format ``"table_name.column_id"`` for column-level
granularity, designed to be extensible to ``"table_name.column_id[row]"``
for cell-level overrides.
"""
import logging
from collections import defaultdict, deque
from dataclasses import dataclass, field
from typing import Optional, Set
from .dataclasses import (
FormulaDefinition,
ColumnRef,
CrossTableRef,
BinaryOp,
UnaryOp,
FunctionCall,
ConditionalExpr,
FormulaNode,
)
from .dsl.exceptions import FormulaCycleError
logger = logging.getLogger("DependencyGraph")
@dataclass
class DependencyNode:
"""
A node in the dependency graph.
Attributes:
node_id: Unique identifier in the format ``"table.column"``.
table: Table name.
column: Column name.
dirty: Whether this node needs recalculation.
dirty_rows: Set of specific row indices that are dirty.
Empty set means all rows are dirty.
formula: The parsed FormulaDefinition for formula nodes.
"""
node_id: str
table: str
column: str
dirty: bool = False
dirty_rows: Set[int] = field(default_factory=set)
formula: Optional[FormulaDefinition] = None
class DependencyGraph:
"""
Directed Acyclic Graph of formula column dependencies.
Tracks which columns depend on which other columns and provides:
- Dirty flag propagation via BFS
- Topological ordering for recalculation (Kahn's algorithm)
- Cycle detection before registering new formulas
The graph is bidirectional:
- ``_dependents[A]`` = set of nodes that depend on A (forward edges)
- ``_precedents[B]`` = set of nodes that B depends on (reverse edges)
"""
def __init__(self):
self._nodes: dict[str, DependencyNode] = {}
# forward: A -> {B, C} means "B and C depend on A"
self._dependents: dict[str, set[str]] = defaultdict(set)
# reverse: B -> {A} means "B depends on A"
self._precedents: dict[str, set[str]] = defaultdict(set)
def add_formula(
self,
table: str,
column: str,
formula: FormulaDefinition,
) -> None:
"""
Register a formula for a column and add dependency edges.
Raises:
FormulaCycleError: If adding this formula would create a cycle.
Args:
table: Table name.
column: Column name.
formula: The parsed FormulaDefinition.
"""
logger.debug(f"add_formula {table}.{column}:{formula.source_text}")
node_id = self._make_node_id(table, column)
# Extract dependency node_ids from the formula AST
dep_ids = self._extract_dependencies(formula.expression, table)
# Temporarily remove old edges to avoid stale dependencies
self._remove_edges(node_id)
# Add new edges
for dep_id in dep_ids:
if dep_id == node_id:
raise FormulaCycleError([node_id])
self._dependents[dep_id].add(node_id)
self._precedents[node_id].add(dep_id)
# Ensure all referenced nodes exist (as data nodes without formulas)
for dep_id in dep_ids:
if dep_id not in self._nodes:
dep_table, dep_col = dep_id.split(".", 1)
self._nodes[dep_id] = DependencyNode(
node_id=dep_id,
table=dep_table,
column=dep_col,
)
# Ensure formula node exists
node = self._get_or_create_node(table, column)
node.formula = formula
node.dirty = True # New formula -> needs evaluation
# Detect cycles using Kahn's algorithm
self._detect_cycles()
logger.debug("Added formula for %s depending on: %s", node_id, dep_ids)
def remove_formula(self, table: str, column: str) -> None:
"""
Remove a formula column and its edges from the graph.
Args:
table: Table name.
column: Column name.
"""
node_id = self._make_node_id(table, column)
self._remove_edges(node_id)
if node_id in self._nodes:
node = self._nodes[node_id]
node.formula = None
node.dirty = False
node.dirty_rows.clear()
# If the node has no dependents either, remove it
if not self._dependents.get(node_id):
del self._nodes[node_id]
logger.debug("Removed formula for %s", node_id)
def get_calculation_order(self, table: Optional[str] = None) -> list[DependencyNode]:
"""
Return dirty formula nodes in topological order.
Uses Kahn's algorithm (BFS-based topological sort).
Only returns nodes with a formula that are dirty.
Args:
table: If provided, filter to only nodes for this table.
Returns:
List of dirty DependencyNode objects in calculation order.
"""
# Build in-degree map for nodes with formulas
formula_nodes = {
nid: node for nid, node in self._nodes.items()
if node.formula is not None and node.dirty
}
if not formula_nodes:
return []
# Kahn's algorithm on the subgraph of formula nodes
in_degree = {nid: 0 for nid in formula_nodes}
for nid in formula_nodes:
for prec_id in self._precedents.get(nid, set()):
if prec_id in formula_nodes:
in_degree[nid] += 1
queue = deque([nid for nid, deg in in_degree.items() if deg == 0])
result = []
while queue:
nid = queue.popleft()
node = self._nodes[nid]
if table is None or node.table == table:
result.append(node)
for dep_id in self._dependents.get(nid, set()):
if dep_id in in_degree:
in_degree[dep_id] -= 1
if in_degree[dep_id] == 0:
queue.append(dep_id)
return result
def clear_dirty(self, node_id: str) -> None:
"""
Clear dirty flags for a node after successful recalculation.
Args:
node_id: The node ID in format ``"table.column"``.
"""
if node_id in self._nodes:
node = self._nodes[node_id]
node.dirty = False
node.dirty_rows.clear()
def get_node(self, table: str, column: str) -> Optional[DependencyNode]:
"""
Get a node by table and column.
Args:
table: Table name.
column: Column name.
Returns:
DependencyNode or None if not found.
"""
node_id = self._make_node_id(table, column)
return self._nodes.get(node_id)
def has_formula(self, table: str, column: str) -> bool:
"""
Check if a column has a formula registered.
Args:
table: Table name.
column: Column name.
Returns:
True if the column has a formula.
"""
node = self.get_node(table, column)
return node is not None and node.formula is not None
def mark_dirty(
self,
table: str,
column: str,
rows: Optional[list[int]] = None,
) -> None:
"""
Mark a column (and its transitive dependents) as dirty.
Uses BFS to propagate dirty flags through the dependency graph.
Args:
table: Table name.
column: Column name.
rows: Specific row indices to mark dirty. None means all rows.
"""
node_id = self._make_node_id(table, column)
self._mark_node_dirty(node_id, rows)
# BFS propagation through dependents
queue = deque([node_id])
visited = {node_id}
while queue:
current_id = queue.popleft()
for dep_id in self._dependents.get(current_id, set()):
self._mark_node_dirty(dep_id, rows)
if dep_id not in visited:
visited.add(dep_id)
queue.append(dep_id)
# ==================== Private helpers ====================
@staticmethod
def _make_node_id(table: str, column: str) -> str:
"""Create a standard node ID from table and column names."""
return f"{table}.{column}"
def _get_or_create_node(self, table: str, column: str) -> DependencyNode:
"""Get existing node or create a new one."""
node_id = self._make_node_id(table, column)
if node_id not in self._nodes:
self._nodes[node_id] = DependencyNode(
node_id=node_id,
table=table,
column=column,
)
return self._nodes[node_id]
def _mark_node_dirty(self, node_id: str, rows: Optional[list[int]]) -> None:
"""Mark a specific node as dirty."""
if node_id not in self._nodes:
return
node = self._nodes[node_id]
node.dirty = True
if rows is not None:
node.dirty_rows.update(rows)
else:
node.dirty_rows.clear() # Empty = all rows dirty
def _remove_edges(self, node_id: str) -> None:
"""Remove all edges connected to a node."""
# Remove this node from its precedents' dependents sets
for prec_id in list(self._precedents.get(node_id, set())):
self._dependents[prec_id].discard(node_id)
# Clear this node's precedents
self._precedents[node_id].clear()
def _detect_cycles(self) -> None:
"""
Detect cycles in the full graph using Kahn's algorithm.
Raises:
FormulaCycleError: If a cycle is detected.
"""
# Only check formula nodes
formula_nodes = {
nid for nid, node in self._nodes.items()
if node.formula is not None
}
if not formula_nodes:
return
in_degree = {}
for nid in formula_nodes:
in_degree[nid] = 0
for nid in formula_nodes:
for prec_id in self._precedents.get(nid, set()):
if prec_id in formula_nodes:
in_degree[nid] = in_degree.get(nid, 0) + 1
queue = deque([nid for nid in formula_nodes if in_degree.get(nid, 0) == 0])
processed = set()
while queue:
nid = queue.popleft()
processed.add(nid)
for dep_id in self._dependents.get(nid, set()):
if dep_id in formula_nodes:
in_degree[dep_id] -= 1
if in_degree[dep_id] == 0:
queue.append(dep_id)
cycle_nodes = formula_nodes - processed
if cycle_nodes:
raise FormulaCycleError(sorted(cycle_nodes))
def _extract_dependencies(
self,
node: FormulaNode,
current_table: str,
) -> set[str]:
"""
Recursively extract all column dependency IDs from a formula AST.
Args:
node: The AST node to walk.
current_table: The table containing this formula (for ColumnRef).
Returns:
Set of dependency node IDs (``"table.column"`` format).
"""
deps = set()
if isinstance(node, ColumnRef):
deps.add(self._make_node_id(current_table, node.column))
elif isinstance(node, CrossTableRef):
deps.add(self._make_node_id(node.table, node.column))
# Also depend on the local column used in WHERE clause
if node.where_clause is not None:
deps.add(self._make_node_id(current_table, node.where_clause.local_column))
elif isinstance(node, BinaryOp):
deps.update(self._extract_dependencies(node.left, current_table))
deps.update(self._extract_dependencies(node.right, current_table))
elif isinstance(node, UnaryOp):
deps.update(self._extract_dependencies(node.operand, current_table))
elif isinstance(node, FunctionCall):
for arg in node.arguments:
deps.update(self._extract_dependencies(arg, current_table))
elif isinstance(node, ConditionalExpr):
deps.update(self._extract_dependencies(node.value_expr, current_table))
deps.update(self._extract_dependencies(node.condition, current_table))
if node.else_expr is not None:
deps.update(self._extract_dependencies(node.else_expr, current_table))
return deps

View File

@@ -0,0 +1,180 @@
"""
Autocompletion engine for the DataGrid Formula DSL.
Provides context-aware suggestions for:
- Column names (after ``{``)
- Cross-table references (``{Table.``)
- Built-in function names
- Keywords: ``if``, ``else``, ``and``, ``or``, ``not``, ``WHERE``
"""
import re
from myfasthtml.core.dsl.base_completion import BaseCompletionEngine
from myfasthtml.core.dsl.types import Position, Suggestion
from myfasthtml.core.formula.evaluator import BUILTIN_FUNCTIONS
from myfasthtml.core.utils import make_safe_id
FORMULA_KEYWORDS = [
"if", "else", "and", "or", "not", "where",
"between", "in", "isempty", "isnotempty", "isnan",
"contains", "startswith", "endswith",
"true", "false",
]
class FormulaCompletionEngine(BaseCompletionEngine):
"""
Context-aware completion engine for formula expressions.
Provides suggestions for column references, functions, and keywords.
Args:
provider: DataGrid metadata provider (DataGridsManager or similar).
table_name: Name of the current table in ``"namespace.name"`` format.
"""
def __init__(self, provider, table_name: str):
super().__init__(provider)
self.table_name = table_name
self._id = "formula_completion_engine#" + make_safe_id(table_name)
def detect_scope(self, text: str, current_line: int):
"""Formula has no scope — always the same single-expression scope."""
return None
def detect_context(self, text: str, cursor: Position, scope):
"""
Detect completion context based on cursor position in formula text.
Args:
text: The full formula text.
cursor: Cursor position (line, ch).
scope: Unused (formulas have no scopes).
Returns:
Context string: ``"column_ref"``, ``"cross_table"``,
``"function"``, ``"keyword"``, or ``"general"``.
"""
# Get text up to cursor
lines = text.split("\n")
line_idx = min(cursor.line, len(lines) - 1)
line_text = lines[line_idx]
text_before = line_text[:cursor.ch]
# Check if we are inside a { ... } reference
last_brace = text_before.rfind("{")
if last_brace >= 0:
inside = text_before[last_brace + 1:]
if "}" not in inside:
if "." in inside:
return "cross_table"
return "column_ref"
# Check if we are typing a function name (alphanumeric at word start)
word_match = re.search(r"[a-z_][a-z0-9_]*$", text_before, re.IGNORECASE)
if word_match:
return "function_or_keyword"
return "general"
def get_suggestions(self, text: str, cursor: Position, scope, context) -> list:
"""
Generate suggestions based on the detected context.
Args:
text: The full formula text.
cursor: Cursor position.
scope: Unused.
context: String from ``detect_context``.
Returns:
List of Suggestion objects.
"""
suggestions = []
if context == "column_ref":
# Suggest columns from the current table
suggestions += self._column_suggestions(self.table_name)
elif context == "cross_table":
# Get the table name prefix from text_before
lines = text.split("\n")
line_text = lines[min(cursor.line, len(lines) - 1)]
text_before = line_text[:cursor.ch]
last_brace = text_before.rfind("{")
inside = text_before[last_brace + 1:] if last_brace >= 0 else ""
dot_pos = inside.rfind(".")
table_prefix = inside[:dot_pos] if dot_pos >= 0 else ""
# Suggest columns from the referenced table
if table_prefix:
suggestions += self._column_suggestions(table_prefix)
else:
suggestions += self._table_suggestions()
elif context == "function_or_keyword":
suggestions += self._function_suggestions()
suggestions += self._keyword_suggestions()
else: # general
suggestions += self._function_suggestions()
suggestions += self._keyword_suggestions()
suggestions += [
Suggestion(
label="{",
detail="Column reference",
insert_text="{",
)
]
return suggestions
# ==================== Private helpers ====================
def _column_suggestions(self, table_name: str) -> list:
"""Get column name suggestions for a table."""
try:
columns = self.provider.list_columns(table_name)
return [
Suggestion(
label=col,
detail=f"Column from {table_name}",
insert_text=col,
)
for col in (columns or [])
]
except Exception:
return []
def _table_suggestions(self) -> list:
"""Get table name suggestions."""
try:
tables = self.provider.list_tables()
return [
Suggestion(
label=t,
detail="Table",
insert_text=t,
)
for t in (tables or [])
]
except Exception:
return []
def _function_suggestions(self) -> list:
"""Get built-in function name suggestions."""
return [
Suggestion(
label=name,
detail="Function",
insert_text=f"{name}(",
)
for name in sorted(BUILTIN_FUNCTIONS.keys())
]
def _keyword_suggestions(self) -> list:
"""Get keyword suggestions."""
return [
Suggestion(label=kw, detail="Keyword", insert_text=kw)
for kw in FORMULA_KEYWORDS
]

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,79 @@
"""
FormulaDSL definition for the DslEditor control.
Provides the Lark grammar and derived completions for the
DataGrid Formula DSL (CodeMirror 5 Simple Mode).
"""
from functools import cached_property
from typing import Dict, Any
from myfasthtml.core.dsl.base import DSLDefinition
from .grammar import FORMULA_GRAMMAR
class FormulaDSL(DSLDefinition):
"""
DSL definition for DataGrid formula expressions.
Uses the Lark grammar from grammar.py to drive syntax highlighting
and autocompletion in the DslEditor.
"""
name: str = "Formula DSL"
def get_grammar(self) -> str:
"""Return the Lark grammar for the formula DSL."""
return FORMULA_GRAMMAR
@cached_property
def simple_mode_config(self) -> Dict[str, Any]:
"""
Return a hand-tuned CodeMirror 5 Simple Mode config for formula syntax.
Overrides the base class to provide optimized highlighting rules
for column references, operators, functions, and keywords.
"""
return {
"start": [
# Column references: {ColumnName} or {Table.Column}
{
"regex": r"\{[A-Za-z_][A-Za-z0-9_.]*(?:\s+where\s+[A-Za-z_][A-Za-z0-9_.]*\s*=\s*[A-Za-z_][A-Za-z0-9_]*)?\}",
"token": "variable-2",
},
# Function names before parenthesis
{
"regex": r"[a-z_][a-z0-9_]*(?=\s*\()",
"token": "keyword",
},
# Keywords: if, else, and, or, not, where, between, in
{
"regex": r"\b(if|else|and|or|not|where|between|in|isempty|isnotempty|isnan|contains|startswith|endswith|true|false)\b",
"token": "keyword",
},
# Numbers
{
"regex": r"[+-]?(\d+\.?\d*|\.\d+)([eE][+-]?\d+)?",
"token": "number",
},
# Strings
{
"regex": r'"[^"\\]*"',
"token": "string",
},
# Operators
{
"regex": r"[=!<>]=?|[+\-*/%^]",
"token": "operator",
},
# Parentheses and brackets
{
"regex": r"[()[\],]",
"token": "punctuation",
},
],
"meta": {
"dontIndentStates": ["comment"],
"lineComment": "#",
},
}

View File

@@ -0,0 +1,35 @@
class FormulaError(Exception):
"""Base exception for formula errors."""
pass
class FormulaSyntaxError(FormulaError):
"""Raised when the formula has syntax errors."""
def __init__(self, message, line=None, column=None, context=None):
self.message = message
self.line = line
self.column = column
self.context = context
super().__init__(self._format_message())
def _format_message(self):
parts = [self.message]
if self.line is not None:
parts.append(f"at line {self.line}")
if self.column is not None:
parts.append(f"col {self.column}")
return " ".join(parts)
class FormulaValidationError(FormulaError):
"""Raised when the formula is syntactically correct but semantically invalid."""
pass
class FormulaCycleError(FormulaError):
"""Raised when formula dependencies contain a cycle."""
def __init__(self, cycle_nodes):
self.cycle_nodes = cycle_nodes
super().__init__(f"Circular dependency detected involving: {', '.join(cycle_nodes)}")

View File

@@ -0,0 +1,100 @@
FORMULA_GRAMMAR = r"""
start: expression
// ==================== Top-level expression ====================
?expression: conditional_expr
// Suffix-if: value_expr if condition [else expression]
// Right-associative for chaining: a if c1 else b if c2 else d
?conditional_expr: or_expr "if" or_expr "else" conditional_expr -> conditional_with_else
| or_expr "if" or_expr -> conditional_no_else
| or_expr
// ==================== Logical ====================
?or_expr: and_expr ("or" and_expr)* -> or_op
?and_expr: not_expr ("and" not_expr)* -> and_op
?not_expr: "not" not_expr -> not_op
| comparison
// ==================== Comparison ====================
?comparison: addition comp_op addition -> comparison_expr
| addition "in" "[" literal ("," literal)* "]" -> in_expr
| addition "between" addition "and" addition -> between_expr
| addition "contains" addition -> contains_expr
| addition "startswith" addition -> startswith_expr
| addition "endswith" addition -> endswith_expr
| addition "isempty" -> isempty_expr
| addition "isnotempty" -> isnotempty_expr
| addition "isnan" -> isnan_expr
| addition
comp_op: "==" -> eq
| "!=" -> ne
| "<=" -> le
| "<" -> lt
| ">=" -> ge
| ">" -> gt
// ==================== Arithmetic ====================
?addition: multiplication (add_op multiplication)* -> add_expr
?multiplication: power (mul_op power)* -> mul_expr
?power: unary ("^" unary)* -> pow_expr
add_op: "+" -> plus
| "-" -> minus
mul_op: "*" -> times
| "/" -> divide
| "%" -> modulo
?unary: "-" unary -> neg
| atom
// ==================== Atoms ====================
?atom: function_call
| cross_table_ref
| column_ref
| literal
| "(" expression ")" -> paren
// ==================== References ====================
// Cross-table must be checked before column_ref since both use { }
// TABLE_NAME.COL_NAME with optional WHERE clause
// Note: whitespace around "where" and "=" is handled by %ignore
cross_table_ref: "{" TABLE_COL_REF "}" -> cross_ref_simple
| "{" TABLE_COL_REF "where" where_clause "}" -> cross_ref_where
column_ref: "{" COL_NAME "}"
where_clause: TABLE_COL_REF "=" COL_NAME
// TABLE_COL_REF matches "TableName.ColumnName" (dot-separated, no spaces)
TABLE_COL_REF: /[A-Za-z_][A-Za-z0-9_]*\.[A-Za-z_][A-Za-z0-9_]*/
COL_NAME: /[A-Za-z_][A-Za-z0-9_ ]*/
// ==================== Functions ====================
function_call: FUNC_NAME "(" [expression ("," expression)*] ")"
FUNC_NAME: /[a-z_][a-z0-9_]*/
// ==================== Literals ====================
?literal: NUMBER -> number_literal
| ESCAPED_STRING -> string_literal
| "true"i -> true_literal
| "false"i -> false_literal
// ==================== Terminals ====================
NUMBER: /[+-]?(\d+\.?\d*|\.\d+)([eE][+-]?\d+)?/
ESCAPED_STRING: "\"" /[^"\\]*/ "\""
%ignore /[ \t\f]+/
%ignore /\#[^\n]*/
"""

View File

@@ -0,0 +1,85 @@
"""
Formula DSL parser using Lark.
Handles parsing of formula expression strings into a Lark AST.
No indentation handling needed — formulas are single-line expressions.
"""
from lark import Lark, UnexpectedInput
from .exceptions import FormulaSyntaxError
from .grammar import FORMULA_GRAMMAR
class FormulaParser:
"""
Parser for the DataGrid formula language.
Uses Lark LALR parser without indentation handling.
Example:
parser = FormulaParser()
tree = parser.parse("{Price} * {Quantity}")
"""
def __init__(self):
self._parser = Lark(
FORMULA_GRAMMAR,
parser="lalr",
propagate_positions=False,
)
def parse(self, text: str):
"""
Parse a formula expression string into a Lark Tree.
Args:
text: The formula expression text.
Returns:
lark.Tree: The parsed AST.
Raises:
FormulaSyntaxError: If the text has syntax errors.
"""
text = text.strip()
if not text:
return None
try:
return self._parser.parse(text)
except UnexpectedInput as e:
context = None
if hasattr(e, "get_context"):
context = e.get_context(text)
raise FormulaSyntaxError(
message=self._format_error_message(e),
line=getattr(e, "line", None),
column=getattr(e, "column", None),
context=context,
) from e
def _format_error_message(self, error: UnexpectedInput) -> str:
"""Format a user-friendly error message from a Lark exception."""
if hasattr(error, "expected"):
expected = list(error.expected)
if len(expected) == 1:
return f"Expected {expected[0]}"
elif len(expected) <= 5:
return f"Expected one of: {', '.join(expected)}"
else:
return "Unexpected input"
return str(error)
# Singleton parser instance
_parser_instance = None
def get_parser() -> FormulaParser:
"""Get the singleton FormulaParser instance."""
global _parser_instance
if _parser_instance is None:
_parser_instance = FormulaParser()
return _parser_instance

View File

@@ -0,0 +1,274 @@
"""
Formula DSL Transformer.
Converts a Lark AST tree into FormulaDefinition and related AST dataclasses.
"""
from lark import Transformer
from ..dataclasses import (
FormulaDefinition,
FormulaNode,
LiteralNode,
ColumnRef,
CrossTableRef,
WhereClause,
BinaryOp,
UnaryOp,
FunctionCall,
ConditionalExpr,
)
class FormulaTransformer(Transformer):
"""
Transforms the Lark parse tree into FormulaDefinition AST dataclasses.
Handles left-associative folding for arithmetic and logical operators.
Handles right-associative chaining for conditional expressions.
"""
# ==================== Top-level ====================
def start(self, items):
"""Return the FormulaDefinition wrapping the single expression."""
expr = items[0]
return FormulaDefinition(expression=expr)
# ==================== Conditionals ====================
def conditional_with_else(self, items):
"""value_expr if condition else else_expr"""
value_expr, condition, else_expr = items
return ConditionalExpr(
value_expr=value_expr,
condition=condition,
else_expr=else_expr,
)
def conditional_no_else(self, items):
"""value_expr if condition"""
value_expr, condition = items
return ConditionalExpr(
value_expr=value_expr,
condition=condition,
else_expr=None,
)
# ==================== Logical ====================
def or_op(self, items):
"""Fold left-associatively: a or b or c -> BinaryOp(or, BinaryOp(or, a, b), c)"""
return self._fold_left(items, "or")
def and_op(self, items):
"""Fold left-associatively: a and b and c -> BinaryOp(and, BinaryOp(and, a, b), c)"""
return self._fold_left(items, "and")
def not_op(self, items):
"""not expr"""
return UnaryOp(operator="not", operand=items[0])
# ==================== Comparisons ====================
def comparison_expr(self, items):
"""left comp_op right"""
left, op, right = items
return BinaryOp(operator=op, left=left, right=right)
def in_expr(self, items):
"""operand in [literal, ...]"""
operand = items[0]
values = list(items[1:])
return BinaryOp(operator="in", left=operand, right=LiteralNode(value=values))
def between_expr(self, items):
"""operand between low and high"""
operand, low, high = items
return BinaryOp(
operator="between",
left=operand,
right=LiteralNode(value=[low, high]),
)
def contains_expr(self, items):
left, right = items
return BinaryOp(operator="contains", left=left, right=right)
def startswith_expr(self, items):
left, right = items
return BinaryOp(operator="startswith", left=left, right=right)
def endswith_expr(self, items):
left, right = items
return BinaryOp(operator="endswith", left=left, right=right)
def isempty_expr(self, items):
return UnaryOp(operator="isempty", operand=items[0])
def isnotempty_expr(self, items):
return UnaryOp(operator="isnotempty", operand=items[0])
def isnan_expr(self, items):
return UnaryOp(operator="isnan", operand=items[0])
# ==================== Comparison operators ====================
def eq(self, items):
return "=="
def ne(self, items):
return "!="
def le(self, items):
return "<="
def lt(self, items):
return "<"
def ge(self, items):
return ">="
def gt(self, items):
return ">"
# ==================== Arithmetic ====================
def add_expr(self, items):
"""Fold left-associatively with alternating operands and operators."""
return self._fold_binary_with_ops(items)
def mul_expr(self, items):
"""Fold left-associatively with alternating operands and operators."""
return self._fold_binary_with_ops(items)
def pow_expr(self, items):
"""Fold left-associatively for power expressions (^ is left-assoc here)."""
# pow_expr items are [base, exp1, exp2, ...] — operator is always "^"
result = items[0]
for exp in items[1:]:
result = BinaryOp(operator="^", left=result, right=exp)
return result
def plus(self, items):
return "+"
def minus(self, items):
return "-"
def times(self, items):
return "*"
def divide(self, items):
return "/"
def modulo(self, items):
return "%"
def neg(self, items):
"""Unary negation: -expr"""
return UnaryOp(operator="-", operand=items[0])
# ==================== References ====================
def cross_ref_simple(self, items):
"""{ Table.Column }"""
table_col = str(items[0])
table, column = table_col.split(".", 1)
return CrossTableRef(table=table, column=column)
def cross_ref_where(self, items):
"""{ Table.Column WHERE remote_table.remote_col = local_col }"""
table_col = str(items[0])
where = items[1]
table, column = table_col.split(".", 1)
return CrossTableRef(table=table, column=column, where_clause=where)
def column_ref(self, items):
"""{ ColumnName }"""
col_name = str(items[0]).strip()
return ColumnRef(column=col_name)
def where_clause(self, items):
"""TABLE_COL_REF = COL_NAME"""
remote_table_col = str(items[0])
local_col = str(items[1]).strip()
remote_table, remote_col = remote_table_col.split(".", 1)
return WhereClause(
remote_table=remote_table,
remote_column=remote_col,
local_column=local_col,
)
# ==================== Functions ====================
def function_call(self, items):
"""func_name(arg1, arg2, ...)"""
func_name = str(items[0]).lower()
args = list(items[1:])
return FunctionCall(function_name=func_name, arguments=args)
# ==================== Literals ====================
def number_literal(self, items):
value = str(items[0])
if "." in value or "e" in value.lower():
return LiteralNode(value=float(value))
try:
return LiteralNode(value=int(value))
except ValueError:
return LiteralNode(value=float(value))
def string_literal(self, items):
raw = str(items[0])
# Remove surrounding double quotes
if raw.startswith('"') and raw.endswith('"'):
return LiteralNode(value=raw[1:-1])
return LiteralNode(value=raw)
def true_literal(self, items):
return LiteralNode(value=True)
def false_literal(self, items):
return LiteralNode(value=False)
def paren(self, items):
"""Parenthesized expression — transparent pass-through."""
return items[0]
# ==================== Helpers ====================
def _fold_left(self, items: list, op: str) -> FormulaNode:
"""
Fold a list of operands left-associatively with a fixed operator.
Args:
items: List of FormulaNode operands.
op: Operator string.
Returns:
Left-folded BinaryOp tree, or the single item if only one.
"""
result = items[0]
for operand in items[1:]:
result = BinaryOp(operator=op, left=result, right=operand)
return result
def _fold_binary_with_ops(self, items: list) -> FormulaNode:
"""
Fold a list of alternating [operand, op, operand, op, operand, ...]
left-associatively.
Args:
items: Alternating list: [expr, op_str, expr, op_str, expr, ...]
Returns:
Left-folded BinaryOp tree.
"""
result = items[0]
i = 1
while i < len(items):
op = items[i]
right = items[i + 1]
result = BinaryOp(operator=op, left=result, right=right)
i += 2
return result

View File

@@ -0,0 +1,398 @@
"""
Formula Engine — facade orchestrating parsing, DAG, and evaluation.
Coordinates:
- Parsing formula text via the DSL parser
- Registering formulas and their dependencies in the DependencyGraph
- Evaluating dirty formula columns row-by-row via FormulaEvaluator
- Updating ns_fast_access caches in the DatagridStore
"""
import logging
from typing import Any, Callable, Optional
import numpy as np
from .dataclasses import FormulaDefinition, WhereClause
from .dependency_graph import DependencyGraph
from .dsl.parser import get_parser
from .dsl.transformer import FormulaTransformer
from .evaluator import FormulaEvaluator
logger = logging.getLogger("FormulaEngine")
# Callback that returns a DatagridStore-like object for a given table name
RegistryResolver = Callable[[str], Any]
def parse_formula(text: str) -> FormulaDefinition | None:
"""Parse a formula expression string into a FormulaDefinition AST.
Args:
text: The formula expression string.
Returns:
FormulaDefinition on success, None if text is empty.
Raises:
FormulaSyntaxError: If the formula text is syntactically invalid.
"""
text = text.strip() if text else ""
if not text:
return None
parser = get_parser()
tree = parser.parse(text)
if tree is None:
return None
transformer = FormulaTransformer()
formula = transformer.transform(tree)
formula.source_text = text
return formula
class FormulaEngine:
"""
Facade for the formula calculation system.
Orchestrates formula parsing, dependency tracking, and incremental
recalculation of formula columns.
Args:
registry_resolver: Callback that takes a table name and returns
the DatagridStore for that table (used for cross-table refs).
Provided by DataGridsManager.
"""
def __init__(self, registry_resolver: Optional[RegistryResolver] = None):
self._graph = DependencyGraph()
self._registry_resolver = registry_resolver
# Cache of parsed formulas: {(table, col): FormulaDefinition}
self._formulas: dict[tuple[str, str], FormulaDefinition] = {}
def set_formula(self, table: str, col: str, formula_text: str) -> None:
"""
Parse and register a formula for a column.
Args:
table: Table name.
col: Column name.
formula_text: The formula expression string.
Raises:
FormulaSyntaxError: If the formula is syntactically invalid.
FormulaCycleError: If the formula would create a circular dependency.
"""
formula_text = formula_text.strip() if formula_text else ""
if not formula_text:
self.remove_formula(table, col)
return
formula = parse_formula(formula_text)
if formula is None:
self.remove_formula(table, col)
return
# Registers in DAG and raises FormulaCycleError if cycle detected
self._graph.add_formula(table, col, formula)
self._formulas[(table, col)] = formula
logger.debug("Formula set for %s.%s: %s", table, col, formula_text)
def remove_formula(self, table: str, col: str) -> None:
"""
Remove a formula column from the engine.
Args:
table: Table name.
col: Column name.
"""
self._graph.remove_formula(table, col)
self._formulas.pop((table, col), None)
def mark_data_changed(
self,
table: str,
col: str,
rows: Optional[list[int]] = None,
) -> None:
"""
Mark a column's data as changed, propagating dirty flags.
Call this when source data is modified so that dependent formula
columns are re-evaluated on next render.
Args:
table: Table name.
col: Column name.
rows: Specific row indices that changed. None means all rows.
"""
self._graph.mark_dirty(table, col, rows)
def recalculate_if_needed(self, table: str, store: Any) -> bool:
"""
Recalculate all dirty formula columns for a table.
Should be called at the start of ``mk_body_content_page()`` to
ensure formula columns are up-to-date before rendering.
Updates ``store.ns_fast_access`` and ``store.ns_row_data`` in place.
Args:
table: Table name.
store: The DatagridStore instance for this table.
Returns:
True if any columns were recalculated, False otherwise.
"""
dirty_nodes = self._graph.get_calculation_order(table=table)
if not dirty_nodes:
return False
for node in dirty_nodes:
formula = node.formula
if formula is None:
continue
self._evaluate_column(table, node.column, formula, store)
self._graph.clear_dirty(node.node_id)
# Rebuild ns_row_data after recalculation
if dirty_nodes and store.ns_fast_access:
self._rebuild_row_data(store)
return True
def has_formula(self, table: str, col: str) -> bool:
"""
Check if a column has a formula registered.
Args:
table: Table name.
col: Column name.
Returns:
True if the column has a registered formula.
"""
return self._graph.has_formula(table, col)
def get_formula_text(self, table: str, col: str) -> Optional[str]:
"""
Get the source text of a registered formula.
Args:
table: Table name.
col: Column name.
Returns:
Formula source text or None if not registered.
"""
formula = self._formulas.get((table, col))
return formula.source_text if formula else None
# ==================== Private helpers ====================
def _evaluate_column(
self,
table: str,
col: str,
formula: FormulaDefinition,
store: Any,
) -> None:
"""
Evaluate a formula column row-by-row and update ns_fast_access.
Args:
table: Table name.
col: Column name.
formula: The parsed FormulaDefinition.
store: The DatagridStore with ns_fast_access and ns_row_data.
"""
if store.ns_row_data is None or len(store.ns_row_data) == 0:
return
n_rows = len(store.ns_row_data)
resolver = self._make_cross_table_resolver(table)
evaluator = FormulaEvaluator(cross_table_resolver=resolver)
# Ensure ns_fast_access exists before the loop so that formula columns
# evaluated earlier in the same pass are visible to subsequent columns.
if store.ns_fast_access is None:
store.ns_fast_access = {}
results = np.empty(n_rows, dtype=object)
for row_index in range(n_rows):
# Build row_data from ns_fast_access so that formula columns evaluated
# earlier in this pass (e.g. B) are available to dependent columns (e.g. C).
row_data = {
c: arr[row_index]
for c, arr in store.ns_fast_access.items()
if arr is not None and row_index < len(arr)
}
results[row_index] = evaluator.evaluate(formula, row_data, row_index)
store.ns_fast_access[col] = results
logger.debug("Evaluated formula column %s.%s (%d rows)", table, col, n_rows)
def _rebuild_row_data(self, store: Any) -> None:
"""
Rebuild ns_row_data to include formula column results.
This ensures formula values are available to dependent formulas
in subsequent evaluation passes.
Args:
store: The DatagridStore to update.
"""
if store.ns_fast_access is None:
return
n_rows = len(store.ns_row_data)
for row_index in range(n_rows):
row = store.ns_row_data[row_index]
for col, arr in store.ns_fast_access.items():
if arr is not None and row_index < len(arr):
row[col] = arr[row_index]
def _make_cross_table_resolver(self, current_table: str):
"""
Create a cross-table resolver callback for the given table context.
Resolution strategy:
1. Explicit WHERE clause: scan remote column for matching rows.
2. Implicit join by ``id`` column: match rows where both tables share
the same id value.
3. Fallback: match by row_index.
Args:
current_table: The table that contains the formula.
Returns:
A callable ``resolver(table, column, where_clause, row_index) -> value``.
"""
def resolver(
remote_table: str,
remote_column: str,
where_clause: Optional[WhereClause],
row_index: int,
) -> Any:
if self._registry_resolver is None:
logger.warning(
"No registry_resolver set for cross-table ref %s.%s",
remote_table, remote_column,
)
return None
remote_store = self._registry_resolver(remote_table)
if remote_store is None:
logger.warning("Table '%s' not found in registry", remote_table)
return None
ns = remote_store.ns_fast_access
if not ns or remote_column not in ns:
logger.debug(
"Column '%s' not found in table '%s'", remote_column, remote_table
)
return None
remote_array = ns[remote_column]
# Strategy 1: Explicit WHERE clause
if where_clause is not None:
return self._resolve_with_where(
where_clause, remote_store, remote_column,
remote_array, current_table, row_index,
)
# Strategy 2: Implicit join by 'id' column
current_store = self._registry_resolver(current_table)
if (
current_store is not None
and current_store.ns_fast_access is not None
and "id" in current_store.ns_fast_access
and "id" in ns
):
local_id_arr = current_store.ns_fast_access["id"]
remote_id_arr = ns["id"]
if row_index < len(local_id_arr):
local_id = local_id_arr[row_index]
# Find first matching row in remote table
matches = np.where(remote_id_arr == local_id)[0]
if len(matches) > 0:
return remote_array[matches[0]]
return None
# Strategy 3: Fallback — match by row_index
if row_index < len(remote_array):
return remote_array[row_index]
return None
return resolver
def _resolve_with_where(
self,
where_clause: WhereClause,
remote_store: Any,
remote_column: str,
remote_array: Any,
current_table: str,
row_index: int,
) -> Any:
"""
Resolve a cross-table reference using an explicit WHERE clause.
Args:
where_clause: The parsed WHERE clause.
remote_store: DatagridStore for the remote table.
remote_column: Column to return value from.
remote_array: numpy array of the remote column values.
current_table: Table containing the formula.
row_index: Current row being evaluated.
Returns:
The value from the first matching remote row, or None.
"""
remote_ns = remote_store.ns_fast_access
if not remote_ns:
return None
# Get the remote key column array
remote_key_col = where_clause.remote_column
if remote_key_col not in remote_ns:
logger.debug(
"WHERE key column '%s' not found in remote table", remote_key_col
)
return None
remote_key_array = remote_ns[remote_key_col]
# Get the local value to compare
current_store = self._registry_resolver(current_table) if self._registry_resolver else None
if current_store is None or current_store.ns_fast_access is None:
return None
local_col = where_clause.local_column
if local_col not in current_store.ns_fast_access:
logger.debug("WHERE local column '%s' not found", local_col)
return None
local_array = current_store.ns_fast_access[local_col]
if row_index >= len(local_array):
return None
local_value = local_array[row_index]
# Find matching rows
try:
matches = np.where(remote_key_array == local_value)[0]
except Exception:
matches = []
if len(matches) == 0:
return None
# Return value from first match (use aggregation functions for multi-row)
return remote_array[matches[0]]

View File

@@ -0,0 +1,522 @@
"""
Formula Evaluator.
Evaluates a FormulaDefinition AST row-by-row using column data.
"""
import logging
import math
from datetime import date, datetime
from typing import Any, Callable, Optional
from .dataclasses import (
FormulaNode,
FormulaDefinition,
LiteralNode,
ColumnRef,
CrossTableRef,
BinaryOp,
UnaryOp,
FunctionCall,
ConditionalExpr,
)
logger = logging.getLogger("FormulaEvaluator")
# Type alias for the cross-table resolver callback
CrossTableResolver = Callable[[str, str, Optional[object], int], Any]
def _safe_numeric(value) -> Optional[float]:
"""Convert value to float, returning None if not possible."""
if value is None:
return None
try:
return float(value)
except (TypeError, ValueError):
return None
# ==================== Built-in function registry ====================
def _fn_round(args):
if len(args) < 1:
return None
value = args[0]
decimals = int(args[1]) if len(args) > 1 else 0
v = _safe_numeric(value)
return round(v, decimals) if v is not None else None
def _fn_abs(args):
v = _safe_numeric(args[0]) if args else None
return abs(v) if v is not None else None
def _fn_min(args):
nums = [_safe_numeric(a) for a in args]
nums = [n for n in nums if n is not None]
return min(nums) if nums else None
def _fn_max(args):
nums = [_safe_numeric(a) for a in args]
nums = [n for n in nums if n is not None]
return max(nums) if nums else None
def _fn_floor(args):
v = _safe_numeric(args[0]) if args else None
return math.floor(v) if v is not None else None
def _fn_ceil(args):
v = _safe_numeric(args[0]) if args else None
return math.ceil(v) if v is not None else None
def _fn_sqrt(args):
v = _safe_numeric(args[0]) if args else None
return math.sqrt(v) if v is not None and v >= 0 else None
def _fn_sum(args):
"""Sum of all arguments (used for inline multi-value, not aggregation)."""
nums = [_safe_numeric(a) for a in args]
nums = [n for n in nums if n is not None]
return sum(nums) if nums else None
def _fn_avg(args):
nums = [_safe_numeric(a) for a in args]
nums = [n for n in nums if n is not None]
return sum(nums) / len(nums) if nums else None
def _fn_len(args):
v = args[0] if args else None
if v is None:
return None
return len(str(v))
def _fn_upper(args):
v = args[0] if args else None
return str(v).upper() if v is not None else None
def _fn_lower(args):
v = args[0] if args else None
return str(v).lower() if v is not None else None
def _fn_trim(args):
v = args[0] if args else None
return str(v).strip() if v is not None else None
def _fn_left(args):
if len(args) < 2:
return None
v, n = args[0], args[1]
if v is None or n is None:
return None
return str(v)[:int(n)]
def _fn_right(args):
if len(args) < 2:
return None
v, n = args[0], args[1]
if v is None or n is None:
return None
n = int(n)
return str(v)[-n:] if n > 0 else ""
def _fn_concat(args):
return "".join(str(a) if a is not None else "" for a in args)
def _fn_year(args):
v = args[0] if args else None
if v is None:
return None
if isinstance(v, (datetime, date)):
return v.year
try:
return datetime.fromisoformat(str(v)).year
except (ValueError, TypeError):
return None
def _fn_month(args):
v = args[0] if args else None
if v is None:
return None
if isinstance(v, (datetime, date)):
return v.month
try:
return datetime.fromisoformat(str(v)).month
except (ValueError, TypeError):
return None
def _fn_day(args):
v = args[0] if args else None
if v is None:
return None
if isinstance(v, (datetime, date)):
return v.day
try:
return datetime.fromisoformat(str(v)).day
except (ValueError, TypeError):
return None
def _fn_today(args):
return date.today()
def _fn_datediff(args):
if len(args) < 2:
return None
d1, d2 = args[0], args[1]
if d1 is None or d2 is None:
return None
try:
if not isinstance(d1, (datetime, date)):
d1 = datetime.fromisoformat(str(d1))
if not isinstance(d2, (datetime, date)):
d2 = datetime.fromisoformat(str(d2))
delta = d1 - d2
return delta.days
except (ValueError, TypeError):
return None
def _fn_coalesce(args):
for a in args:
if a is not None:
return a
return None
def _fn_if_error(args):
# if_error(expr, fallback) - expr already evaluated, error would be None
if len(args) < 2:
return args[0] if args else None
return args[0] if args[0] is not None else args[1]
def _fn_count(args):
"""Count non-None values (used for aggregation results)."""
return sum(1 for a in args if a is not None)
# ==================== Function registry ====================
BUILTIN_FUNCTIONS: dict[str, Callable] = {
# Math
"round": _fn_round,
"abs": _fn_abs,
"min": _fn_min,
"max": _fn_max,
"floor": _fn_floor,
"ceil": _fn_ceil,
"sqrt": _fn_sqrt,
"sum": _fn_sum,
"avg": _fn_avg,
# Text
"len": _fn_len,
"upper": _fn_upper,
"lower": _fn_lower,
"trim": _fn_trim,
"left": _fn_left,
"right": _fn_right,
"concat": _fn_concat,
# Date
"year": _fn_year,
"month": _fn_month,
"day": _fn_day,
"today": _fn_today,
"datediff": _fn_datediff,
# Utility
"coalesce": _fn_coalesce,
"if_error": _fn_if_error,
"count": _fn_count,
}
# ==================== Evaluator ====================
class FormulaEvaluator:
"""
Row-by-row formula evaluator.
Evaluates a FormulaDefinition AST against a single row of data.
Args:
cross_table_resolver: Optional callback for cross-table references.
Signature: resolver(table, column, where_clause, row_index) -> value
"""
def __init__(self, cross_table_resolver: Optional[CrossTableResolver] = None):
self._cross_table_resolver = cross_table_resolver
def evaluate(
self,
formula: FormulaDefinition,
row_data: dict,
row_index: int,
) -> Any:
"""
Evaluate a formula for a single row.
Args:
formula: The parsed FormulaDefinition AST.
row_data: Dict mapping column_id -> value for the current row.
row_index: The integer index of the current row.
Returns:
The computed value, or None on error.
"""
try:
return self._eval(formula.expression, row_data, row_index)
except Exception as exc:
logger.warning(
"Formula evaluation error at row %d: %s", row_index, exc
)
return None
def _eval(self, node: FormulaNode, row_data: dict, row_index: int) -> Any:
"""
Recursively evaluate an AST node.
Args:
node: The AST node to evaluate.
row_data: Current row data dict.
row_index: Current row index.
Returns:
Evaluated value.
"""
if isinstance(node, LiteralNode):
return node.value
if isinstance(node, ColumnRef):
return self._resolve_column(node.column, row_data)
if isinstance(node, CrossTableRef):
return self._resolve_cross_table(node, row_index)
if isinstance(node, BinaryOp):
return self._eval_binary(node, row_data, row_index)
if isinstance(node, UnaryOp):
return self._eval_unary(node, row_data, row_index)
if isinstance(node, FunctionCall):
return self._eval_function(node, row_data, row_index)
if isinstance(node, ConditionalExpr):
return self._eval_conditional(node, row_data, row_index)
logger.warning("Unknown AST node type: %s", type(node).__name__)
return None
def _resolve_column(self, column_name: str, row_data: dict) -> Any:
"""Resolve a column reference in the current row."""
if column_name in row_data:
return row_data[column_name]
# Try case-insensitive match
lower_name = column_name.lower()
for key, value in row_data.items():
if str(key).lower() == lower_name:
return value
logger.debug("Column '%s' not found in row_data", column_name)
return None
def _resolve_cross_table(self, node: CrossTableRef, row_index: int) -> Any:
"""Resolve a cross-table reference."""
if self._cross_table_resolver is None:
logger.warning(
"No cross_table_resolver set for cross-table ref %s.%s",
node.table, node.column,
)
return None
return self._cross_table_resolver(
node.table, node.column, node.where_clause, row_index
)
def _eval_binary(self, node: BinaryOp, row_data: dict, row_index: int) -> Any:
"""Evaluate a binary operation."""
left = self._eval(node.left, row_data, row_index)
op = node.operator
# Short-circuit for logical operators
if op == "and":
if not self._truthy(left):
return False
right = self._eval(node.right, row_data, row_index)
return self._truthy(left) and self._truthy(right)
if op == "or":
if self._truthy(left):
return True
right = self._eval(node.right, row_data, row_index)
return self._truthy(left) or self._truthy(right)
right = self._eval(node.right, row_data, row_index)
# Arithmetic
if op == "+":
if isinstance(left, str) or isinstance(right, str):
return str(left or "") + str(right or "")
return self._num_op(left, right, lambda a, b: a + b)
if op == "-":
return self._num_op(left, right, lambda a, b: a - b)
if op == "*":
return self._num_op(left, right, lambda a, b: a * b)
if op == "/":
if right == 0 or right == 0.0:
return None # Division by zero -> None
return self._num_op(left, right, lambda a, b: a / b)
if op == "%":
if right == 0 or right == 0.0:
return None
return self._num_op(left, right, lambda a, b: a % b)
if op == "^":
return self._num_op(left, right, lambda a, b: a ** b)
# Comparison
if op == "==":
return left == right
if op == "!=":
return left != right
if op == "<":
return self._compare(left, right) < 0
if op == "<=":
return self._compare(left, right) <= 0
if op == ">":
return self._compare(left, right) > 0
if op == ">=":
return self._compare(left, right) >= 0
# String operations
if op == "contains":
return str(right) in str(left) if left is not None else False
if op == "startswith":
return str(left).startswith(str(right)) if left is not None else False
if op == "endswith":
return str(left).endswith(str(right)) if left is not None else False
# Collection operations
if op == "in":
values = right.value if isinstance(right, LiteralNode) else right
if isinstance(values, list):
return left in [v.value if isinstance(v, LiteralNode) else v for v in values]
return left in (values or [])
if op == "between":
values = right.value if isinstance(right, LiteralNode) else right
if isinstance(values, list) and len(values) == 2:
lo = values[0].value if isinstance(values[0], LiteralNode) else values[0]
hi = values[1].value if isinstance(values[1], LiteralNode) else values[1]
return lo <= left <= hi
return None
logger.warning("Unknown binary operator: %s", op)
return None
def _eval_unary(self, node: UnaryOp, row_data: dict, row_index: int) -> Any:
"""Evaluate a unary operation."""
operand = self._eval(node.operand, row_data, row_index)
op = node.operator
if op == "-":
v = _safe_numeric(operand)
return -v if v is not None else None
if op == "not":
return not self._truthy(operand)
if op == "isempty":
return operand is None or operand == "" or operand == []
if op == "isnotempty":
return operand is not None and operand != "" and operand != []
if op == "isnan":
try:
return math.isnan(float(operand))
except (TypeError, ValueError):
return False
logger.warning("Unknown unary operator: %s", op)
return None
def _eval_function(self, node: FunctionCall, row_data: dict, row_index: int) -> Any:
"""Evaluate a function call."""
name = node.function_name.lower()
args = [self._eval(arg, row_data, row_index) for arg in node.arguments]
if name in BUILTIN_FUNCTIONS:
try:
return BUILTIN_FUNCTIONS[name](args)
except Exception as exc:
logger.warning("Function '%s' error: %s", name, exc)
return None
logger.warning("Unknown function: %s", name)
return None
def _eval_conditional(self, node: ConditionalExpr, row_data: dict, row_index: int) -> Any:
"""Evaluate a conditional expression."""
condition = self._eval(node.condition, row_data, row_index)
if self._truthy(condition):
return self._eval(node.value_expr, row_data, row_index)
elif node.else_expr is not None:
return self._eval(node.else_expr, row_data, row_index)
return None
@staticmethod
def _truthy(value: Any) -> bool:
"""Convert a value to boolean for conditional evaluation."""
if value is None:
return False
if isinstance(value, bool):
return value
if isinstance(value, (int, float)):
return value != 0
if isinstance(value, str):
return len(value) > 0
return bool(value)
@staticmethod
def _num_op(left: Any, right: Any, fn: Callable) -> Any:
"""Apply a numeric binary function, returning None if inputs are non-numeric."""
a = _safe_numeric(left)
b = _safe_numeric(right)
if a is None or b is None:
return None
try:
return fn(a, b)
except (ZeroDivisionError, OverflowError, ValueError):
return None
@staticmethod
def _compare(left: Any, right: Any) -> int:
"""
Compare two values, returning -1, 0, or 1.
Handles mixed numeric/string comparisons gracefully.
"""
try:
if left < right:
return -1
elif left > right:
return 1
return 0
except TypeError:
# Fallback: compare as strings
sl, sr = str(left), str(right)
if sl < sr:
return -1
elif sl > sr:
return 1
return 0

View File

@@ -243,7 +243,7 @@ class TestConflictResolution:
css, formatted = engine.apply_format(rules, cell_value=150, row_data=row_data) css, formatted = engine.apply_format(rules, cell_value=150, row_data=row_data)
assert isinstance(css, StyleContainer) assert isinstance(css, StyleContainer)
assert "var(--color-secondary)" in css.css # Style from Rule 2 assert css.cls == "mf-formatting-secondary" # Style from Rule 2
assert formatted == "150.00 €" # Formatter from Rule 1 assert formatted == "150.00 €" # Formatter from Rule 1
# Case 2: Condition not met (value <= budget) # Case 2: Condition not met (value <= budget)
@@ -282,7 +282,7 @@ class TestConflictResolution:
css, formatted = engine.apply_format(rules, cell_value=-5.67) css, formatted = engine.apply_format(rules, cell_value=-5.67)
assert "var(--color-error)" in css.css # Rule 3 wins for style assert css.cls == "mf-formatting-error" # Rule 3 wins for style
assert formatted == "-6 €" # Rule 4 wins for formatter (precision=0) assert formatted == "-6 €" # Rule 4 wins for formatter (precision=0)
@@ -316,7 +316,7 @@ class TestWithRowData:
css, _ = engine.apply_format(rules, cell_value=42, row_data=row_data) css, _ = engine.apply_format(rules, cell_value=42, row_data=row_data)
assert isinstance(css, StyleContainer) assert isinstance(css, StyleContainer)
assert "background-color" in css.css assert css.cls == "mf-formatting-error"
class TestPresets: class TestPresets:
@@ -327,7 +327,7 @@ class TestPresets:
css, _ = engine.apply_format(rules, cell_value=42) css, _ = engine.apply_format(rules, cell_value=42)
assert "var(--color-success)" in css.css assert css.cls == "mf-formatting-success"
def test_formatter_preset(self): def test_formatter_preset(self):
"""Formatter preset is resolved correctly.""" """Formatter preset is resolved correctly."""

View File

@@ -16,9 +16,14 @@ class TestResolve:
assert result["font-weight"] == "bold" assert result["font-weight"] == "bold"
def test_resolve_preset_with_override(self): def test_resolve_preset_with_override(self):
"""Preset properties can be overridden by explicit values.""" """Preset CSS properties can be overridden by explicit values."""
resolver = StyleResolver() custom_presets = {
# "success" preset has background and color defined "success": {
"background-color": "var(--color-success)",
"color": "var(--color-success-content)",
}
}
resolver = StyleResolver(style_presets=custom_presets)
style = Style(preset="success", color="black") style = Style(preset="success", color="black")
result = resolver.resolve(style) result = resolver.resolve(style)
@@ -66,6 +71,16 @@ class TestResolve:
assert result == {} assert result == {}
def test_i_can_resolve_class_only_preset(self):
"""Default preset with __class__ only is included as-is in resolve result."""
resolver = StyleResolver()
style = Style(preset="success")
result = resolver.resolve(style)
assert result["__class__"] == "mf-formatting-success"
assert "background-color" not in result
assert "color" not in result
def test_resolve_converts_property_names(self): def test_resolve_converts_property_names(self):
"""Python attribute names are converted to CSS property names.""" """Python attribute names are converted to CSS property names."""
resolver = StyleResolver() resolver = StyleResolver()
@@ -151,11 +166,11 @@ class TestToStyleContainer:
None, None,
["background-color: red", "color: white"] ["background-color: red", "color: white"]
), ),
# Class only via preset # Class only via preset (default presets use __class__, no inline CSS)
( (
Style(preset="success"), Style(preset="success"),
None, # Default presets don't have __class__ "mf-formatting-success",
["background-color: var(--color-success)", "color: var(--color-success-content)"] []
), ),
# Empty style # Empty style
( (
@@ -246,3 +261,12 @@ class TestToStyleContainer:
assert isinstance(result, StyleContainer) assert isinstance(result, StyleContainer)
assert result.cls is None assert result.cls is None
assert result.css == "" assert result.css == ""
def test_i_can_resolve_default_preset_to_container(self):
"""Default preset with __class__ only generates cls but no css."""
resolver = StyleResolver()
style = Style(preset="error")
result = resolver.to_style_container(style)
assert result.cls == "mf-formatting-error"
assert result.css == ""

View File

View File

@@ -0,0 +1,391 @@
"""
Tests for the DependencyGraph DAG.
"""
import pytest
from myfasthtml.core.formula.dataclasses import (
ColumnRef,
ConditionalExpr,
CrossTableRef,
FormulaDefinition,
FunctionCall,
LiteralNode,
UnaryOp,
WhereClause,
)
from myfasthtml.core.formula.dependency_graph import DependencyGraph
from myfasthtml.core.formula.dsl.exceptions import FormulaCycleError
from myfasthtml.core.formula.engine import parse_formula
def make_formula(expr):
"""Create a FormulaDefinition from a raw AST node for direct AST testing."""
return FormulaDefinition(expression=expr, source_text="test")
# ==================== Add formula ====================
def test_i_can_add_simple_dependency():
"""Test that adding a formula creates edges in the graph."""
graph = DependencyGraph()
formula = parse_formula("{Price} * {Quantity}")
graph.add_formula("orders", "total", formula)
node = graph.get_node("orders", "total")
assert node is not None
assert node.formula is formula
assert node.dirty is True
# Precedents should include Price and Quantity
precedent = graph._precedents["orders.total"]
assert "orders.Price" in precedent
assert "orders.Quantity" in precedent
# Descendants are correctly set
assert "orders.total" in graph._dependents["orders.Price"]
assert "orders.total" in graph._dependents["orders.Quantity"]
def test_i_can_add_formula_and_check_dirty():
"""Test that newly added formulas are marked dirty."""
graph = DependencyGraph()
formula = parse_formula("{A} + {B}")
graph.add_formula("t", "C", formula)
node = graph.get_node("t", "C")
assert node.dirty is True
def test_i_can_update_formula():
"""Test that replacing a formula updates dependencies."""
graph = DependencyGraph()
formula1 = parse_formula("{A} + {B}")
graph.add_formula("t", "C", formula1)
formula2 = parse_formula("{X} * {Y}")
graph.add_formula("t", "C", formula2)
node = graph.get_node("t", "C")
assert node.formula is formula2
# Should no longer depend on A and B
precedent = graph._precedents.get("t.C", set())
assert "t.A" not in precedent
assert "t.B" not in precedent
# ==================== Cycle detection ====================
def test_i_cannot_create_cycle():
"""Test that circular dependencies raise FormulaCycleError."""
graph = DependencyGraph()
# A depends on B
graph.add_formula("t", "A", parse_formula("{B} + 1"))
# B depends on A -> cycle
with pytest.raises(FormulaCycleError):
graph.add_formula("t", "B", parse_formula("{A} * 2"))
def test_i_cannot_create_self_reference():
"""Test that a formula referencing its own column raises FormulaCycleError."""
graph = DependencyGraph()
with pytest.raises(FormulaCycleError):
graph.add_formula("t", "A", parse_formula("{A} + 1"))
def test_i_can_detect_long_cycle():
"""Test that a longer chain cycle is also detected."""
graph = DependencyGraph()
graph.add_formula("t", "A", parse_formula("{B} + 1"))
graph.add_formula("t", "B", parse_formula("{C} + 1"))
with pytest.raises(FormulaCycleError):
graph.add_formula("t", "C", parse_formula("{A} + 1"))
# ==================== Dirty flag propagation ====================
def test_i_can_propagate_dirty_flags():
"""Test that marking a source column dirty propagates to dependents."""
graph = DependencyGraph()
graph.add_formula("t", "B", parse_formula("{A} * 2"))
# Clear dirty flag set by add_formula
graph.clear_dirty("t.B")
assert not graph.get_node("t", "B").dirty
# Mark source A as dirty
graph.mark_dirty("t", "A")
# B depends on A, so B should become dirty
node_b = graph.get_node("t", "B")
assert node_b is not None
assert node_b.dirty is True
def test_i_can_propagate_dirty_to_chain():
"""Test that dirty flags propagate through a chain: A -> B -> C."""
graph = DependencyGraph()
graph.add_formula("t", "B", parse_formula("{A} + 1"))
graph.add_formula("t", "C", parse_formula("{B} + 1"))
# Clear all dirty flags
graph.clear_dirty("t.B")
graph.clear_dirty("t.C")
# Mark A dirty
graph.mark_dirty("t", "A")
assert graph.get_node("t", "B").dirty is True
assert graph.get_node("t", "C").dirty is True
def test_i_can_propagate_specific_rows():
"""Test that dirty propagation can be limited to specific rows."""
graph = DependencyGraph()
graph.add_formula("t", "B", parse_formula("{A} * 2"))
graph.clear_dirty("t.B")
graph.mark_dirty("t", "A", rows=[0, 2, 5])
node_b = graph.get_node("t", "B")
assert node_b.dirty is True
assert 0 in node_b.dirty_rows
assert 2 in node_b.dirty_rows
assert 5 in node_b.dirty_rows
assert 1 not in node_b.dirty_rows
# ==================== Topological ordering ====================
def test_i_can_get_calculation_order():
"""Test that dirty formula nodes are returned in topological order."""
graph = DependencyGraph()
graph.add_formula("t", "B", parse_formula("{A} + 1"))
graph.add_formula("t", "C", parse_formula("{B} + 1"))
graph.mark_dirty("t", "A")
order = graph.get_calculation_order(table="t")
node_ids = [n.node_id for n in order]
# B must come before C
assert "t.B" in node_ids
assert "t.C" in node_ids
assert node_ids.index("t.B") < node_ids.index("t.C")
def test_i_can_get_calculation_order_without_table_filter():
"""Test that get_calculation_order with no table filter returns dirty nodes across all tables.
Why: The table parameter is optional. Omitting it should return dirty formula nodes
from every table, not just one.
"""
graph = DependencyGraph()
graph.add_formula("t1", "B", parse_formula("{A} + 1"))
graph.add_formula("t2", "D", parse_formula("{C} * 2"))
order = graph.get_calculation_order()
node_ids = [n.node_id for n in order]
assert "t1.B" in node_ids
assert "t2.D" in node_ids
def test_i_can_get_calculation_order_excludes_clean_nodes():
"""Test that get_calculation_order only returns dirty formula nodes.
Why: The method filters on node.dirty (source line 166). Clean nodes must
never appear in the output, otherwise they would be recalculated unnecessarily.
"""
graph = DependencyGraph()
graph.add_formula("t", "B", parse_formula("{A} + 1"))
graph.add_formula("t", "C", parse_formula("{A} * 2"))
graph.clear_dirty("t.C")
order = graph.get_calculation_order(table="t")
node_ids = [n.node_id for n in order]
assert "t.B" in node_ids
assert "t.C" not in node_ids
def test_i_can_handle_diamond_dependency():
"""Test topological order with diamond dependency: A -> B, A -> C, B+C -> D."""
graph = DependencyGraph()
graph.add_formula("t", "B", parse_formula("{A} + 1"))
graph.add_formula("t", "C", parse_formula("{A} * 2"))
graph.add_formula("t", "D", parse_formula("{B} + {C}"))
graph.mark_dirty("t", "A")
order = graph.get_calculation_order(table="t")
node_ids = [n.node_id for n in order]
assert "t.B" in node_ids
assert "t.C" in node_ids
assert "t.D" in node_ids
# D must come after both B and C
assert node_ids.index("t.D") > node_ids.index("t.B")
assert node_ids.index("t.D") > node_ids.index("t.C")
# ==================== Remove formula ====================
def test_i_can_remove_formula():
"""Test that a formula can be removed from the graph."""
graph = DependencyGraph()
graph.add_formula("t", "B", parse_formula("{A} + 1"))
assert graph.has_formula("t", "B")
graph.remove_formula("t", "B")
assert not graph.has_formula("t", "B")
def test_i_can_remove_formula_node_kept_when_has_dependents():
"""Test that removing a formula keeps the node when other formulas still depend on it.
Why: remove_formula deletes the node only when no dependents exist (source line 145-146).
If another formula depends on the removed node it must remain as a data node.
"""
graph = DependencyGraph()
graph.add_formula("t", "B", parse_formula("{A} + 1"))
graph.add_formula("t", "C", parse_formula("{B} * 2"))
graph.remove_formula("t", "B")
assert not graph.has_formula("t", "B")
assert graph.get_node("t", "B") is not None
def test_i_can_remove_formula_and_add_back():
"""Test that a formula can be removed and re-added."""
graph = DependencyGraph()
graph.add_formula("t", "B", parse_formula("{A} + 1"))
graph.remove_formula("t", "B")
# Should not raise
graph.add_formula("t", "B", parse_formula("{X} * 2"))
assert graph.has_formula("t", "B")
# ==================== Cross-table ====================
def test_i_can_handle_cross_table_dependencies():
"""Test that cross-table references create inter-table edges."""
graph = DependencyGraph()
formula = parse_formula("{Products.Price} * {Quantity}")
graph.add_formula("orders", "total", formula)
node = graph.get_node("orders", "total")
assert node is not None
prec = graph._precedents.get("orders.total", set())
assert "Products.Price" in prec
assert "orders.Quantity" in prec
def test_i_can_extract_cross_table_ref_with_where_clause():
"""Test that CrossTableRef with a WHERE clause adds both the remote column
and the local column from the WHERE clause as dependencies.
Why: Source line 366-367 adds where_clause.local_column as a dependency.
Without this test that code path is never exercised.
"""
graph = DependencyGraph()
formula = make_formula(
CrossTableRef(
table="Products",
column="Price",
where_clause=WhereClause(
remote_table="Products",
remote_column="id",
local_column="product_id",
),
)
)
graph.add_formula("orders", "total", formula)
prec = graph._precedents.get("orders.total", set())
assert "Products.Price" in prec
assert "orders.product_id" in prec
# ==================== _extract_dependencies AST coverage ====================
def test_i_can_extract_unary_op_dependency():
"""Test that UnaryOp correctly extracts its operand's column dependency.
Why: Source lines 373-375 handle UnaryOp. Without this test that branch
is never exercised — a negated column reference would silently produce
no dependency edge.
"""
graph = DependencyGraph()
formula = make_formula(UnaryOp("-", ColumnRef("A")))
graph.add_formula("t", "B", formula)
prec = graph._precedents.get("t.B", set())
assert "t.A" in prec
def test_i_can_extract_function_call_dependencies():
"""Test that FunctionCall extracts column dependencies from all arguments.
Why: Source lines 376-378 iterate over arguments. With multiple column
arguments every dependency must appear in the precedents set.
"""
graph = DependencyGraph()
formula = make_formula(FunctionCall("SUM", [ColumnRef("A"), ColumnRef("B")]))
graph.add_formula("t", "C", formula)
prec = graph._precedents.get("t.C", set())
assert "t.A" in prec
assert "t.B" in prec
def test_i_can_extract_function_call_with_no_column_args():
"""Test that a FunctionCall with only literal arguments creates no column dependencies.
Why: FunctionCall with literal args only (e.g. NOW(1)) must not create
spurious dependency edges that would trigger unnecessary recalculations.
"""
graph = DependencyGraph()
formula = make_formula(FunctionCall("NOW", [LiteralNode(1)]))
graph.add_formula("t", "C", formula)
prec = graph._precedents.get("t.C", set())
assert len(prec) == 0, "FunctionCall with literal args should produce no column dependencies"
def test_i_can_extract_conditional_expr_all_branches():
"""Test that ConditionalExpr extracts dependencies from value_expr, condition, and else_expr.
Why: Source lines 380-384 walk all three branches. All three column references
must appear as dependencies so that any change in any branch triggers recalculation.
"""
graph = DependencyGraph()
formula = make_formula(
ConditionalExpr(
value_expr=ColumnRef("A"),
condition=ColumnRef("B"),
else_expr=ColumnRef("C"),
)
)
graph.add_formula("t", "D", formula)
prec = graph._precedents.get("t.D", set())
assert "t.A" in prec
assert "t.B" in prec
assert "t.C" in prec
def test_i_can_extract_conditional_expr_without_else():
"""Test that ConditionalExpr with else_expr=None does not crash and extracts value and condition.
Why: Source line 383 guards on ``if node.else_expr is not None``. A missing
else branch must not raise and must still extract the remaining two dependencies.
"""
graph = DependencyGraph()
formula = make_formula(
ConditionalExpr(
value_expr=ColumnRef("A"),
condition=ColumnRef("B"),
else_expr=None,
)
)
graph.add_formula("t", "C", formula)
prec = graph._precedents.get("t.C", set())
assert "t.A" in prec
assert "t.B" in prec

View File

@@ -0,0 +1,408 @@
"""
Tests for the FormulaEngine facade.
"""
import numpy as np
import pytest
from myfasthtml.core.formula.dsl.exceptions import FormulaSyntaxError, FormulaCycleError
from myfasthtml.core.formula.engine import FormulaEngine
class FakeStore:
"""Minimal DatagridStore-like object for testing."""
def __init__(self, rows):
self.ns_row_data = rows
self.ns_fast_access = self._build_fast_access(rows)
self.ns_total_rows = len(rows)
@staticmethod
def _build_fast_access(rows):
"""Build a columnar fast-access dict from a list of row dicts."""
if not rows:
return {}
return {
col: np.array([row.get(col) for row in rows], dtype=object)
for col in rows[0].keys()
}
def make_engine(store_map=None):
"""Create an engine with an optional store map for cross-table resolution."""
def resolver(table_name):
return store_map.get(table_name) if store_map else None
return FormulaEngine(registry_resolver=resolver)
# ==================== TestSetFormula ====================
class TestSetFormula:
"""Tests for set_formula: parsing, registration, and edge cases."""
def test_i_can_set_and_evaluate_formula(self):
"""Test that a formula can be set and evaluated."""
rows = [{"Price": 10, "Quantity": 3}]
store = FakeStore(rows)
engine = make_engine({"orders": store})
engine.set_formula("orders", "total", "{Price} * {Quantity}")
engine.recalculate_if_needed("orders", store)
assert "total" in store.ns_fast_access
assert store.ns_fast_access["total"][0] == 30
def test_i_can_evaluate_multiple_rows(self):
"""Test that formula evaluation works across multiple rows."""
rows = [
{"Price": 10, "Quantity": 3},
{"Price": 20, "Quantity": 2},
{"Price": 5, "Quantity": 10},
]
store = FakeStore(rows)
engine = make_engine({"orders": store})
engine.set_formula("orders", "total", "{Price} * {Quantity}")
engine.recalculate_if_needed("orders", store)
totals = store.ns_fast_access["total"]
assert totals[0] == 30
assert totals[1] == 40
assert totals[2] == 50
def test_i_cannot_set_invalid_formula(self):
"""Test that invalid formula syntax raises FormulaSyntaxError."""
engine = make_engine()
with pytest.raises(FormulaSyntaxError):
engine.set_formula("t", "col", "{Price} * * {Qty}")
def test_i_cannot_set_formula_with_cycle(self):
"""Test that a circular dependency raises FormulaCycleError."""
rows = [{"A": 1}]
store = FakeStore(rows)
engine = make_engine({"t": store})
engine.set_formula("t", "A", "{B} + 1")
with pytest.raises(FormulaCycleError):
engine.set_formula("t", "B", "{A} * 2")
@pytest.mark.parametrize("text", ["", " ", "\t\n"])
def test_i_can_set_formula_with_blank_input_removes_it(self, text):
"""Test that setting a blank or whitespace-only formula string removes it.
Why: set_formula strips the input (source line 86) before checking emptiness.
All blank variants — empty string, spaces, tabs — must behave identically.
"""
rows = [{"Price": 10}]
store = FakeStore(rows)
engine = make_engine({"t": store})
engine.set_formula("t", "col", "{Price} * 2")
assert engine.has_formula("t", "col")
engine.set_formula("t", "col", text)
assert not engine.has_formula("t", "col")
def test_i_can_replace_formula_clears_old_dependencies(self):
"""Test that replacing a formula removes old dependency edges.
Why: add_formula calls _remove_edges before adding new ones (source line 99).
After replacement, marking the old dependency dirty must NOT trigger
recalculation of the formula column.
"""
rows = [{"A": 5, "B": 3, "X": 10}]
store = FakeStore(rows)
engine = make_engine({"t": store})
engine.set_formula("t", "C", "{A} + {B}")
engine.recalculate_if_needed("t", store)
assert store.ns_fast_access["C"][0] == pytest.approx(8.0)
engine.set_formula("t", "C", "{X} * 2")
engine.recalculate_if_needed("t", store)
assert store.ns_fast_access["C"][0] == pytest.approx(20.0)
# A is no longer a dependency — changing it should NOT trigger C recalculation
rows[0]["A"] = 100
store.ns_fast_access["A"] = np.array([100], dtype=object)
engine.mark_data_changed("t", "A")
store.ns_fast_access["C"][0] = 999 # sentinel
engine.recalculate_if_needed("t", store)
assert store.ns_fast_access["C"][0] == 999, (
"C should not be recalculated when A changes after formula replacement"
)
# ==================== TestRemoveFormula ====================
class TestRemoveFormula:
"""Tests for remove_formula."""
def test_i_can_remove_formula(self):
"""Test that a formula can be removed."""
rows = [{"Price": 10, "Quantity": 3}]
store = FakeStore(rows)
engine = make_engine({"t": store})
engine.set_formula("t", "total", "{Price} * {Quantity}")
assert engine.has_formula("t", "total")
engine.remove_formula("t", "total")
assert not engine.has_formula("t", "total")
def test_i_can_remove_formula_and_set_back(self):
"""Test that a formula can be removed via empty string and then re-added."""
rows = [{"Price": 10}]
store = FakeStore(rows)
engine = make_engine({"t": store})
engine.set_formula("t", "col", "{Price} * 2")
assert engine.has_formula("t", "col")
engine.set_formula("t", "col", "")
assert not engine.has_formula("t", "col")
# ==================== TestRecalculate ====================
class TestRecalculate:
"""Tests for recalculate_if_needed: dirty tracking, return values, and row data update."""
def test_i_can_recalculate_only_dirty(self):
"""Test that only dirty formula columns are recalculated."""
rows = [{"A": 5, "B": 3}]
store = FakeStore(rows)
engine = make_engine({"t": store})
engine.set_formula("t", "C", "{A} + {B}")
engine.recalculate_if_needed("t", store)
# Manually set a different value to detect recalculation
store.ns_fast_access["C"][0] = 999
# No dirty flags → should NOT recalculate
engine.recalculate_if_needed("t", store)
assert store.ns_fast_access["C"][0] == 999 # unchanged
def test_i_can_recalculate_after_data_changed(self):
"""Test that marking data changed triggers recalculation."""
rows = [{"A": 5, "B": 3}]
store = FakeStore(rows)
engine = make_engine({"t": store})
engine.set_formula("t", "C", "{A} + {B}")
engine.recalculate_if_needed("t", store)
assert store.ns_fast_access["C"][0] == 8
# Update both storage structures (mirrors real DataGrid behaviour)
rows[0]["A"] = 10
store.ns_fast_access["A"] = np.array([10], dtype=object)
# Mark source column dirty
engine.mark_data_changed("t", "A")
engine.recalculate_if_needed("t", store)
assert store.ns_fast_access["C"][0] == 13
def test_i_can_recalculate_returns_false_when_no_dirty(self):
"""Test that recalculate_if_needed returns False when no nodes are dirty.
Why: Source line 151 returns False early when get_calculation_order is empty.
After a first successful recalculation all dirty flags are cleared, so the
second call must return False to avoid redundant work.
"""
rows = [{"A": 5}]
store = FakeStore(rows)
engine = make_engine({"t": store})
engine.set_formula("t", "B", "{A} + 1")
engine.recalculate_if_needed("t", store) # clears dirty flag
result = engine.recalculate_if_needed("t", store)
assert result is False
def test_i_can_recalculate_returns_true_when_dirty(self):
"""Test that recalculate_if_needed returns True when columns are recalculated.
Why: Source line 164 returns True after the evaluation loop. This return value
lets callers skip downstream work (e.g. rendering) when nothing changed.
"""
rows = [{"A": 5}]
store = FakeStore(rows)
engine = make_engine({"t": store})
engine.set_formula("t", "B", "{A} + 1")
result = engine.recalculate_if_needed("t", store)
assert result is True
def test_i_can_recalculate_with_empty_store(self):
"""Test that recalculate_if_needed handles a store with no rows without crashing.
Why: _evaluate_column guards on empty ns_row_data (source line 211-212).
No formula result should be written when there are no rows to process.
"""
store = FakeStore([]) # no rows
engine = make_engine({"t": store})
engine.set_formula("t", "B", "{A} + 1")
engine.recalculate_if_needed("t", store) # must not raise
assert "B" not in store.ns_fast_access
def test_i_can_verify_formula_values_appear_in_row_data(self):
"""Test that formula values are written back into ns_row_data after recalculation.
Why: _rebuild_row_data (source line 231-249) merges ns_fast_access values into
each row dict. This ensures formula results are available in row_data for
subsequent evaluation passes and for rendering.
"""
rows = [{"A": 5}]
store = FakeStore(rows)
engine = make_engine({"t": store})
engine.set_formula("t", "B", "{A} + 10")
engine.recalculate_if_needed("t", store)
assert store.ns_row_data[0]["B"] == pytest.approx(15.0)
# --- Known bug: chained formula columns ---
# _rebuild_row_data is called once AFTER the full evaluation loop, so formula
# column B is not yet in row_data when formula column C is evaluated.
# Fix needed: rebuild row_data between each column evaluation in the loop.
def test_i_can_recalculate_chain_formula_initial(self):
"""Test that C = f(B) is correct when B is itself a formula column (initial pass).
Chain: A (data) → B = A + 10 → C = B * 2
Expected: B = 15, C = 30.
"""
rows = [{"A": 5}]
store = FakeStore(rows)
engine = make_engine({"t": store})
engine.set_formula("t", "B", "{A} + 10")
engine.set_formula("t", "C", "{B} * 2")
engine.recalculate_if_needed("t", store)
assert store.ns_fast_access["B"][0] == 15
assert store.ns_fast_access["C"][0] == 30
def test_i_can_recalculate_chain_formula_after_data_change(self):
"""Test that C = f(B) stays correct after the source data column A changes.
Chain: A (data) → B = A + 10 → C = B * 2
After A changes from 5 to 10: B = 20, C = 40.
"""
rows = [{"A": 5}]
store = FakeStore(rows)
engine = make_engine({"t": store})
engine.set_formula("t", "B", "{A} + 10")
engine.set_formula("t", "C", "{B} * 2")
engine.recalculate_if_needed("t", store) # first pass (B=15, C=None per bug above)
rows[0]["A"] = 10
store.ns_fast_access["A"] = np.array([10], dtype=object)
engine.mark_data_changed("t", "A")
engine.recalculate_if_needed("t", store)
assert store.ns_fast_access["B"][0] == 20
assert store.ns_fast_access["C"][0] == 40
# ==================== TestCrossTable ====================
class TestCrossTable:
"""Tests for cross-table reference resolution strategies."""
def test_i_can_handle_cross_table_formula(self):
"""Test that cross-table references are resolved via registry (Strategy 3: row_index)."""
orders_rows = [{"Quantity": 3, "ProductId": 1}]
products_rows = [{"Price": 99.0}]
orders_store = FakeStore(orders_rows)
products_store = FakeStore(products_rows)
products_store.ns_fast_access = {"Price": np.array([99.0], dtype=object)}
engine = make_engine({
"orders": orders_store,
"products": products_store,
})
engine.set_formula("orders", "total", "{products.Price} * {Quantity}")
engine.recalculate_if_needed("orders", orders_store)
assert "total" in orders_store.ns_fast_access
assert orders_store.ns_fast_access["total"][0] == pytest.approx(297.0)
def test_i_can_resolve_cross_table_by_id_join(self):
"""Test that cross-table references are resolved via implicit id-join (Strategy 2).
Why: Strategy 2 (source lines 303-318) matches rows where both tables share
the same id value. This allows cross-table lookups without an explicit WHERE
clause when both stores expose an 'id' column in ns_fast_access.
"""
orders_rows = [{"id": 101, "Quantity": 3}, {"id": 102, "Quantity": 5}]
orders_store = FakeStore(orders_rows)
products_rows = [{"id": 101, "Price": 10.0}, {"id": 102, "Price": 20.0}]
products_store = FakeStore(products_rows)
engine = make_engine({
"orders": orders_store,
"products": products_store,
})
engine.set_formula("orders", "total", "{products.Price} * {Quantity}")
engine.recalculate_if_needed("orders", orders_store)
totals = orders_store.ns_fast_access["total"]
assert totals[0] == 30, "Row with id=101: Price=10 * Qty=3"
assert totals[1] == 100, "Row with id=102: Price=20 * Qty=5"
def test_i_can_handle_cross_table_without_registry(self):
"""Test that a cross-table formula evaluates gracefully when no registry is set.
Why: _make_cross_table_resolver guards on registry_resolver=None (source line 274).
The formula must evaluate to None without raising, preserving engine stability.
"""
rows = [{"Quantity": 3}]
store = FakeStore(rows)
engine = FormulaEngine(registry_resolver=None)
engine.set_formula("orders", "total", "{products.Price} * {Quantity}")
engine.recalculate_if_needed("orders", store) # must not raise
assert store.ns_fast_access["total"][0] is None
def test_i_can_handle_cross_table_missing_table(self):
"""Test that a cross-table formula evaluates gracefully when the remote table is absent.
Why: Source line 282-284 returns None when registry_resolver returns None for
the requested table. The engine must not crash and must produce None for the row.
"""
rows = [{"Quantity": 3}]
orders_store = FakeStore(rows)
engine = make_engine({"orders": orders_store}) # "products" not in registry
engine.set_formula("orders", "total", "{products.Price} * {Quantity}")
engine.recalculate_if_needed("orders", orders_store) # must not raise
assert orders_store.ns_fast_access["total"][0] is None
# ==================== TestGetFormulaText ====================
class TestGetFormulaText:
"""Tests for formula text retrieval."""
def test_i_can_get_formula_text(self):
"""Test that registered formula text can be retrieved."""
engine = make_engine()
engine.set_formula("t", "col", "{Price} * 2")
assert engine.get_formula_text("t", "col") == "{Price} * 2"
def test_i_can_get_formula_text_returns_none_when_not_set(self):
"""Test that get_formula_text returns None for non-formula columns."""
engine = make_engine()
assert engine.get_formula_text("t", "non_existing") is None

View File

@@ -0,0 +1,188 @@
"""
Tests for the FormulaEvaluator.
"""
import pytest
from myfasthtml.core.formula.engine import parse_formula
from myfasthtml.core.formula.evaluator import FormulaEvaluator
def make_evaluator(resolver=None):
return FormulaEvaluator(cross_table_resolver=resolver)
def eval_formula(text, row_data, row_index=0, resolver=None):
"""Helper: parse and evaluate a formula."""
formula = parse_formula(text)
evaluator = make_evaluator(resolver)
return evaluator.evaluate(formula, row_data, row_index)
# ==================== Arithmetic ====================
@pytest.mark.parametrize("formula,row_data,expected", [
("{Price} * {Quantity}", {"Price": 10, "Quantity": 3}, 30.0),
("{Price} + {Tax}", {"Price": 100, "Tax": 20}, 120.0),
("{Total} - {Discount}", {"Total": 100, "Discount": 15}, 85.0),
("{Total} / {Count}", {"Total": 100, "Count": 4}, 25.0),
("{Value} % 3", {"Value": 10}, 1.0),
("{Base} ^ 2", {"Base": 5}, 25.0),
])
def test_i_can_evaluate_simple_arithmetic(formula, row_data, expected):
"""Test that arithmetic formulas evaluate correctly."""
result = eval_formula(formula, row_data)
assert result == pytest.approx(expected)
# ==================== Functions ====================
@pytest.mark.parametrize("formula,row_data,expected", [
("round({Price} * 1.2, 2)", {"Price": 10}, 12.0),
("abs({Balance})", {"Balance": -50}, 50.0),
("upper({Name})", {"Name": "hello"}, "HELLO"),
("lower({Name})", {"Name": "WORLD"}, "world"),
("len({Description})", {"Description": "abc"}, 3),
("concat({First}, \" \", {Last})", {"First": "John", "Last": "Doe"}, "John Doe"),
("left({Code}, 3)", {"Code": "ABCDEF"}, "ABC"),
("right({Code}, 3)", {"Code": "ABCDEF"}, "DEF"),
("trim({Name})", {"Name": " hello "}, "hello"),
])
def test_i_can_evaluate_function(formula, row_data, expected):
"""Test that built-in function calls evaluate correctly."""
result = eval_formula(formula, row_data)
assert result == expected
def test_i_can_evaluate_nested_functions():
"""Test that nested function calls evaluate correctly."""
result = eval_formula("round(abs({Val}), 1)", {"Val": -3.456})
assert result == 3.5
# ==================== Conditionals ====================
def test_i_can_evaluate_conditional_true():
"""Test conditional when condition is true."""
result = eval_formula('{Price} * 0.8 if {Country} == "FR"', {"Price": 100, "Country": "FR"})
assert result == 80.0
def test_i_can_evaluate_conditional_false_no_else():
"""Test conditional returns None when condition is false and no else."""
result = eval_formula('{Price} * 0.8 if {Country} == "FR"', {"Price": 100, "Country": "DE"})
assert result is None
def test_i_can_evaluate_conditional_with_else():
"""Test conditional returns else value when condition is false."""
result = eval_formula('{Price} * 0.8 if {Country} == "FR" else {Price}', {"Price": 100, "Country": "DE"})
assert result == 100.0
def test_i_can_evaluate_chained_conditional():
"""Test chained conditionals evaluate in order."""
formula = '{Price} * 0.8 if {Country} == "FR" else {Price} * 0.9 if {Country} == "DE" else {Price}'
assert eval_formula(formula, {"Price": 100, "Country": "FR"}) == pytest.approx(80.0)
assert eval_formula(formula, {"Price": 100, "Country": "DE"}) == pytest.approx(90.0)
assert eval_formula(formula, {"Price": 100, "Country": "US"}) == pytest.approx(100.0)
# ==================== Logical operators ====================
@pytest.mark.parametrize("formula,row_data,expected", [
("{A} and {B}", {"A": True, "B": True}, True),
("{A} and {B}", {"A": True, "B": False}, False),
("{A} or {B}", {"A": False, "B": True}, True),
("{A} or {B}", {"A": False, "B": False}, False),
("not {A}", {"A": True}, False),
("not {A}", {"A": False}, True),
])
def test_i_can_evaluate_logical_operators(formula, row_data, expected):
"""Test that logical operators evaluate correctly."""
result = eval_formula(formula, row_data)
assert result == expected
# ==================== Error handling ====================
def test_i_can_handle_division_by_zero():
"""Test that division by zero returns None."""
result = eval_formula("{A} / {B}", {"A": 10, "B": 0})
assert result is None
def test_i_can_handle_missing_column():
"""Test that missing column reference returns None."""
result = eval_formula("{NonExistent} * 2", {})
assert result is None
def test_i_can_handle_none_operand():
"""Test that operations with None operands return None."""
result = eval_formula("{A} * {B}", {"A": None, "B": 5})
assert result is None
# ==================== Cross-table references ====================
def test_i_can_evaluate_cross_table_ref():
"""Test cross-table reference with mock resolver."""
def mock_resolver(table, column, where_clause, row_index):
assert table == "Products"
assert column == "Price"
return 99.0
result = eval_formula("{Products.Price} * {Quantity}", {"Quantity": 3}, resolver=mock_resolver)
assert result == pytest.approx(297.0)
def test_i_can_evaluate_cross_table_with_where():
"""Test cross-table reference with WHERE clause and mock resolver."""
def mock_resolver(table, column, where_clause, row_index):
assert where_clause is not None
assert where_clause.local_column == "ProductCode"
return 50.0
result = eval_formula(
"{Products.Price where Products.Code = ProductCode} * {Qty}",
{"ProductCode": "ABC", "Qty": 2},
resolver=mock_resolver,
)
assert result == pytest.approx(100.0)
# ==================== Aggregation ====================
def test_i_can_evaluate_aggregation():
"""Test that aggregation functions work with cross-table resolver."""
values = [10.0, 20.0, 30.0]
call_count = [0]
def mock_resolver(table, column, where_clause, row_index):
# For aggregation test, return a list to simulate multi-row match
val = values[call_count[0] % len(values)]
call_count[0] += 1
return val
formula = parse_formula("sum({OrderLines.Amount where OrderLines.OrderId = Id})")
evaluator = FormulaEvaluator(cross_table_resolver=mock_resolver)
result = evaluator.evaluate(formula, {"Id": 1}, 0)
# sum() with a single cross-table value returned by resolver
assert result is not None
# ==================== String operations ====================
@pytest.mark.parametrize("formula,row_data,expected", [
('{Name} contains "Corp"', {"Name": "Acme Corp"}, True),
('{Code} startswith "ERR"', {"Code": "ERR001"}, True),
('{File} endswith ".csv"', {"File": "data.csv"}, True),
('{Status} in ["active", "new"]', {"Status": "active"}, True),
('{Status} in ["active", "new"]', {"Status": "deleted"}, False),
])
def test_i_can_evaluate_string_operations(formula, row_data, expected):
"""Test string comparison operations."""
result = eval_formula(formula, row_data)
assert result == expected

View File

@@ -0,0 +1,188 @@
"""
Tests for the formula parser (grammar + transformer integration).
"""
import pytest
from myfasthtml.core.formula.dataclasses import (
BinaryOp,
CrossTableRef,
ConditionalExpr,
FunctionCall,
LiteralNode,
FormulaDefinition,
)
from myfasthtml.core.formula.dsl.exceptions import FormulaSyntaxError
from myfasthtml.core.formula.engine import parse_formula
# ==================== Valid formulas ====================
@pytest.mark.parametrize("formula_text", [
"{Price} * {Quantity}",
"{Price} + {Tax}",
"{Total} - {Discount}",
"{Total} / {Count}",
"{Value} % 2",
"{Base} ^ 2",
])
def test_i_can_parse_simple_arithmetic(formula_text):
"""Test that basic arithmetic formulas parse without error."""
result = parse_formula(formula_text)
assert result is not None
assert isinstance(result, FormulaDefinition)
assert isinstance(result.expression, BinaryOp)
@pytest.mark.parametrize("formula_text,expected_func", [
("round({Price} * 1.2, 2)", "round"),
("abs({Balance})", "abs"),
("upper({Name})", "upper"),
("len({Description})", "len"),
("today()", "today"),
("concat({First}, \" \", {Last})", "concat"),
])
def test_i_can_parse_function_call(formula_text, expected_func):
"""Test that function calls parse correctly."""
result = parse_formula(formula_text)
assert result is not None
assert isinstance(result.expression, FunctionCall)
assert result.expression.function_name == expected_func
def test_i_can_parse_conditional_no_else():
"""Test that suffix-if without else parses correctly."""
result = parse_formula('{Price} * 0.8 if {Country} == "FR"')
assert result is not None
expr = result.expression
assert isinstance(expr, ConditionalExpr)
assert expr.else_expr is None
assert isinstance(expr.value_expr, BinaryOp)
def test_i_can_parse_conditional_with_else():
"""Test that suffix-if with else parses correctly."""
result = parse_formula('{Price} * 0.8 if {Country} == "FR" else {Price}')
assert result is not None
expr = result.expression
assert isinstance(expr, ConditionalExpr)
assert expr.else_expr is not None
def test_i_can_parse_chained_conditional():
"""Test that chained conditionals parse correctly."""
formula = '{Price} * 0.8 if {Country} == "FR" else {Price} * 0.9 if {Country} == "DE" else {Price}'
result = parse_formula(formula)
assert result is not None
expr = result.expression
assert isinstance(expr, ConditionalExpr)
# The else_expr should be another ConditionalExpr
assert isinstance(expr.else_expr, ConditionalExpr)
def test_i_can_parse_cross_table_ref():
"""Test that cross-table references parse correctly."""
result = parse_formula("{Products.Price} * {Quantity}")
assert result is not None
expr = result.expression
assert isinstance(expr, BinaryOp)
assert isinstance(expr.left, CrossTableRef)
assert expr.left.table == "Products"
assert expr.left.column == "Price"
def test_i_can_parse_cross_table_with_where():
"""Test that cross-table references with WHERE clause parse correctly."""
result = parse_formula("{Products.Price where Products.Code = ProductCode} * {Quantity}")
assert result is not None
expr = result.expression
assert isinstance(expr, BinaryOp)
cross_ref = expr.left
assert isinstance(cross_ref, CrossTableRef)
assert cross_ref.where_clause is not None
assert cross_ref.where_clause.remote_table == "Products"
assert cross_ref.where_clause.remote_column == "Code"
assert cross_ref.where_clause.local_column == "ProductCode"
@pytest.mark.parametrize("formula_text,expected_op", [
("{A} and {B}", "and"),
("{A} or {B}", "or"),
("not {A}", "not"),
])
def test_i_can_parse_logical_operators(formula_text, expected_op):
"""Test that logical operators parse correctly."""
result = parse_formula(formula_text)
assert result is not None
if expected_op == "not":
from myfasthtml.core.formula.dataclasses import UnaryOp
assert isinstance(result.expression, UnaryOp)
else:
assert isinstance(result.expression, BinaryOp)
assert result.expression.operator == expected_op
@pytest.mark.parametrize("formula_text", [
"42",
"3.14",
'"hello"',
"true",
"false",
])
def test_i_can_parse_literals(formula_text):
"""Test that literal values parse correctly."""
result = parse_formula(formula_text)
assert result is not None
assert isinstance(result.expression, LiteralNode)
def test_i_can_parse_aggregation():
"""Test that aggregation with cross-table WHERE parses correctly."""
result = parse_formula("sum({OrderLines.Amount where OrderLines.OrderId = Id})")
assert result is not None
expr = result.expression
assert isinstance(expr, FunctionCall)
assert expr.function_name == "sum"
assert len(expr.arguments) == 1
arg = expr.arguments[0]
assert isinstance(arg, CrossTableRef)
assert arg.where_clause is not None
def test_i_can_parse_empty_formula():
"""Test that empty formula returns None."""
result = parse_formula("")
assert result is None
def test_i_can_parse_whitespace_formula():
"""Test that whitespace-only formula returns None."""
result = parse_formula(" ")
assert result is None
def test_i_can_parse_nested_functions():
"""Test that nested function calls parse correctly."""
result = parse_formula("round(avg({Q1}, {Q2}, {Q3}), 1)")
assert result is not None
expr = result.expression
assert isinstance(expr, FunctionCall)
assert expr.function_name == "round"
# ==================== Invalid formulas ====================
@pytest.mark.parametrize("formula_text", [
"{Price} * * {Quantity}", # double operator
"round(", # unclosed paren
"123 + + 456", # double operator
])
def test_i_cannot_parse_invalid_syntax(formula_text):
"""Test that invalid syntax raises FormulaSyntaxError."""
with pytest.raises(FormulaSyntaxError):
parse_formula(formula_text)
def test_i_cannot_parse_unclosed_brace():
"""Test that an unclosed brace raises FormulaSyntaxError."""
with pytest.raises(FormulaSyntaxError):
parse_formula("{Price")