Introducing columns formulas
This commit is contained in:
365
docs/Datagrid Formulas.md
Normal file
365
docs/Datagrid Formulas.md
Normal file
@@ -0,0 +1,365 @@
|
||||
# DataGrid Formulas
|
||||
|
||||
## Overview
|
||||
|
||||
The DataGrid formula system adds computed columns to the DataGrid. A formula column applies a single expression to every
|
||||
row, producing derived values from existing data — within the same table or across tables.
|
||||
|
||||
The system is designed for:
|
||||
|
||||
- **Column-level formulas**: one formula per column, applied to all rows
|
||||
- **Cross-table references**: direct syntax to reference columns from other tables
|
||||
- **Reactive recalculation**: dirty flag propagation with page-aware computation
|
||||
- **Cell-level overrides** (planned): individual cells can override the column formula
|
||||
|
||||
## Formula Language
|
||||
|
||||
### Basic Syntax
|
||||
|
||||
A formula is an expression that references columns with `{ColumnName}` and produces a value for each row:
|
||||
|
||||
```
|
||||
{Price} * {Quantity}
|
||||
```
|
||||
|
||||
References use curly braces `{}` to distinguish column names from keywords and functions. Column names are matched by ID
|
||||
or title.
|
||||
|
||||
### Operators
|
||||
|
||||
#### Arithmetic
|
||||
|
||||
| Operator | Description | Example |
|
||||
|----------|----------------|------------------------|
|
||||
| `+` | Addition | `{Price} + {Tax}` |
|
||||
| `-` | Subtraction | `{Total} - {Discount}` |
|
||||
| `*` | Multiplication | `{Price} * {Quantity}` |
|
||||
| `/` | Division | `{Total} / {Count}` |
|
||||
| `%` | Modulo | `{Value} % 2` |
|
||||
| `^` | Power | `{Base} ^ 2` |
|
||||
|
||||
#### Comparison
|
||||
|
||||
| Operator | Description | Example |
|
||||
|--------------|--------------------|---------------------------------|
|
||||
| `==` | Equal | `{Status} == "active"` |
|
||||
| `!=` | Not equal | `{Status} != "deleted"` |
|
||||
| `>` | Greater than | `{Price} > 100` |
|
||||
| `<` | Less than | `{Stock} < 10` |
|
||||
| `>=` | Greater or equal | `{Score} >= 80` |
|
||||
| `<=` | Less or equal | `{Age} <= 18` |
|
||||
| `contains` | String contains | `{Name} contains "Corp"` |
|
||||
| `startswith` | String starts with | `{Code} startswith "ERR"` |
|
||||
| `endswith` | String ends with | `{File} endswith ".csv"` |
|
||||
| `in` | Value in list | `{Status} in ["active", "new"]` |
|
||||
| `between` | Value in range | `{Age} between 18 and 65` |
|
||||
| `isempty` | Value is empty | `{Notes} isempty` |
|
||||
| `isnotempty` | Value is not empty | `{Email} isnotempty` |
|
||||
| `isnan` | Value is NaN | `{Score} isnan` |
|
||||
|
||||
#### Logical
|
||||
|
||||
| Operator | Description | Example |
|
||||
|----------|-------------|---------------------------------------|
|
||||
| `and` | Logical AND | `{Age} > 18 and {Status} == "active"` |
|
||||
| `or` | Logical OR | `{Type} == "A" or {Type} == "B"` |
|
||||
| `not` | Negation | `not {Status} == "deleted"` |
|
||||
|
||||
Parentheses control precedence: `({Type} == "A" or {Type} == "B") and {Active} == True`
|
||||
|
||||
### Conditions (suffix-if)
|
||||
|
||||
Conditions use a **suffix-if** syntax: the result expression comes first, then the condition. This keeps the focus on
|
||||
the output, not the branching logic.
|
||||
|
||||
#### Simple condition (no else — result is None when false)
|
||||
|
||||
```
|
||||
{Price} * 0.8 if {Country} == "FR"
|
||||
```
|
||||
|
||||
#### With else
|
||||
|
||||
```
|
||||
{Price} * 0.8 if {Country} == "FR" else {Price}
|
||||
```
|
||||
|
||||
#### Chained conditions
|
||||
|
||||
```
|
||||
{Price} * 0.8 if {Country} == "FR" else {Price} * 0.9 if {Country} == "DE" else {Price}
|
||||
```
|
||||
|
||||
#### With logical operators
|
||||
|
||||
```
|
||||
{Price} * 0.8 if {Country} == "FR" and {Quantity} > 10 else {Price}
|
||||
```
|
||||
|
||||
#### With grouping
|
||||
|
||||
```
|
||||
{Price} * 0.8 if ({Country} == "FR" or {Country} == "DE") and {Quantity} > 10
|
||||
```
|
||||
|
||||
### Functions
|
||||
|
||||
#### Math
|
||||
|
||||
| Function | Description | Example |
|
||||
|-------------------|-----------------------|-------------------------------|
|
||||
| `round(expr, n)` | Round to n decimals | `round({Price} * 1.2, 2)` |
|
||||
| `abs(expr)` | Absolute value | `abs({Balance})` |
|
||||
| `min(expr, expr)` | Minimum of two values | `min({Price}, {MaxPrice})` |
|
||||
| `max(expr, expr)` | Maximum of two values | `max({Score}, 0)` |
|
||||
| `sum(expr, ...)` | Sum of values | `sum({Q1}, {Q2}, {Q3}, {Q4})` |
|
||||
| `avg(expr, ...)` | Average of values | `avg({Q1}, {Q2}, {Q3}, {Q4})` |
|
||||
|
||||
#### Text
|
||||
|
||||
| Function | Description | Example |
|
||||
|---------------------|---------------------|--------------------------------|
|
||||
| `upper(expr)` | Uppercase | `upper({Name})` |
|
||||
| `lower(expr)` | Lowercase | `lower({Email})` |
|
||||
| `len(expr)` | String length | `len({Description})` |
|
||||
| `concat(expr, ...)` | Concatenate strings | `concat({First}, " ", {Last})` |
|
||||
| `trim(expr)` | Remove whitespace | `trim({Input})` |
|
||||
| `left(expr, n)` | First n characters | `left({Code}, 3)` |
|
||||
| `right(expr, n)` | Last n characters | `right({Phone}, 4)` |
|
||||
|
||||
#### Date
|
||||
|
||||
| Function | Description | Example |
|
||||
|------------------------|--------------------|--------------------------------|
|
||||
| `year(expr)` | Extract year | `year({CreatedAt})` |
|
||||
| `month(expr)` | Extract month | `month({CreatedAt})` |
|
||||
| `day(expr)` | Extract day | `day({CreatedAt})` |
|
||||
| `today()` | Current date | `datediff({DueDate}, today())` |
|
||||
| `datediff(expr, expr)` | Difference in days | `datediff({End}, {Start})` |
|
||||
|
||||
#### Aggregation (for cross-table contexts)
|
||||
|
||||
| Function | Description | Example |
|
||||
|---------------|--------------|-----------------------------------------------------|
|
||||
| `sum(expr)` | Sum values | `sum({Orders.Amount WHERE Orders.ClientId = Id})` |
|
||||
| `count(expr)` | Count values | `count({Orders.Id WHERE Orders.ClientId = Id})` |
|
||||
| `avg(expr)` | Average | `avg({Reviews.Score WHERE Reviews.ProductId = Id})` |
|
||||
| `min(expr)` | Minimum | `min({Bids.Price WHERE Bids.ItemId = Id})` |
|
||||
| `max(expr)` | Maximum | `max({Bids.Price WHERE Bids.ItemId = Id})` |
|
||||
|
||||
## Cross-Table References
|
||||
|
||||
### Direct Reference
|
||||
|
||||
Reference a column from another table using `{TableName.ColumnName}`:
|
||||
|
||||
```
|
||||
{Products.Price} * {Quantity}
|
||||
```
|
||||
|
||||
### Join Resolution (implicit)
|
||||
|
||||
When referencing another table without a WHERE clause, the join is resolved automatically:
|
||||
|
||||
1. **By `id` column**: if both tables have a column named `id`, rows are matched on equal `id` values
|
||||
2. **By row index**: if no `id` column exists in both tables, rows are matched by their internal row index (stable
|
||||
across sort/filter)
|
||||
|
||||
### Explicit Join (WHERE clause)
|
||||
|
||||
For explicit control over which row of the other table to use:
|
||||
|
||||
```
|
||||
{Products.Price WHERE Products.Code = ProductCode} * {Quantity}
|
||||
```
|
||||
|
||||
Inside the WHERE clause:
|
||||
|
||||
- `Products.Code` refers to a column in the referenced table
|
||||
- `ProductCode` (no `Table.` prefix) refers to a column in the current table
|
||||
|
||||
### Aggregation with Cross-Table
|
||||
|
||||
When a cross-table reference matches multiple rows, use an aggregation function:
|
||||
|
||||
```
|
||||
sum({OrderLines.Amount WHERE OrderLines.OrderId = Id})
|
||||
```
|
||||
|
||||
Without aggregation, a multi-row match returns the first matching value.
|
||||
|
||||
## Calculation Engine
|
||||
|
||||
### Dependency Graph (DAG)
|
||||
|
||||
The formula system maintains a **Directed Acyclic Graph** of dependencies between columns:
|
||||
|
||||
- **Nodes**: each formula column is a node, identified by `table_name.column_id`
|
||||
- **Edges**: if column A's formula references column B, an edge B → A exists ("A depends on B")
|
||||
- Both directions are tracked:
|
||||
- **Precedents**: columns that a formula reads from
|
||||
- **Dependents**: columns that need recalculation when this column changes
|
||||
|
||||
Cross-table references create edges that span DataGrid instances, managed at the `DataGridsManager` level.
|
||||
|
||||
### Dirty Flag Propagation
|
||||
|
||||
When a source column's data changes:
|
||||
|
||||
1. The source column is marked **dirty**
|
||||
2. All direct dependents are marked dirty
|
||||
3. Propagation continues recursively through the DAG
|
||||
4. Each dirty column maintains a **dirty row set**: the specific row indices that need recalculation
|
||||
|
||||
This propagation is **immediate** (fast — only flag marking, no computation).
|
||||
|
||||
### Recalculation Strategy (Hybrid)
|
||||
|
||||
Actual computation is **deferred to rendering time**:
|
||||
|
||||
1. On value change → dirty flags propagate instantly through the DAG
|
||||
2. On page render (`mk_body_content_page`) → only dirty rows within the visible page (up to 1000 rows) are recalculated
|
||||
3. Off-screen pages remain dirty until scrolled into view
|
||||
4. Calculation follows **topological order** of the DAG to ensure precedents are computed before dependents
|
||||
|
||||
### Cycle Detection
|
||||
|
||||
Before adding a formula, the engine checks for cycles in the DAG using Kahn's algorithm during topological sort. If a
|
||||
cycle is detected:
|
||||
|
||||
- The formula is **rejected**
|
||||
- The editor displays an error identifying the circular dependency chain
|
||||
- The previous formula (if any) remains unchanged
|
||||
|
||||
### Caching
|
||||
|
||||
Each formula column caches its computed values:
|
||||
|
||||
- Results are stored in `ns_fast_access[col_id]` alongside raw data columns
|
||||
- The dirty row set tracks which cached values are stale
|
||||
- Non-dirty rows return their cached value without re-evaluation
|
||||
- Cache is invalidated per-row when source data changes
|
||||
|
||||
## Evaluation
|
||||
|
||||
### Row-by-Row Execution
|
||||
|
||||
Formulas are evaluated **row-by-row** within the page being rendered. For each row:
|
||||
|
||||
1. Resolve column references `{ColumnName}` to the cell value at the current row index
|
||||
2. Resolve cross-table references `{Table.Column}` via the join mechanism
|
||||
3. Evaluate the expression with resolved values
|
||||
4. Store the result in the cache (`ns_fast_access`)
|
||||
|
||||
### Parser
|
||||
|
||||
The formula language uses a **custom grammar** parsed with Lark (consistent with the formatting DSL). The parser:
|
||||
|
||||
1. Tokenizes the formula string
|
||||
2. Builds an AST (Abstract Syntax Tree)
|
||||
3. Transforms the AST into an evaluable representation
|
||||
4. Extracts column references for dependency graph registration
|
||||
|
||||
### Error Handling
|
||||
|
||||
| Error Type | Behavior |
|
||||
|-----------------------|-------------------------------------------------------|
|
||||
| Syntax error | Editor highlights the error, formula not saved |
|
||||
| Unknown column | Editor highlights, autocompletion suggests fixes |
|
||||
| Type mismatch | Cell displays error indicator, other cells unaffected |
|
||||
| Division by zero | Cell displays `#DIV/0!` or None |
|
||||
| Circular dependency | Formula rejected, editor shows cycle chain |
|
||||
| Cross-table not found | Editor highlights unknown table name |
|
||||
| No join match | Cell displays None |
|
||||
|
||||
## User Interface
|
||||
|
||||
### Creating a Formula Column
|
||||
|
||||
Formula columns are created and edited through the **DataGridColumnsManager**:
|
||||
|
||||
1. User opens the Columns Manager panel
|
||||
2. Adds a new column or edits an existing one
|
||||
3. Selects column type **"Formula"**
|
||||
4. A **DslEditor** (CodeMirror 5) opens for formula input
|
||||
5. The editor provides:
|
||||
- **Syntax highlighting**: keywords, column references, functions, operators
|
||||
- **Autocompletion**: column names (current table and other tables), function names, table names
|
||||
- **Validation**: real-time syntax checking and dependency cycle detection
|
||||
- **Error markers**: inline error indicators with descriptions
|
||||
|
||||
### Formula Column Properties
|
||||
|
||||
A formula column extends `DataGridColumnState` with:
|
||||
|
||||
| Property | Type | Description |
|
||||
|---------------------------------------------------------------------------|---------------|------------------------------------------------|
|
||||
| `formula` | `str` or None | The formula expression (None for data columns) |
|
||||
| `col_type` | `ColumnType` | Set to `ColumnType.Formula` |
|
||||
| Other properties (`title`, `visible`, `width`, `format`) remain unchanged |
|
||||
|
||||
Formula columns are **read-only** in the grid body — cell values are computed, not editable. Formatting rules from the
|
||||
formatting DSL apply to formula columns like any other column.
|
||||
|
||||
## Integration Points
|
||||
|
||||
| Component | Role |
|
||||
|--------------------------|----------------------------------------------------------|
|
||||
| `DataGridColumnState` | Stores `formula` field and `ColumnType.Formula` type |
|
||||
| `DatagridStore` | `ns_fast_access` caches formula results as numpy arrays |
|
||||
| `DataGridColumnsManager` | UI for creating/editing formula columns |
|
||||
| `DataGridsManager` | Hosts the global dependency DAG across all tables |
|
||||
| `DslEditor` | CodeMirror 5 editor with highlighting and autocompletion |
|
||||
| `FormattingEngine` | Applies formatting rules AFTER formula evaluation |
|
||||
| `mk_body_content_page()` | Triggers formula computation for visible rows |
|
||||
| `mk_body_cell_content()` | Reads computed values from `ns_fast_access` |
|
||||
|
||||
## Syntax Summary
|
||||
|
||||
```
|
||||
# Basic arithmetic
|
||||
{Price} * {Quantity}
|
||||
|
||||
# Function call
|
||||
round({Price} * 1.2, 2)
|
||||
|
||||
# Simple condition (None if false)
|
||||
{Price} * 0.8 if {Country} == "FR"
|
||||
|
||||
# Condition with else
|
||||
{Price} * 0.8 if {Country} == "FR" else {Price}
|
||||
|
||||
# Chained conditions
|
||||
{Price} * 0.8 if {Country} == "FR" else {Price} * 0.9 if {Country} == "DE" else {Price}
|
||||
|
||||
# Logical operators
|
||||
{Price} * 0.8 if {Country} == "FR" and {Quantity} > 10
|
||||
|
||||
# Grouping
|
||||
{Price} * 0.8 if ({Country} == "FR" or {Country} == "DE") and {Quantity} > 10
|
||||
|
||||
# Cross-table (implicit join on id)
|
||||
{Products.Price} * {Quantity}
|
||||
|
||||
# Cross-table (explicit join)
|
||||
{Products.Price WHERE Products.Code = ProductCode} * {Quantity}
|
||||
|
||||
# Cross-table aggregation
|
||||
sum({OrderLines.Amount WHERE OrderLines.OrderId = Id})
|
||||
|
||||
# Nested functions
|
||||
round(avg({Q1}, {Q2}, {Q3}, {Q4}), 1)
|
||||
|
||||
# Text operations
|
||||
concat(upper(left({FirstName}, 1)), ". ", {LastName})
|
||||
```
|
||||
|
||||
## Future: Cell-Level Overrides
|
||||
|
||||
The architecture supports adding cell-level formula overrides with ~20-30% additional work:
|
||||
|
||||
- **Storage**: sparse dict `cell_formulas: dict[(col_id, row_index), str]` (same pattern as `cell_formats`)
|
||||
- **DAG**: new node type `table.column[row]` alongside existing `table.column` nodes
|
||||
- **Evaluation**: "does this cell have an override? If yes, use it. Otherwise, use the column formula."
|
||||
- **Node ID scheme**: designed to be extensible from the start (`table.column` for columns, `table.column[row]` for
|
||||
cells)
|
||||
@@ -418,9 +418,41 @@ class DataGrid(MultipleInstance):
|
||||
self._df_store.ns_fast_access = _init_fast_access(self._df)
|
||||
self._df_store.ns_row_data = _init_row_data(self._df)
|
||||
self._df_store.ns_total_rows = len(self._df) if self._df is not None else 0
|
||||
if init_state:
|
||||
self._register_existing_formulas()
|
||||
|
||||
return self
|
||||
|
||||
def _register_existing_formulas(self) -> None:
|
||||
"""
|
||||
Re-register all formula columns with the FormulaEngine.
|
||||
|
||||
Called after data reload to ensure the engine knows about all
|
||||
formula columns and their expressions.
|
||||
"""
|
||||
engine = self._get_formula_engine()
|
||||
if engine is None:
|
||||
return
|
||||
table = self.get_table_name()
|
||||
for col_def in self._state.columns:
|
||||
if col_def.formula:
|
||||
try:
|
||||
engine.set_formula(table, col_def.col_id, col_def.formula)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to register formula for %s.%s: %s", table, col_def.col_id, e)
|
||||
|
||||
def _recalculate_formulas(self) -> None:
|
||||
"""
|
||||
Recalculate dirty formula columns before rendering.
|
||||
|
||||
Called at the start of mk_body_content_page() to ensure formula
|
||||
columns are up-to-date before cells are rendered.
|
||||
"""
|
||||
engine = self._get_formula_engine()
|
||||
if engine is None:
|
||||
return
|
||||
engine.recalculate_if_needed(self.get_table_name(), self._df_store)
|
||||
|
||||
def _get_format_rules(self, col_pos, row_index, col_def):
|
||||
"""
|
||||
Get format rules for a cell, returning only the most specific level defined.
|
||||
@@ -575,6 +607,11 @@ class DataGrid(MultipleInstance):
|
||||
def get_table_name(self):
|
||||
return f"{self._settings.namespace}.{self._settings.name}" if self._settings.namespace else self._settings.name
|
||||
|
||||
def get_formula_engine(self):
|
||||
"""Return the FormulaEngine from the DataGridsManager, if available."""
|
||||
return self._parent.get_formula_engine()
|
||||
|
||||
|
||||
def mk_headers(self):
|
||||
resize_cmd = self.commands.set_column_width()
|
||||
move_cmd = self.commands.move_column()
|
||||
@@ -701,6 +738,7 @@ class DataGrid(MultipleInstance):
|
||||
OPTIMIZED: Extract filter keyword once instead of 10,000 times.
|
||||
OPTIMIZED: Uses OptimizedDiv for rows instead of Div for faster rendering.
|
||||
"""
|
||||
self._recalculate_formulas()
|
||||
df = self._get_filtered_df()
|
||||
if df is None:
|
||||
return []
|
||||
|
||||
@@ -99,6 +99,9 @@ class DataGridColumnsManager(MultipleInstance):
|
||||
col_def.type = ColumnType(v)
|
||||
elif k == "width":
|
||||
col_def.width = int(v)
|
||||
elif k == "formula":
|
||||
col_def.formula = v or ""
|
||||
self._register_formula(col_def)
|
||||
else:
|
||||
setattr(col_def, k, v)
|
||||
|
||||
@@ -107,6 +110,21 @@ class DataGridColumnsManager(MultipleInstance):
|
||||
|
||||
return self.mk_all_columns()
|
||||
|
||||
def _register_formula(self, col_def) -> None:
|
||||
"""Register or remove a formula column with the FormulaEngine."""
|
||||
engine = self._parent.get_formula_engine()
|
||||
if engine is None:
|
||||
return
|
||||
table = self._parent.get_table_name()
|
||||
if col_def.formula:
|
||||
try:
|
||||
engine.set_formula(table, col_def.col_id, col_def.formula)
|
||||
logger.debug("Registered formula for %s.%s", table, col_def.col_id)
|
||||
except Exception as e:
|
||||
logger.warning("Formula error for %s.%s: %s", table, col_def.col_id, e)
|
||||
else:
|
||||
engine.remove_formula(table, col_def.col_id)
|
||||
|
||||
def mk_column_label(self, col_def: DataGridColumnState):
|
||||
return Div(
|
||||
mk.mk(
|
||||
@@ -168,6 +186,17 @@ class DataGridColumnsManager(MultipleInstance):
|
||||
value=col_def.title,
|
||||
),
|
||||
|
||||
*([
|
||||
Label("Formula"),
|
||||
Textarea(
|
||||
col_def.formula or "",
|
||||
name="formula",
|
||||
cls=f"textarea textarea-{size} w-full font-mono",
|
||||
placeholder="{Column} * {OtherColumn}",
|
||||
rows=3,
|
||||
),
|
||||
] if col_def.type == ColumnType.Formula else []),
|
||||
|
||||
legend="Column details",
|
||||
cls="fieldset border-base-300 rounded-box"
|
||||
),
|
||||
|
||||
67
src/myfasthtml/controls/DataGridFormulaEditor.py
Normal file
67
src/myfasthtml/controls/DataGridFormulaEditor.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""
|
||||
DataGridFormulaEditor — DslEditor for formula column expressions.
|
||||
|
||||
Extends DslEditor with formula-specific behavior:
|
||||
- Parses the formula on content change
|
||||
- Registers the formula with FormulaEngine
|
||||
- Triggers a body re-render on the parent DataGrid
|
||||
"""
|
||||
import logging
|
||||
|
||||
from myfasthtml.controls.DslEditor import DslEditor
|
||||
from myfasthtml.core.formula.dsl.exceptions import FormulaSyntaxError, FormulaCycleError
|
||||
|
||||
logger = logging.getLogger("DataGridFormulaEditor")
|
||||
|
||||
|
||||
class DataGridFormulaEditor(DslEditor):
|
||||
"""
|
||||
Formula editor for a specific DataGrid column.
|
||||
|
||||
Args:
|
||||
parent: The parent DataGrid instance.
|
||||
col_def: The DataGridColumnState for the formula column.
|
||||
conf: DslEditorConf for CodeMirror configuration.
|
||||
_id: Optional instance ID.
|
||||
"""
|
||||
|
||||
def __init__(self, parent, col_def, conf=None, _id=None):
|
||||
super().__init__(parent, conf=conf, _id=_id)
|
||||
self._col_def = col_def
|
||||
|
||||
def on_content_changed(self):
|
||||
"""
|
||||
Called when the formula text is changed in the editor.
|
||||
|
||||
1. Updates col_def.formula with the new text.
|
||||
2. Registers the formula with the FormulaEngine.
|
||||
3. Triggers a body re-render of the parent DataGrid.
|
||||
"""
|
||||
formula_text = self.get_content()
|
||||
|
||||
# Update the column definition
|
||||
self._col_def.formula = formula_text or ""
|
||||
|
||||
# Register with the FormulaEngine
|
||||
engine = self._parent._get_formula_engine()
|
||||
if engine is not None:
|
||||
table = self._parent.get_table_name()
|
||||
try:
|
||||
engine.set_formula(table, self._col_def.col_id, formula_text)
|
||||
logger.debug(
|
||||
"Formula updated for %s.%s: %s",
|
||||
table, self._col_def.col_id, formula_text,
|
||||
)
|
||||
except FormulaSyntaxError as e:
|
||||
logger.debug("Formula syntax error, keeping old formula: %s", e)
|
||||
return
|
||||
except FormulaCycleError as e:
|
||||
logger.warning("Formula cycle detected for %s.%s: %s", table, self._col_def.col_id, e)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning("Formula engine error for %s.%s: %s", table, self._col_def.col_id, e)
|
||||
return
|
||||
|
||||
# Save state and re-render the grid body
|
||||
self._parent.save_state()
|
||||
return self._parent.render_partial("body")
|
||||
@@ -12,7 +12,7 @@ from myfasthtml.icons.fluent import brain_circuit20_regular
|
||||
from myfasthtml.icons.fluent_p1 import filter20_regular, search20_regular
|
||||
from myfasthtml.icons.fluent_p2 import dismiss_circle20_regular
|
||||
|
||||
logger = logging.getLogger("DataGridFilter")
|
||||
logger = logging.getLogger("DataGridQuery")
|
||||
|
||||
DG_QUERY_FILTER = "filter"
|
||||
DG_QUERY_SEARCH = "search"
|
||||
|
||||
@@ -16,6 +16,7 @@ from myfasthtml.core.commands import Command
|
||||
from myfasthtml.core.dbmanager import DbObject
|
||||
from myfasthtml.core.formatting.dsl.completion.provider import DatagridMetadataProvider
|
||||
from myfasthtml.core.formatting.presets import DEFAULT_STYLE_PRESETS, DEFAULT_FORMATTER_PRESETS
|
||||
from myfasthtml.core.formula.engine import FormulaEngine
|
||||
from myfasthtml.core.instances import InstancesManager, SingleInstance
|
||||
from myfasthtml.icons.fluent_p1 import table_add20_regular
|
||||
from myfasthtml.icons.fluent_p3 import folder_open20_regular
|
||||
@@ -91,6 +92,11 @@ class DataGridsManager(SingleInstance, DatagridMetadataProvider):
|
||||
self.style_presets: dict = DEFAULT_STYLE_PRESETS.copy()
|
||||
self.formatter_presets: dict = DEFAULT_FORMATTER_PRESETS.copy()
|
||||
self.all_tables_formats: list = []
|
||||
|
||||
# Formula engine shared across all DataGrids in this session
|
||||
self._formula_engine = FormulaEngine(
|
||||
registry_resolver=self._resolve_store_for_table
|
||||
)
|
||||
|
||||
def upload_from_source(self):
|
||||
file_upload = FileUpload(self)
|
||||
@@ -167,10 +173,10 @@ class DataGridsManager(SingleInstance, DatagridMetadataProvider):
|
||||
|
||||
def list_column_values(self, table_name, column_name):
|
||||
return self._registry.get_column_values(table_name, column_name)
|
||||
|
||||
|
||||
def get_row_count(self, table_name):
|
||||
return self._registry.get_row_count(table_name)
|
||||
|
||||
|
||||
def get_column_type(self, table_name, column_name):
|
||||
return self._registry.get_column_type(table_name, column_name)
|
||||
|
||||
@@ -180,7 +186,29 @@ class DataGridsManager(SingleInstance, DatagridMetadataProvider):
|
||||
def list_format_presets(self) -> list[str]:
|
||||
return list(self.formatter_presets.keys())
|
||||
|
||||
# === Presets Management ===
|
||||
def _resolve_store_for_table(self, table_name: str):
|
||||
"""
|
||||
Resolve the DatagridStore for a given table name.
|
||||
|
||||
Used by FormulaEngine as the registry_resolver callback.
|
||||
|
||||
Args:
|
||||
table_name: Full table name in ``"namespace.name"`` format.
|
||||
|
||||
Returns:
|
||||
DatagridStore instance or None if not found.
|
||||
"""
|
||||
try:
|
||||
as_fullname_dict = self._registry._get_entries_as_full_name_dict()
|
||||
grid_id = as_fullname_dict.get(table_name)
|
||||
if grid_id is None:
|
||||
return None
|
||||
datagrid = InstancesManager.get(self._session, grid_id, None)
|
||||
if datagrid is None:
|
||||
return None
|
||||
return datagrid._df_store
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_style_presets(self) -> dict:
|
||||
"""Get the global style presets."""
|
||||
@@ -190,6 +218,10 @@ class DataGridsManager(SingleInstance, DatagridMetadataProvider):
|
||||
"""Get the global formatter presets."""
|
||||
return self.formatter_presets
|
||||
|
||||
def get_formula_engine(self) -> FormulaEngine:
|
||||
"""The FormulaEngine shared across all DataGrids in this session."""
|
||||
return self._formula_engine
|
||||
|
||||
def add_style_preset(self, name: str, preset: dict):
|
||||
"""
|
||||
Add or update a style preset.
|
||||
|
||||
@@ -20,6 +20,7 @@ class DataGridColumnState:
|
||||
visible: bool = True
|
||||
width: int = DATAGRID_DEFAULT_COLUMN_WIDTH
|
||||
format: list = field(default_factory=list) #
|
||||
formula: str = "" # formula expression for ColumnType.Formula columns
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -26,6 +26,7 @@ class ColumnType(Enum):
|
||||
Bool = "Boolean"
|
||||
Choice = "Choice"
|
||||
Enum = "Enum"
|
||||
Formula = "Formula"
|
||||
|
||||
|
||||
class ViewType(Enum):
|
||||
|
||||
@@ -98,6 +98,9 @@ class StyleResolver:
|
||||
return StyleContainer(None, "")
|
||||
|
||||
cls = props.pop("__class__", None)
|
||||
if not props:
|
||||
return StyleContainer(cls, "")
|
||||
|
||||
css = "; ".join(f"{key}: {value}" for key, value in props.items()) + ";"
|
||||
|
||||
|
||||
return StyleContainer(cls, css)
|
||||
|
||||
0
src/myfasthtml/core/formula/__init__.py
Normal file
0
src/myfasthtml/core/formula/__init__.py
Normal file
79
src/myfasthtml/core/formula/dataclasses.py
Normal file
79
src/myfasthtml/core/formula/dataclasses.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class FormulaNode:
|
||||
"""Base AST node for formula expressions."""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class LiteralNode(FormulaNode):
|
||||
"""A literal value (number, string, boolean)."""
|
||||
value: Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class ColumnRef(FormulaNode):
|
||||
"""Reference to a column in the current table: {ColumnName}."""
|
||||
column: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class WhereClause:
|
||||
"""WHERE clause for cross-table references: WHERE remote_table.remote_col = local_col."""
|
||||
remote_table: str
|
||||
remote_column: str
|
||||
local_column: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class CrossTableRef(FormulaNode):
|
||||
"""Reference to a column in another table: {Table.Column}."""
|
||||
table: str
|
||||
column: str
|
||||
where_clause: Optional[WhereClause] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class BinaryOp(FormulaNode):
|
||||
"""Binary operation: left op right."""
|
||||
operator: str
|
||||
left: FormulaNode
|
||||
right: FormulaNode
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnaryOp(FormulaNode):
|
||||
"""Unary operation: -expr or not expr."""
|
||||
operator: str
|
||||
operand: FormulaNode
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionCall(FormulaNode):
|
||||
"""Function call: func(args...)."""
|
||||
function_name: str
|
||||
arguments: list = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConditionalExpr(FormulaNode):
|
||||
"""Conditional: value_expr if condition [else else_expr].
|
||||
|
||||
Chainable: val1 if cond1 else val2 if cond2 else val3
|
||||
"""
|
||||
value_expr: FormulaNode
|
||||
condition: FormulaNode
|
||||
else_expr: Optional[FormulaNode] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FormulaDefinition:
|
||||
"""A complete formula definition for a column."""
|
||||
expression: FormulaNode
|
||||
source_text: str = ""
|
||||
|
||||
def __str__(self):
|
||||
return self.source_text
|
||||
386
src/myfasthtml/core/formula/dependency_graph.py
Normal file
386
src/myfasthtml/core/formula/dependency_graph.py
Normal file
@@ -0,0 +1,386 @@
|
||||
"""
|
||||
Dependency Graph (DAG) for formula columns.
|
||||
|
||||
Tracks column dependencies, propagates dirty flags, and provides
|
||||
topological ordering for incremental recalculation.
|
||||
|
||||
Node IDs use the format ``"table_name.column_id"`` for column-level
|
||||
granularity, designed to be extensible to ``"table_name.column_id[row]"``
|
||||
for cell-level overrides.
|
||||
"""
|
||||
import logging
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Set
|
||||
|
||||
from .dataclasses import (
|
||||
FormulaDefinition,
|
||||
ColumnRef,
|
||||
CrossTableRef,
|
||||
BinaryOp,
|
||||
UnaryOp,
|
||||
FunctionCall,
|
||||
ConditionalExpr,
|
||||
FormulaNode,
|
||||
)
|
||||
from .dsl.exceptions import FormulaCycleError
|
||||
|
||||
logger = logging.getLogger("DependencyGraph")
|
||||
|
||||
|
||||
@dataclass
|
||||
class DependencyNode:
|
||||
"""
|
||||
A node in the dependency graph.
|
||||
|
||||
Attributes:
|
||||
node_id: Unique identifier in the format ``"table.column"``.
|
||||
table: Table name.
|
||||
column: Column name.
|
||||
dirty: Whether this node needs recalculation.
|
||||
dirty_rows: Set of specific row indices that are dirty.
|
||||
Empty set means all rows are dirty.
|
||||
formula: The parsed FormulaDefinition for formula nodes.
|
||||
"""
|
||||
node_id: str
|
||||
table: str
|
||||
column: str
|
||||
dirty: bool = False
|
||||
dirty_rows: Set[int] = field(default_factory=set)
|
||||
formula: Optional[FormulaDefinition] = None
|
||||
|
||||
|
||||
class DependencyGraph:
|
||||
"""
|
||||
Directed Acyclic Graph of formula column dependencies.
|
||||
|
||||
Tracks which columns depend on which other columns and provides:
|
||||
- Dirty flag propagation via BFS
|
||||
- Topological ordering for recalculation (Kahn's algorithm)
|
||||
- Cycle detection before registering new formulas
|
||||
|
||||
The graph is bidirectional:
|
||||
- ``_dependents[A]`` = set of nodes that depend on A (forward edges)
|
||||
- ``_precedents[B]`` = set of nodes that B depends on (reverse edges)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._nodes: dict[str, DependencyNode] = {}
|
||||
# forward: A -> {B, C} means "B and C depend on A"
|
||||
self._dependents: dict[str, set[str]] = defaultdict(set)
|
||||
# reverse: B -> {A} means "B depends on A"
|
||||
self._precedents: dict[str, set[str]] = defaultdict(set)
|
||||
|
||||
def add_formula(
|
||||
self,
|
||||
table: str,
|
||||
column: str,
|
||||
formula: FormulaDefinition,
|
||||
) -> None:
|
||||
"""
|
||||
Register a formula for a column and add dependency edges.
|
||||
|
||||
Raises:
|
||||
FormulaCycleError: If adding this formula would create a cycle.
|
||||
|
||||
Args:
|
||||
table: Table name.
|
||||
column: Column name.
|
||||
formula: The parsed FormulaDefinition.
|
||||
"""
|
||||
logger.debug(f"add_formula {table}.{column}:{formula.source_text}")
|
||||
|
||||
node_id = self._make_node_id(table, column)
|
||||
|
||||
# Extract dependency node_ids from the formula AST
|
||||
dep_ids = self._extract_dependencies(formula.expression, table)
|
||||
|
||||
# Temporarily remove old edges to avoid stale dependencies
|
||||
self._remove_edges(node_id)
|
||||
|
||||
# Add new edges
|
||||
for dep_id in dep_ids:
|
||||
if dep_id == node_id:
|
||||
raise FormulaCycleError([node_id])
|
||||
self._dependents[dep_id].add(node_id)
|
||||
self._precedents[node_id].add(dep_id)
|
||||
|
||||
# Ensure all referenced nodes exist (as data nodes without formulas)
|
||||
for dep_id in dep_ids:
|
||||
if dep_id not in self._nodes:
|
||||
dep_table, dep_col = dep_id.split(".", 1)
|
||||
self._nodes[dep_id] = DependencyNode(
|
||||
node_id=dep_id,
|
||||
table=dep_table,
|
||||
column=dep_col,
|
||||
)
|
||||
|
||||
# Ensure formula node exists
|
||||
node = self._get_or_create_node(table, column)
|
||||
node.formula = formula
|
||||
node.dirty = True # New formula -> needs evaluation
|
||||
|
||||
# Detect cycles using Kahn's algorithm
|
||||
self._detect_cycles()
|
||||
|
||||
logger.debug("Added formula for %s depending on: %s", node_id, dep_ids)
|
||||
|
||||
def remove_formula(self, table: str, column: str) -> None:
|
||||
"""
|
||||
Remove a formula column and its edges from the graph.
|
||||
|
||||
Args:
|
||||
table: Table name.
|
||||
column: Column name.
|
||||
"""
|
||||
node_id = self._make_node_id(table, column)
|
||||
self._remove_edges(node_id)
|
||||
|
||||
if node_id in self._nodes:
|
||||
node = self._nodes[node_id]
|
||||
node.formula = None
|
||||
node.dirty = False
|
||||
node.dirty_rows.clear()
|
||||
# If the node has no dependents either, remove it
|
||||
if not self._dependents.get(node_id):
|
||||
del self._nodes[node_id]
|
||||
|
||||
logger.debug("Removed formula for %s", node_id)
|
||||
|
||||
def get_calculation_order(self, table: Optional[str] = None) -> list[DependencyNode]:
|
||||
"""
|
||||
Return dirty formula nodes in topological order.
|
||||
|
||||
Uses Kahn's algorithm (BFS-based topological sort).
|
||||
Only returns nodes with a formula that are dirty.
|
||||
|
||||
Args:
|
||||
table: If provided, filter to only nodes for this table.
|
||||
|
||||
Returns:
|
||||
List of dirty DependencyNode objects in calculation order.
|
||||
"""
|
||||
# Build in-degree map for nodes with formulas
|
||||
formula_nodes = {
|
||||
nid: node for nid, node in self._nodes.items()
|
||||
if node.formula is not None and node.dirty
|
||||
}
|
||||
|
||||
if not formula_nodes:
|
||||
return []
|
||||
|
||||
# Kahn's algorithm on the subgraph of formula nodes
|
||||
in_degree = {nid: 0 for nid in formula_nodes}
|
||||
for nid in formula_nodes:
|
||||
for prec_id in self._precedents.get(nid, set()):
|
||||
if prec_id in formula_nodes:
|
||||
in_degree[nid] += 1
|
||||
|
||||
queue = deque([nid for nid, deg in in_degree.items() if deg == 0])
|
||||
result = []
|
||||
|
||||
while queue:
|
||||
nid = queue.popleft()
|
||||
node = self._nodes[nid]
|
||||
if table is None or node.table == table:
|
||||
result.append(node)
|
||||
for dep_id in self._dependents.get(nid, set()):
|
||||
if dep_id in in_degree:
|
||||
in_degree[dep_id] -= 1
|
||||
if in_degree[dep_id] == 0:
|
||||
queue.append(dep_id)
|
||||
|
||||
return result
|
||||
|
||||
def clear_dirty(self, node_id: str) -> None:
|
||||
"""
|
||||
Clear dirty flags for a node after successful recalculation.
|
||||
|
||||
Args:
|
||||
node_id: The node ID in format ``"table.column"``.
|
||||
"""
|
||||
if node_id in self._nodes:
|
||||
node = self._nodes[node_id]
|
||||
node.dirty = False
|
||||
node.dirty_rows.clear()
|
||||
|
||||
def get_node(self, table: str, column: str) -> Optional[DependencyNode]:
|
||||
"""
|
||||
Get a node by table and column.
|
||||
|
||||
Args:
|
||||
table: Table name.
|
||||
column: Column name.
|
||||
|
||||
Returns:
|
||||
DependencyNode or None if not found.
|
||||
"""
|
||||
node_id = self._make_node_id(table, column)
|
||||
return self._nodes.get(node_id)
|
||||
|
||||
def has_formula(self, table: str, column: str) -> bool:
|
||||
"""
|
||||
Check if a column has a formula registered.
|
||||
|
||||
Args:
|
||||
table: Table name.
|
||||
column: Column name.
|
||||
|
||||
Returns:
|
||||
True if the column has a formula.
|
||||
"""
|
||||
node = self.get_node(table, column)
|
||||
return node is not None and node.formula is not None
|
||||
|
||||
def mark_dirty(
|
||||
self,
|
||||
table: str,
|
||||
column: str,
|
||||
rows: Optional[list[int]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Mark a column (and its transitive dependents) as dirty.
|
||||
|
||||
Uses BFS to propagate dirty flags through the dependency graph.
|
||||
|
||||
Args:
|
||||
table: Table name.
|
||||
column: Column name.
|
||||
rows: Specific row indices to mark dirty. None means all rows.
|
||||
"""
|
||||
node_id = self._make_node_id(table, column)
|
||||
self._mark_node_dirty(node_id, rows)
|
||||
|
||||
# BFS propagation through dependents
|
||||
queue = deque([node_id])
|
||||
visited = {node_id}
|
||||
|
||||
while queue:
|
||||
current_id = queue.popleft()
|
||||
for dep_id in self._dependents.get(current_id, set()):
|
||||
self._mark_node_dirty(dep_id, rows)
|
||||
if dep_id not in visited:
|
||||
visited.add(dep_id)
|
||||
queue.append(dep_id)
|
||||
|
||||
# ==================== Private helpers ====================
|
||||
|
||||
@staticmethod
|
||||
def _make_node_id(table: str, column: str) -> str:
|
||||
"""Create a standard node ID from table and column names."""
|
||||
return f"{table}.{column}"
|
||||
|
||||
def _get_or_create_node(self, table: str, column: str) -> DependencyNode:
|
||||
"""Get existing node or create a new one."""
|
||||
node_id = self._make_node_id(table, column)
|
||||
if node_id not in self._nodes:
|
||||
self._nodes[node_id] = DependencyNode(
|
||||
node_id=node_id,
|
||||
table=table,
|
||||
column=column,
|
||||
)
|
||||
return self._nodes[node_id]
|
||||
|
||||
def _mark_node_dirty(self, node_id: str, rows: Optional[list[int]]) -> None:
|
||||
"""Mark a specific node as dirty."""
|
||||
if node_id not in self._nodes:
|
||||
return
|
||||
node = self._nodes[node_id]
|
||||
node.dirty = True
|
||||
if rows is not None:
|
||||
node.dirty_rows.update(rows)
|
||||
else:
|
||||
node.dirty_rows.clear() # Empty = all rows dirty
|
||||
|
||||
def _remove_edges(self, node_id: str) -> None:
|
||||
"""Remove all edges connected to a node."""
|
||||
# Remove this node from its precedents' dependents sets
|
||||
for prec_id in list(self._precedents.get(node_id, set())):
|
||||
self._dependents[prec_id].discard(node_id)
|
||||
# Clear this node's precedents
|
||||
self._precedents[node_id].clear()
|
||||
|
||||
def _detect_cycles(self) -> None:
|
||||
"""
|
||||
Detect cycles in the full graph using Kahn's algorithm.
|
||||
|
||||
Raises:
|
||||
FormulaCycleError: If a cycle is detected.
|
||||
"""
|
||||
# Only check formula nodes
|
||||
formula_nodes = {
|
||||
nid for nid, node in self._nodes.items()
|
||||
if node.formula is not None
|
||||
}
|
||||
|
||||
if not formula_nodes:
|
||||
return
|
||||
|
||||
in_degree = {}
|
||||
for nid in formula_nodes:
|
||||
in_degree[nid] = 0
|
||||
for nid in formula_nodes:
|
||||
for prec_id in self._precedents.get(nid, set()):
|
||||
if prec_id in formula_nodes:
|
||||
in_degree[nid] = in_degree.get(nid, 0) + 1
|
||||
|
||||
queue = deque([nid for nid in formula_nodes if in_degree.get(nid, 0) == 0])
|
||||
processed = set()
|
||||
|
||||
while queue:
|
||||
nid = queue.popleft()
|
||||
processed.add(nid)
|
||||
for dep_id in self._dependents.get(nid, set()):
|
||||
if dep_id in formula_nodes:
|
||||
in_degree[dep_id] -= 1
|
||||
if in_degree[dep_id] == 0:
|
||||
queue.append(dep_id)
|
||||
|
||||
cycle_nodes = formula_nodes - processed
|
||||
if cycle_nodes:
|
||||
raise FormulaCycleError(sorted(cycle_nodes))
|
||||
|
||||
def _extract_dependencies(
|
||||
self,
|
||||
node: FormulaNode,
|
||||
current_table: str,
|
||||
) -> set[str]:
|
||||
"""
|
||||
Recursively extract all column dependency IDs from a formula AST.
|
||||
|
||||
Args:
|
||||
node: The AST node to walk.
|
||||
current_table: The table containing this formula (for ColumnRef).
|
||||
|
||||
Returns:
|
||||
Set of dependency node IDs (``"table.column"`` format).
|
||||
"""
|
||||
deps = set()
|
||||
|
||||
if isinstance(node, ColumnRef):
|
||||
deps.add(self._make_node_id(current_table, node.column))
|
||||
|
||||
elif isinstance(node, CrossTableRef):
|
||||
deps.add(self._make_node_id(node.table, node.column))
|
||||
# Also depend on the local column used in WHERE clause
|
||||
if node.where_clause is not None:
|
||||
deps.add(self._make_node_id(current_table, node.where_clause.local_column))
|
||||
|
||||
elif isinstance(node, BinaryOp):
|
||||
deps.update(self._extract_dependencies(node.left, current_table))
|
||||
deps.update(self._extract_dependencies(node.right, current_table))
|
||||
|
||||
elif isinstance(node, UnaryOp):
|
||||
deps.update(self._extract_dependencies(node.operand, current_table))
|
||||
|
||||
elif isinstance(node, FunctionCall):
|
||||
for arg in node.arguments:
|
||||
deps.update(self._extract_dependencies(arg, current_table))
|
||||
|
||||
elif isinstance(node, ConditionalExpr):
|
||||
deps.update(self._extract_dependencies(node.value_expr, current_table))
|
||||
deps.update(self._extract_dependencies(node.condition, current_table))
|
||||
if node.else_expr is not None:
|
||||
deps.update(self._extract_dependencies(node.else_expr, current_table))
|
||||
|
||||
return deps
|
||||
0
src/myfasthtml/core/formula/dsl/__init__.py
Normal file
0
src/myfasthtml/core/formula/dsl/__init__.py
Normal file
@@ -0,0 +1,180 @@
|
||||
"""
|
||||
Autocompletion engine for the DataGrid Formula DSL.
|
||||
|
||||
Provides context-aware suggestions for:
|
||||
- Column names (after ``{``)
|
||||
- Cross-table references (``{Table.``)
|
||||
- Built-in function names
|
||||
- Keywords: ``if``, ``else``, ``and``, ``or``, ``not``, ``WHERE``
|
||||
"""
|
||||
import re
|
||||
|
||||
from myfasthtml.core.dsl.base_completion import BaseCompletionEngine
|
||||
from myfasthtml.core.dsl.types import Position, Suggestion
|
||||
from myfasthtml.core.formula.evaluator import BUILTIN_FUNCTIONS
|
||||
from myfasthtml.core.utils import make_safe_id
|
||||
|
||||
FORMULA_KEYWORDS = [
|
||||
"if", "else", "and", "or", "not", "where",
|
||||
"between", "in", "isempty", "isnotempty", "isnan",
|
||||
"contains", "startswith", "endswith",
|
||||
"true", "false",
|
||||
]
|
||||
|
||||
|
||||
class FormulaCompletionEngine(BaseCompletionEngine):
|
||||
"""
|
||||
Context-aware completion engine for formula expressions.
|
||||
|
||||
Provides suggestions for column references, functions, and keywords.
|
||||
|
||||
Args:
|
||||
provider: DataGrid metadata provider (DataGridsManager or similar).
|
||||
table_name: Name of the current table in ``"namespace.name"`` format.
|
||||
"""
|
||||
|
||||
def __init__(self, provider, table_name: str):
|
||||
super().__init__(provider)
|
||||
self.table_name = table_name
|
||||
self._id = "formula_completion_engine#" + make_safe_id(table_name)
|
||||
|
||||
def detect_scope(self, text: str, current_line: int):
|
||||
"""Formula has no scope — always the same single-expression scope."""
|
||||
return None
|
||||
|
||||
def detect_context(self, text: str, cursor: Position, scope):
|
||||
"""
|
||||
Detect completion context based on cursor position in formula text.
|
||||
|
||||
Args:
|
||||
text: The full formula text.
|
||||
cursor: Cursor position (line, ch).
|
||||
scope: Unused (formulas have no scopes).
|
||||
|
||||
Returns:
|
||||
Context string: ``"column_ref"``, ``"cross_table"``,
|
||||
``"function"``, ``"keyword"``, or ``"general"``.
|
||||
"""
|
||||
# Get text up to cursor
|
||||
lines = text.split("\n")
|
||||
line_idx = min(cursor.line, len(lines) - 1)
|
||||
line_text = lines[line_idx]
|
||||
text_before = line_text[:cursor.ch]
|
||||
|
||||
# Check if we are inside a { ... } reference
|
||||
last_brace = text_before.rfind("{")
|
||||
if last_brace >= 0:
|
||||
inside = text_before[last_brace + 1:]
|
||||
if "}" not in inside:
|
||||
if "." in inside:
|
||||
return "cross_table"
|
||||
return "column_ref"
|
||||
|
||||
# Check if we are typing a function name (alphanumeric at word start)
|
||||
word_match = re.search(r"[a-z_][a-z0-9_]*$", text_before, re.IGNORECASE)
|
||||
if word_match:
|
||||
return "function_or_keyword"
|
||||
|
||||
return "general"
|
||||
|
||||
def get_suggestions(self, text: str, cursor: Position, scope, context) -> list:
|
||||
"""
|
||||
Generate suggestions based on the detected context.
|
||||
|
||||
Args:
|
||||
text: The full formula text.
|
||||
cursor: Cursor position.
|
||||
scope: Unused.
|
||||
context: String from ``detect_context``.
|
||||
|
||||
Returns:
|
||||
List of Suggestion objects.
|
||||
"""
|
||||
suggestions = []
|
||||
|
||||
if context == "column_ref":
|
||||
# Suggest columns from the current table
|
||||
suggestions += self._column_suggestions(self.table_name)
|
||||
|
||||
elif context == "cross_table":
|
||||
# Get the table name prefix from text_before
|
||||
lines = text.split("\n")
|
||||
line_text = lines[min(cursor.line, len(lines) - 1)]
|
||||
text_before = line_text[:cursor.ch]
|
||||
last_brace = text_before.rfind("{")
|
||||
inside = text_before[last_brace + 1:] if last_brace >= 0 else ""
|
||||
dot_pos = inside.rfind(".")
|
||||
table_prefix = inside[:dot_pos] if dot_pos >= 0 else ""
|
||||
|
||||
# Suggest columns from the referenced table
|
||||
if table_prefix:
|
||||
suggestions += self._column_suggestions(table_prefix)
|
||||
else:
|
||||
suggestions += self._table_suggestions()
|
||||
|
||||
elif context == "function_or_keyword":
|
||||
suggestions += self._function_suggestions()
|
||||
suggestions += self._keyword_suggestions()
|
||||
|
||||
else: # general
|
||||
suggestions += self._function_suggestions()
|
||||
suggestions += self._keyword_suggestions()
|
||||
suggestions += [
|
||||
Suggestion(
|
||||
label="{",
|
||||
detail="Column reference",
|
||||
insert_text="{",
|
||||
)
|
||||
]
|
||||
|
||||
return suggestions
|
||||
|
||||
# ==================== Private helpers ====================
|
||||
|
||||
def _column_suggestions(self, table_name: str) -> list:
|
||||
"""Get column name suggestions for a table."""
|
||||
try:
|
||||
columns = self.provider.list_columns(table_name)
|
||||
return [
|
||||
Suggestion(
|
||||
label=col,
|
||||
detail=f"Column from {table_name}",
|
||||
insert_text=col,
|
||||
)
|
||||
for col in (columns or [])
|
||||
]
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def _table_suggestions(self) -> list:
|
||||
"""Get table name suggestions."""
|
||||
try:
|
||||
tables = self.provider.list_tables()
|
||||
return [
|
||||
Suggestion(
|
||||
label=t,
|
||||
detail="Table",
|
||||
insert_text=t,
|
||||
)
|
||||
for t in (tables or [])
|
||||
]
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def _function_suggestions(self) -> list:
|
||||
"""Get built-in function name suggestions."""
|
||||
return [
|
||||
Suggestion(
|
||||
label=name,
|
||||
detail="Function",
|
||||
insert_text=f"{name}(",
|
||||
)
|
||||
for name in sorted(BUILTIN_FUNCTIONS.keys())
|
||||
]
|
||||
|
||||
def _keyword_suggestions(self) -> list:
|
||||
"""Get keyword suggestions."""
|
||||
return [
|
||||
Suggestion(label=kw, detail="Keyword", insert_text=kw)
|
||||
for kw in FORMULA_KEYWORDS
|
||||
]
|
||||
1
src/myfasthtml/core/formula/dsl/completion/__init__.py
Normal file
1
src/myfasthtml/core/formula/dsl/completion/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
79
src/myfasthtml/core/formula/dsl/definition.py
Normal file
79
src/myfasthtml/core/formula/dsl/definition.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""
|
||||
FormulaDSL definition for the DslEditor control.
|
||||
|
||||
Provides the Lark grammar and derived completions for the
|
||||
DataGrid Formula DSL (CodeMirror 5 Simple Mode).
|
||||
"""
|
||||
|
||||
from functools import cached_property
|
||||
from typing import Dict, Any
|
||||
|
||||
from myfasthtml.core.dsl.base import DSLDefinition
|
||||
from .grammar import FORMULA_GRAMMAR
|
||||
|
||||
|
||||
class FormulaDSL(DSLDefinition):
|
||||
"""
|
||||
DSL definition for DataGrid formula expressions.
|
||||
|
||||
Uses the Lark grammar from grammar.py to drive syntax highlighting
|
||||
and autocompletion in the DslEditor.
|
||||
"""
|
||||
|
||||
name: str = "Formula DSL"
|
||||
|
||||
def get_grammar(self) -> str:
|
||||
"""Return the Lark grammar for the formula DSL."""
|
||||
return FORMULA_GRAMMAR
|
||||
|
||||
@cached_property
|
||||
def simple_mode_config(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Return a hand-tuned CodeMirror 5 Simple Mode config for formula syntax.
|
||||
|
||||
Overrides the base class to provide optimized highlighting rules
|
||||
for column references, operators, functions, and keywords.
|
||||
"""
|
||||
return {
|
||||
"start": [
|
||||
# Column references: {ColumnName} or {Table.Column}
|
||||
{
|
||||
"regex": r"\{[A-Za-z_][A-Za-z0-9_.]*(?:\s+where\s+[A-Za-z_][A-Za-z0-9_.]*\s*=\s*[A-Za-z_][A-Za-z0-9_]*)?\}",
|
||||
"token": "variable-2",
|
||||
},
|
||||
# Function names before parenthesis
|
||||
{
|
||||
"regex": r"[a-z_][a-z0-9_]*(?=\s*\()",
|
||||
"token": "keyword",
|
||||
},
|
||||
# Keywords: if, else, and, or, not, where, between, in
|
||||
{
|
||||
"regex": r"\b(if|else|and|or|not|where|between|in|isempty|isnotempty|isnan|contains|startswith|endswith|true|false)\b",
|
||||
"token": "keyword",
|
||||
},
|
||||
# Numbers
|
||||
{
|
||||
"regex": r"[+-]?(\d+\.?\d*|\.\d+)([eE][+-]?\d+)?",
|
||||
"token": "number",
|
||||
},
|
||||
# Strings
|
||||
{
|
||||
"regex": r'"[^"\\]*"',
|
||||
"token": "string",
|
||||
},
|
||||
# Operators
|
||||
{
|
||||
"regex": r"[=!<>]=?|[+\-*/%^]",
|
||||
"token": "operator",
|
||||
},
|
||||
# Parentheses and brackets
|
||||
{
|
||||
"regex": r"[()[\],]",
|
||||
"token": "punctuation",
|
||||
},
|
||||
],
|
||||
"meta": {
|
||||
"dontIndentStates": ["comment"],
|
||||
"lineComment": "#",
|
||||
},
|
||||
}
|
||||
35
src/myfasthtml/core/formula/dsl/exceptions.py
Normal file
35
src/myfasthtml/core/formula/dsl/exceptions.py
Normal file
@@ -0,0 +1,35 @@
|
||||
class FormulaError(Exception):
|
||||
"""Base exception for formula errors."""
|
||||
pass
|
||||
|
||||
|
||||
class FormulaSyntaxError(FormulaError):
|
||||
"""Raised when the formula has syntax errors."""
|
||||
|
||||
def __init__(self, message, line=None, column=None, context=None):
|
||||
self.message = message
|
||||
self.line = line
|
||||
self.column = column
|
||||
self.context = context
|
||||
super().__init__(self._format_message())
|
||||
|
||||
def _format_message(self):
|
||||
parts = [self.message]
|
||||
if self.line is not None:
|
||||
parts.append(f"at line {self.line}")
|
||||
if self.column is not None:
|
||||
parts.append(f"col {self.column}")
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
class FormulaValidationError(FormulaError):
|
||||
"""Raised when the formula is syntactically correct but semantically invalid."""
|
||||
pass
|
||||
|
||||
|
||||
class FormulaCycleError(FormulaError):
|
||||
"""Raised when formula dependencies contain a cycle."""
|
||||
|
||||
def __init__(self, cycle_nodes):
|
||||
self.cycle_nodes = cycle_nodes
|
||||
super().__init__(f"Circular dependency detected involving: {', '.join(cycle_nodes)}")
|
||||
100
src/myfasthtml/core/formula/dsl/grammar.py
Normal file
100
src/myfasthtml/core/formula/dsl/grammar.py
Normal file
@@ -0,0 +1,100 @@
|
||||
FORMULA_GRAMMAR = r"""
|
||||
start: expression
|
||||
|
||||
// ==================== Top-level expression ====================
|
||||
|
||||
?expression: conditional_expr
|
||||
|
||||
// Suffix-if: value_expr if condition [else expression]
|
||||
// Right-associative for chaining: a if c1 else b if c2 else d
|
||||
?conditional_expr: or_expr "if" or_expr "else" conditional_expr -> conditional_with_else
|
||||
| or_expr "if" or_expr -> conditional_no_else
|
||||
| or_expr
|
||||
|
||||
// ==================== Logical ====================
|
||||
|
||||
?or_expr: and_expr ("or" and_expr)* -> or_op
|
||||
?and_expr: not_expr ("and" not_expr)* -> and_op
|
||||
?not_expr: "not" not_expr -> not_op
|
||||
| comparison
|
||||
|
||||
// ==================== Comparison ====================
|
||||
|
||||
?comparison: addition comp_op addition -> comparison_expr
|
||||
| addition "in" "[" literal ("," literal)* "]" -> in_expr
|
||||
| addition "between" addition "and" addition -> between_expr
|
||||
| addition "contains" addition -> contains_expr
|
||||
| addition "startswith" addition -> startswith_expr
|
||||
| addition "endswith" addition -> endswith_expr
|
||||
| addition "isempty" -> isempty_expr
|
||||
| addition "isnotempty" -> isnotempty_expr
|
||||
| addition "isnan" -> isnan_expr
|
||||
| addition
|
||||
|
||||
comp_op: "==" -> eq
|
||||
| "!=" -> ne
|
||||
| "<=" -> le
|
||||
| "<" -> lt
|
||||
| ">=" -> ge
|
||||
| ">" -> gt
|
||||
|
||||
// ==================== Arithmetic ====================
|
||||
|
||||
?addition: multiplication (add_op multiplication)* -> add_expr
|
||||
?multiplication: power (mul_op power)* -> mul_expr
|
||||
?power: unary ("^" unary)* -> pow_expr
|
||||
|
||||
add_op: "+" -> plus
|
||||
| "-" -> minus
|
||||
|
||||
mul_op: "*" -> times
|
||||
| "/" -> divide
|
||||
| "%" -> modulo
|
||||
|
||||
?unary: "-" unary -> neg
|
||||
| atom
|
||||
|
||||
// ==================== Atoms ====================
|
||||
|
||||
?atom: function_call
|
||||
| cross_table_ref
|
||||
| column_ref
|
||||
| literal
|
||||
| "(" expression ")" -> paren
|
||||
|
||||
// ==================== References ====================
|
||||
|
||||
// Cross-table must be checked before column_ref since both use { }
|
||||
// TABLE_NAME.COL_NAME with optional WHERE clause
|
||||
// Note: whitespace around "where" and "=" is handled by %ignore
|
||||
cross_table_ref: "{" TABLE_COL_REF "}" -> cross_ref_simple
|
||||
| "{" TABLE_COL_REF "where" where_clause "}" -> cross_ref_where
|
||||
|
||||
column_ref: "{" COL_NAME "}"
|
||||
|
||||
where_clause: TABLE_COL_REF "=" COL_NAME
|
||||
|
||||
// TABLE_COL_REF matches "TableName.ColumnName" (dot-separated, no spaces)
|
||||
TABLE_COL_REF: /[A-Za-z_][A-Za-z0-9_]*\.[A-Za-z_][A-Za-z0-9_]*/
|
||||
COL_NAME: /[A-Za-z_][A-Za-z0-9_ ]*/
|
||||
|
||||
// ==================== Functions ====================
|
||||
|
||||
function_call: FUNC_NAME "(" [expression ("," expression)*] ")"
|
||||
FUNC_NAME: /[a-z_][a-z0-9_]*/
|
||||
|
||||
// ==================== Literals ====================
|
||||
|
||||
?literal: NUMBER -> number_literal
|
||||
| ESCAPED_STRING -> string_literal
|
||||
| "true"i -> true_literal
|
||||
| "false"i -> false_literal
|
||||
|
||||
// ==================== Terminals ====================
|
||||
|
||||
NUMBER: /[+-]?(\d+\.?\d*|\.\d+)([eE][+-]?\d+)?/
|
||||
ESCAPED_STRING: "\"" /[^"\\]*/ "\""
|
||||
|
||||
%ignore /[ \t\f]+/
|
||||
%ignore /\#[^\n]*/
|
||||
"""
|
||||
85
src/myfasthtml/core/formula/dsl/parser.py
Normal file
85
src/myfasthtml/core/formula/dsl/parser.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""
|
||||
Formula DSL parser using Lark.
|
||||
|
||||
Handles parsing of formula expression strings into a Lark AST.
|
||||
No indentation handling needed — formulas are single-line expressions.
|
||||
"""
|
||||
from lark import Lark, UnexpectedInput
|
||||
|
||||
from .exceptions import FormulaSyntaxError
|
||||
from .grammar import FORMULA_GRAMMAR
|
||||
|
||||
|
||||
class FormulaParser:
|
||||
"""
|
||||
Parser for the DataGrid formula language.
|
||||
|
||||
Uses Lark LALR parser without indentation handling.
|
||||
|
||||
Example:
|
||||
parser = FormulaParser()
|
||||
tree = parser.parse("{Price} * {Quantity}")
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._parser = Lark(
|
||||
FORMULA_GRAMMAR,
|
||||
parser="lalr",
|
||||
propagate_positions=False,
|
||||
)
|
||||
|
||||
def parse(self, text: str):
|
||||
"""
|
||||
Parse a formula expression string into a Lark Tree.
|
||||
|
||||
Args:
|
||||
text: The formula expression text.
|
||||
|
||||
Returns:
|
||||
lark.Tree: The parsed AST.
|
||||
|
||||
Raises:
|
||||
FormulaSyntaxError: If the text has syntax errors.
|
||||
"""
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return None
|
||||
|
||||
try:
|
||||
return self._parser.parse(text)
|
||||
except UnexpectedInput as e:
|
||||
context = None
|
||||
if hasattr(e, "get_context"):
|
||||
context = e.get_context(text)
|
||||
|
||||
raise FormulaSyntaxError(
|
||||
message=self._format_error_message(e),
|
||||
line=getattr(e, "line", None),
|
||||
column=getattr(e, "column", None),
|
||||
context=context,
|
||||
) from e
|
||||
|
||||
def _format_error_message(self, error: UnexpectedInput) -> str:
|
||||
"""Format a user-friendly error message from a Lark exception."""
|
||||
if hasattr(error, "expected"):
|
||||
expected = list(error.expected)
|
||||
if len(expected) == 1:
|
||||
return f"Expected {expected[0]}"
|
||||
elif len(expected) <= 5:
|
||||
return f"Expected one of: {', '.join(expected)}"
|
||||
else:
|
||||
return "Unexpected input"
|
||||
|
||||
return str(error)
|
||||
|
||||
|
||||
# Singleton parser instance
|
||||
_parser_instance = None
|
||||
|
||||
|
||||
def get_parser() -> FormulaParser:
|
||||
"""Get the singleton FormulaParser instance."""
|
||||
global _parser_instance
|
||||
if _parser_instance is None:
|
||||
_parser_instance = FormulaParser()
|
||||
return _parser_instance
|
||||
274
src/myfasthtml/core/formula/dsl/transformer.py
Normal file
274
src/myfasthtml/core/formula/dsl/transformer.py
Normal file
@@ -0,0 +1,274 @@
|
||||
"""
|
||||
Formula DSL Transformer.
|
||||
|
||||
Converts a Lark AST tree into FormulaDefinition and related AST dataclasses.
|
||||
"""
|
||||
from lark import Transformer
|
||||
|
||||
from ..dataclasses import (
|
||||
FormulaDefinition,
|
||||
FormulaNode,
|
||||
LiteralNode,
|
||||
ColumnRef,
|
||||
CrossTableRef,
|
||||
WhereClause,
|
||||
BinaryOp,
|
||||
UnaryOp,
|
||||
FunctionCall,
|
||||
ConditionalExpr,
|
||||
)
|
||||
|
||||
|
||||
class FormulaTransformer(Transformer):
|
||||
"""
|
||||
Transforms the Lark parse tree into FormulaDefinition AST dataclasses.
|
||||
|
||||
Handles left-associative folding for arithmetic and logical operators.
|
||||
Handles right-associative chaining for conditional expressions.
|
||||
"""
|
||||
|
||||
# ==================== Top-level ====================
|
||||
|
||||
def start(self, items):
|
||||
"""Return the FormulaDefinition wrapping the single expression."""
|
||||
expr = items[0]
|
||||
return FormulaDefinition(expression=expr)
|
||||
|
||||
# ==================== Conditionals ====================
|
||||
|
||||
def conditional_with_else(self, items):
|
||||
"""value_expr if condition else else_expr"""
|
||||
value_expr, condition, else_expr = items
|
||||
return ConditionalExpr(
|
||||
value_expr=value_expr,
|
||||
condition=condition,
|
||||
else_expr=else_expr,
|
||||
)
|
||||
|
||||
def conditional_no_else(self, items):
|
||||
"""value_expr if condition"""
|
||||
value_expr, condition = items
|
||||
return ConditionalExpr(
|
||||
value_expr=value_expr,
|
||||
condition=condition,
|
||||
else_expr=None,
|
||||
)
|
||||
|
||||
# ==================== Logical ====================
|
||||
|
||||
def or_op(self, items):
|
||||
"""Fold left-associatively: a or b or c -> BinaryOp(or, BinaryOp(or, a, b), c)"""
|
||||
return self._fold_left(items, "or")
|
||||
|
||||
def and_op(self, items):
|
||||
"""Fold left-associatively: a and b and c -> BinaryOp(and, BinaryOp(and, a, b), c)"""
|
||||
return self._fold_left(items, "and")
|
||||
|
||||
def not_op(self, items):
|
||||
"""not expr"""
|
||||
return UnaryOp(operator="not", operand=items[0])
|
||||
|
||||
# ==================== Comparisons ====================
|
||||
|
||||
def comparison_expr(self, items):
|
||||
"""left comp_op right"""
|
||||
left, op, right = items
|
||||
return BinaryOp(operator=op, left=left, right=right)
|
||||
|
||||
def in_expr(self, items):
|
||||
"""operand in [literal, ...]"""
|
||||
operand = items[0]
|
||||
values = list(items[1:])
|
||||
return BinaryOp(operator="in", left=operand, right=LiteralNode(value=values))
|
||||
|
||||
def between_expr(self, items):
|
||||
"""operand between low and high"""
|
||||
operand, low, high = items
|
||||
return BinaryOp(
|
||||
operator="between",
|
||||
left=operand,
|
||||
right=LiteralNode(value=[low, high]),
|
||||
)
|
||||
|
||||
def contains_expr(self, items):
|
||||
left, right = items
|
||||
return BinaryOp(operator="contains", left=left, right=right)
|
||||
|
||||
def startswith_expr(self, items):
|
||||
left, right = items
|
||||
return BinaryOp(operator="startswith", left=left, right=right)
|
||||
|
||||
def endswith_expr(self, items):
|
||||
left, right = items
|
||||
return BinaryOp(operator="endswith", left=left, right=right)
|
||||
|
||||
def isempty_expr(self, items):
|
||||
return UnaryOp(operator="isempty", operand=items[0])
|
||||
|
||||
def isnotempty_expr(self, items):
|
||||
return UnaryOp(operator="isnotempty", operand=items[0])
|
||||
|
||||
def isnan_expr(self, items):
|
||||
return UnaryOp(operator="isnan", operand=items[0])
|
||||
|
||||
# ==================== Comparison operators ====================
|
||||
|
||||
def eq(self, items):
|
||||
return "=="
|
||||
|
||||
def ne(self, items):
|
||||
return "!="
|
||||
|
||||
def le(self, items):
|
||||
return "<="
|
||||
|
||||
def lt(self, items):
|
||||
return "<"
|
||||
|
||||
def ge(self, items):
|
||||
return ">="
|
||||
|
||||
def gt(self, items):
|
||||
return ">"
|
||||
|
||||
# ==================== Arithmetic ====================
|
||||
|
||||
def add_expr(self, items):
|
||||
"""Fold left-associatively with alternating operands and operators."""
|
||||
return self._fold_binary_with_ops(items)
|
||||
|
||||
def mul_expr(self, items):
|
||||
"""Fold left-associatively with alternating operands and operators."""
|
||||
return self._fold_binary_with_ops(items)
|
||||
|
||||
def pow_expr(self, items):
|
||||
"""Fold left-associatively for power expressions (^ is left-assoc here)."""
|
||||
# pow_expr items are [base, exp1, exp2, ...] — operator is always "^"
|
||||
result = items[0]
|
||||
for exp in items[1:]:
|
||||
result = BinaryOp(operator="^", left=result, right=exp)
|
||||
return result
|
||||
|
||||
def plus(self, items):
|
||||
return "+"
|
||||
|
||||
def minus(self, items):
|
||||
return "-"
|
||||
|
||||
def times(self, items):
|
||||
return "*"
|
||||
|
||||
def divide(self, items):
|
||||
return "/"
|
||||
|
||||
def modulo(self, items):
|
||||
return "%"
|
||||
|
||||
def neg(self, items):
|
||||
"""Unary negation: -expr"""
|
||||
return UnaryOp(operator="-", operand=items[0])
|
||||
|
||||
# ==================== References ====================
|
||||
|
||||
def cross_ref_simple(self, items):
|
||||
"""{ Table.Column }"""
|
||||
table_col = str(items[0])
|
||||
table, column = table_col.split(".", 1)
|
||||
return CrossTableRef(table=table, column=column)
|
||||
|
||||
def cross_ref_where(self, items):
|
||||
"""{ Table.Column WHERE remote_table.remote_col = local_col }"""
|
||||
table_col = str(items[0])
|
||||
where = items[1]
|
||||
table, column = table_col.split(".", 1)
|
||||
return CrossTableRef(table=table, column=column, where_clause=where)
|
||||
|
||||
def column_ref(self, items):
|
||||
"""{ ColumnName }"""
|
||||
col_name = str(items[0]).strip()
|
||||
return ColumnRef(column=col_name)
|
||||
|
||||
def where_clause(self, items):
|
||||
"""TABLE_COL_REF = COL_NAME"""
|
||||
remote_table_col = str(items[0])
|
||||
local_col = str(items[1]).strip()
|
||||
remote_table, remote_col = remote_table_col.split(".", 1)
|
||||
return WhereClause(
|
||||
remote_table=remote_table,
|
||||
remote_column=remote_col,
|
||||
local_column=local_col,
|
||||
)
|
||||
|
||||
# ==================== Functions ====================
|
||||
|
||||
def function_call(self, items):
|
||||
"""func_name(arg1, arg2, ...)"""
|
||||
func_name = str(items[0]).lower()
|
||||
args = list(items[1:])
|
||||
return FunctionCall(function_name=func_name, arguments=args)
|
||||
|
||||
# ==================== Literals ====================
|
||||
|
||||
def number_literal(self, items):
|
||||
value = str(items[0])
|
||||
if "." in value or "e" in value.lower():
|
||||
return LiteralNode(value=float(value))
|
||||
try:
|
||||
return LiteralNode(value=int(value))
|
||||
except ValueError:
|
||||
return LiteralNode(value=float(value))
|
||||
|
||||
def string_literal(self, items):
|
||||
raw = str(items[0])
|
||||
# Remove surrounding double quotes
|
||||
if raw.startswith('"') and raw.endswith('"'):
|
||||
return LiteralNode(value=raw[1:-1])
|
||||
return LiteralNode(value=raw)
|
||||
|
||||
def true_literal(self, items):
|
||||
return LiteralNode(value=True)
|
||||
|
||||
def false_literal(self, items):
|
||||
return LiteralNode(value=False)
|
||||
|
||||
def paren(self, items):
|
||||
"""Parenthesized expression — transparent pass-through."""
|
||||
return items[0]
|
||||
|
||||
# ==================== Helpers ====================
|
||||
|
||||
def _fold_left(self, items: list, op: str) -> FormulaNode:
|
||||
"""
|
||||
Fold a list of operands left-associatively with a fixed operator.
|
||||
|
||||
Args:
|
||||
items: List of FormulaNode operands.
|
||||
op: Operator string.
|
||||
|
||||
Returns:
|
||||
Left-folded BinaryOp tree, or the single item if only one.
|
||||
"""
|
||||
result = items[0]
|
||||
for operand in items[1:]:
|
||||
result = BinaryOp(operator=op, left=result, right=operand)
|
||||
return result
|
||||
|
||||
def _fold_binary_with_ops(self, items: list) -> FormulaNode:
|
||||
"""
|
||||
Fold a list of alternating [operand, op, operand, op, operand, ...]
|
||||
left-associatively.
|
||||
|
||||
Args:
|
||||
items: Alternating list: [expr, op_str, expr, op_str, expr, ...]
|
||||
|
||||
Returns:
|
||||
Left-folded BinaryOp tree.
|
||||
"""
|
||||
result = items[0]
|
||||
i = 1
|
||||
while i < len(items):
|
||||
op = items[i]
|
||||
right = items[i + 1]
|
||||
result = BinaryOp(operator=op, left=result, right=right)
|
||||
i += 2
|
||||
return result
|
||||
398
src/myfasthtml/core/formula/engine.py
Normal file
398
src/myfasthtml/core/formula/engine.py
Normal file
@@ -0,0 +1,398 @@
|
||||
"""
|
||||
Formula Engine — facade orchestrating parsing, DAG, and evaluation.
|
||||
|
||||
Coordinates:
|
||||
- Parsing formula text via the DSL parser
|
||||
- Registering formulas and their dependencies in the DependencyGraph
|
||||
- Evaluating dirty formula columns row-by-row via FormulaEvaluator
|
||||
- Updating ns_fast_access caches in the DatagridStore
|
||||
"""
|
||||
import logging
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .dataclasses import FormulaDefinition, WhereClause
|
||||
from .dependency_graph import DependencyGraph
|
||||
from .dsl.parser import get_parser
|
||||
from .dsl.transformer import FormulaTransformer
|
||||
from .evaluator import FormulaEvaluator
|
||||
|
||||
logger = logging.getLogger("FormulaEngine")
|
||||
|
||||
# Callback that returns a DatagridStore-like object for a given table name
|
||||
RegistryResolver = Callable[[str], Any]
|
||||
|
||||
|
||||
def parse_formula(text: str) -> FormulaDefinition | None:
|
||||
"""Parse a formula expression string into a FormulaDefinition AST.
|
||||
|
||||
Args:
|
||||
text: The formula expression string.
|
||||
|
||||
Returns:
|
||||
FormulaDefinition on success, None if text is empty.
|
||||
|
||||
Raises:
|
||||
FormulaSyntaxError: If the formula text is syntactically invalid.
|
||||
"""
|
||||
text = text.strip() if text else ""
|
||||
if not text:
|
||||
return None
|
||||
|
||||
parser = get_parser()
|
||||
tree = parser.parse(text)
|
||||
if tree is None:
|
||||
return None
|
||||
|
||||
transformer = FormulaTransformer()
|
||||
formula = transformer.transform(tree)
|
||||
formula.source_text = text
|
||||
return formula
|
||||
|
||||
|
||||
class FormulaEngine:
|
||||
"""
|
||||
Facade for the formula calculation system.
|
||||
|
||||
Orchestrates formula parsing, dependency tracking, and incremental
|
||||
recalculation of formula columns.
|
||||
|
||||
Args:
|
||||
registry_resolver: Callback that takes a table name and returns
|
||||
the DatagridStore for that table (used for cross-table refs).
|
||||
Provided by DataGridsManager.
|
||||
"""
|
||||
|
||||
def __init__(self, registry_resolver: Optional[RegistryResolver] = None):
|
||||
self._graph = DependencyGraph()
|
||||
self._registry_resolver = registry_resolver
|
||||
# Cache of parsed formulas: {(table, col): FormulaDefinition}
|
||||
self._formulas: dict[tuple[str, str], FormulaDefinition] = {}
|
||||
|
||||
def set_formula(self, table: str, col: str, formula_text: str) -> None:
|
||||
"""
|
||||
Parse and register a formula for a column.
|
||||
|
||||
Args:
|
||||
table: Table name.
|
||||
col: Column name.
|
||||
formula_text: The formula expression string.
|
||||
|
||||
Raises:
|
||||
FormulaSyntaxError: If the formula is syntactically invalid.
|
||||
FormulaCycleError: If the formula would create a circular dependency.
|
||||
"""
|
||||
formula_text = formula_text.strip() if formula_text else ""
|
||||
if not formula_text:
|
||||
self.remove_formula(table, col)
|
||||
return
|
||||
|
||||
formula = parse_formula(formula_text)
|
||||
if formula is None:
|
||||
self.remove_formula(table, col)
|
||||
return
|
||||
|
||||
# Registers in DAG and raises FormulaCycleError if cycle detected
|
||||
self._graph.add_formula(table, col, formula)
|
||||
self._formulas[(table, col)] = formula
|
||||
|
||||
logger.debug("Formula set for %s.%s: %s", table, col, formula_text)
|
||||
|
||||
def remove_formula(self, table: str, col: str) -> None:
|
||||
"""
|
||||
Remove a formula column from the engine.
|
||||
|
||||
Args:
|
||||
table: Table name.
|
||||
col: Column name.
|
||||
"""
|
||||
self._graph.remove_formula(table, col)
|
||||
self._formulas.pop((table, col), None)
|
||||
|
||||
def mark_data_changed(
|
||||
self,
|
||||
table: str,
|
||||
col: str,
|
||||
rows: Optional[list[int]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Mark a column's data as changed, propagating dirty flags.
|
||||
|
||||
Call this when source data is modified so that dependent formula
|
||||
columns are re-evaluated on next render.
|
||||
|
||||
Args:
|
||||
table: Table name.
|
||||
col: Column name.
|
||||
rows: Specific row indices that changed. None means all rows.
|
||||
"""
|
||||
self._graph.mark_dirty(table, col, rows)
|
||||
|
||||
def recalculate_if_needed(self, table: str, store: Any) -> bool:
|
||||
"""
|
||||
Recalculate all dirty formula columns for a table.
|
||||
|
||||
Should be called at the start of ``mk_body_content_page()`` to
|
||||
ensure formula columns are up-to-date before rendering.
|
||||
|
||||
Updates ``store.ns_fast_access`` and ``store.ns_row_data`` in place.
|
||||
|
||||
Args:
|
||||
table: Table name.
|
||||
store: The DatagridStore instance for this table.
|
||||
|
||||
Returns:
|
||||
True if any columns were recalculated, False otherwise.
|
||||
"""
|
||||
dirty_nodes = self._graph.get_calculation_order(table=table)
|
||||
|
||||
if not dirty_nodes:
|
||||
return False
|
||||
|
||||
for node in dirty_nodes:
|
||||
formula = node.formula
|
||||
if formula is None:
|
||||
continue
|
||||
self._evaluate_column(table, node.column, formula, store)
|
||||
self._graph.clear_dirty(node.node_id)
|
||||
|
||||
# Rebuild ns_row_data after recalculation
|
||||
if dirty_nodes and store.ns_fast_access:
|
||||
self._rebuild_row_data(store)
|
||||
|
||||
return True
|
||||
|
||||
def has_formula(self, table: str, col: str) -> bool:
|
||||
"""
|
||||
Check if a column has a formula registered.
|
||||
|
||||
Args:
|
||||
table: Table name.
|
||||
col: Column name.
|
||||
|
||||
Returns:
|
||||
True if the column has a registered formula.
|
||||
"""
|
||||
return self._graph.has_formula(table, col)
|
||||
|
||||
def get_formula_text(self, table: str, col: str) -> Optional[str]:
|
||||
"""
|
||||
Get the source text of a registered formula.
|
||||
|
||||
Args:
|
||||
table: Table name.
|
||||
col: Column name.
|
||||
|
||||
Returns:
|
||||
Formula source text or None if not registered.
|
||||
"""
|
||||
formula = self._formulas.get((table, col))
|
||||
return formula.source_text if formula else None
|
||||
|
||||
# ==================== Private helpers ====================
|
||||
|
||||
def _evaluate_column(
|
||||
self,
|
||||
table: str,
|
||||
col: str,
|
||||
formula: FormulaDefinition,
|
||||
store: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Evaluate a formula column row-by-row and update ns_fast_access.
|
||||
|
||||
Args:
|
||||
table: Table name.
|
||||
col: Column name.
|
||||
formula: The parsed FormulaDefinition.
|
||||
store: The DatagridStore with ns_fast_access and ns_row_data.
|
||||
"""
|
||||
if store.ns_row_data is None or len(store.ns_row_data) == 0:
|
||||
return
|
||||
|
||||
n_rows = len(store.ns_row_data)
|
||||
resolver = self._make_cross_table_resolver(table)
|
||||
evaluator = FormulaEvaluator(cross_table_resolver=resolver)
|
||||
|
||||
# Ensure ns_fast_access exists before the loop so that formula columns
|
||||
# evaluated earlier in the same pass are visible to subsequent columns.
|
||||
if store.ns_fast_access is None:
|
||||
store.ns_fast_access = {}
|
||||
|
||||
results = np.empty(n_rows, dtype=object)
|
||||
|
||||
for row_index in range(n_rows):
|
||||
# Build row_data from ns_fast_access so that formula columns evaluated
|
||||
# earlier in this pass (e.g. B) are available to dependent columns (e.g. C).
|
||||
row_data = {
|
||||
c: arr[row_index]
|
||||
for c, arr in store.ns_fast_access.items()
|
||||
if arr is not None and row_index < len(arr)
|
||||
}
|
||||
results[row_index] = evaluator.evaluate(formula, row_data, row_index)
|
||||
|
||||
store.ns_fast_access[col] = results
|
||||
|
||||
logger.debug("Evaluated formula column %s.%s (%d rows)", table, col, n_rows)
|
||||
|
||||
def _rebuild_row_data(self, store: Any) -> None:
|
||||
"""
|
||||
Rebuild ns_row_data to include formula column results.
|
||||
|
||||
This ensures formula values are available to dependent formulas
|
||||
in subsequent evaluation passes.
|
||||
|
||||
Args:
|
||||
store: The DatagridStore to update.
|
||||
"""
|
||||
if store.ns_fast_access is None:
|
||||
return
|
||||
|
||||
n_rows = len(store.ns_row_data)
|
||||
for row_index in range(n_rows):
|
||||
row = store.ns_row_data[row_index]
|
||||
for col, arr in store.ns_fast_access.items():
|
||||
if arr is not None and row_index < len(arr):
|
||||
row[col] = arr[row_index]
|
||||
|
||||
def _make_cross_table_resolver(self, current_table: str):
|
||||
"""
|
||||
Create a cross-table resolver callback for the given table context.
|
||||
|
||||
Resolution strategy:
|
||||
1. Explicit WHERE clause: scan remote column for matching rows.
|
||||
2. Implicit join by ``id`` column: match rows where both tables share
|
||||
the same id value.
|
||||
3. Fallback: match by row_index.
|
||||
|
||||
Args:
|
||||
current_table: The table that contains the formula.
|
||||
|
||||
Returns:
|
||||
A callable ``resolver(table, column, where_clause, row_index) -> value``.
|
||||
"""
|
||||
|
||||
def resolver(
|
||||
remote_table: str,
|
||||
remote_column: str,
|
||||
where_clause: Optional[WhereClause],
|
||||
row_index: int,
|
||||
) -> Any:
|
||||
if self._registry_resolver is None:
|
||||
logger.warning(
|
||||
"No registry_resolver set for cross-table ref %s.%s",
|
||||
remote_table, remote_column,
|
||||
)
|
||||
return None
|
||||
|
||||
remote_store = self._registry_resolver(remote_table)
|
||||
if remote_store is None:
|
||||
logger.warning("Table '%s' not found in registry", remote_table)
|
||||
return None
|
||||
|
||||
ns = remote_store.ns_fast_access
|
||||
if not ns or remote_column not in ns:
|
||||
logger.debug(
|
||||
"Column '%s' not found in table '%s'", remote_column, remote_table
|
||||
)
|
||||
return None
|
||||
|
||||
remote_array = ns[remote_column]
|
||||
|
||||
# Strategy 1: Explicit WHERE clause
|
||||
if where_clause is not None:
|
||||
return self._resolve_with_where(
|
||||
where_clause, remote_store, remote_column,
|
||||
remote_array, current_table, row_index,
|
||||
)
|
||||
|
||||
# Strategy 2: Implicit join by 'id' column
|
||||
current_store = self._registry_resolver(current_table)
|
||||
if (
|
||||
current_store is not None
|
||||
and current_store.ns_fast_access is not None
|
||||
and "id" in current_store.ns_fast_access
|
||||
and "id" in ns
|
||||
):
|
||||
local_id_arr = current_store.ns_fast_access["id"]
|
||||
remote_id_arr = ns["id"]
|
||||
if row_index < len(local_id_arr):
|
||||
local_id = local_id_arr[row_index]
|
||||
# Find first matching row in remote table
|
||||
matches = np.where(remote_id_arr == local_id)[0]
|
||||
if len(matches) > 0:
|
||||
return remote_array[matches[0]]
|
||||
return None
|
||||
|
||||
# Strategy 3: Fallback — match by row_index
|
||||
if row_index < len(remote_array):
|
||||
return remote_array[row_index]
|
||||
return None
|
||||
|
||||
return resolver
|
||||
|
||||
def _resolve_with_where(
|
||||
self,
|
||||
where_clause: WhereClause,
|
||||
remote_store: Any,
|
||||
remote_column: str,
|
||||
remote_array: Any,
|
||||
current_table: str,
|
||||
row_index: int,
|
||||
) -> Any:
|
||||
"""
|
||||
Resolve a cross-table reference using an explicit WHERE clause.
|
||||
|
||||
Args:
|
||||
where_clause: The parsed WHERE clause.
|
||||
remote_store: DatagridStore for the remote table.
|
||||
remote_column: Column to return value from.
|
||||
remote_array: numpy array of the remote column values.
|
||||
current_table: Table containing the formula.
|
||||
row_index: Current row being evaluated.
|
||||
|
||||
Returns:
|
||||
The value from the first matching remote row, or None.
|
||||
"""
|
||||
remote_ns = remote_store.ns_fast_access
|
||||
if not remote_ns:
|
||||
return None
|
||||
|
||||
# Get the remote key column array
|
||||
remote_key_col = where_clause.remote_column
|
||||
if remote_key_col not in remote_ns:
|
||||
logger.debug(
|
||||
"WHERE key column '%s' not found in remote table", remote_key_col
|
||||
)
|
||||
return None
|
||||
|
||||
remote_key_array = remote_ns[remote_key_col]
|
||||
|
||||
# Get the local value to compare
|
||||
current_store = self._registry_resolver(current_table) if self._registry_resolver else None
|
||||
if current_store is None or current_store.ns_fast_access is None:
|
||||
return None
|
||||
|
||||
local_col = where_clause.local_column
|
||||
if local_col not in current_store.ns_fast_access:
|
||||
logger.debug("WHERE local column '%s' not found", local_col)
|
||||
return None
|
||||
|
||||
local_array = current_store.ns_fast_access[local_col]
|
||||
if row_index >= len(local_array):
|
||||
return None
|
||||
|
||||
local_value = local_array[row_index]
|
||||
|
||||
# Find matching rows
|
||||
try:
|
||||
matches = np.where(remote_key_array == local_value)[0]
|
||||
except Exception:
|
||||
matches = []
|
||||
|
||||
if len(matches) == 0:
|
||||
return None
|
||||
|
||||
# Return value from first match (use aggregation functions for multi-row)
|
||||
return remote_array[matches[0]]
|
||||
522
src/myfasthtml/core/formula/evaluator.py
Normal file
522
src/myfasthtml/core/formula/evaluator.py
Normal file
@@ -0,0 +1,522 @@
|
||||
"""
|
||||
Formula Evaluator.
|
||||
|
||||
Evaluates a FormulaDefinition AST row-by-row using column data.
|
||||
"""
|
||||
import logging
|
||||
import math
|
||||
from datetime import date, datetime
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from .dataclasses import (
|
||||
FormulaNode,
|
||||
FormulaDefinition,
|
||||
LiteralNode,
|
||||
ColumnRef,
|
||||
CrossTableRef,
|
||||
BinaryOp,
|
||||
UnaryOp,
|
||||
FunctionCall,
|
||||
ConditionalExpr,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("FormulaEvaluator")
|
||||
|
||||
# Type alias for the cross-table resolver callback
|
||||
CrossTableResolver = Callable[[str, str, Optional[object], int], Any]
|
||||
|
||||
|
||||
def _safe_numeric(value) -> Optional[float]:
|
||||
"""Convert value to float, returning None if not possible."""
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
# ==================== Built-in function registry ====================
|
||||
|
||||
def _fn_round(args):
|
||||
if len(args) < 1:
|
||||
return None
|
||||
value = args[0]
|
||||
decimals = int(args[1]) if len(args) > 1 else 0
|
||||
v = _safe_numeric(value)
|
||||
return round(v, decimals) if v is not None else None
|
||||
|
||||
|
||||
def _fn_abs(args):
|
||||
v = _safe_numeric(args[0]) if args else None
|
||||
return abs(v) if v is not None else None
|
||||
|
||||
|
||||
def _fn_min(args):
|
||||
nums = [_safe_numeric(a) for a in args]
|
||||
nums = [n for n in nums if n is not None]
|
||||
return min(nums) if nums else None
|
||||
|
||||
|
||||
def _fn_max(args):
|
||||
nums = [_safe_numeric(a) for a in args]
|
||||
nums = [n for n in nums if n is not None]
|
||||
return max(nums) if nums else None
|
||||
|
||||
|
||||
def _fn_floor(args):
|
||||
v = _safe_numeric(args[0]) if args else None
|
||||
return math.floor(v) if v is not None else None
|
||||
|
||||
|
||||
def _fn_ceil(args):
|
||||
v = _safe_numeric(args[0]) if args else None
|
||||
return math.ceil(v) if v is not None else None
|
||||
|
||||
|
||||
def _fn_sqrt(args):
|
||||
v = _safe_numeric(args[0]) if args else None
|
||||
return math.sqrt(v) if v is not None and v >= 0 else None
|
||||
|
||||
|
||||
def _fn_sum(args):
|
||||
"""Sum of all arguments (used for inline multi-value, not aggregation)."""
|
||||
nums = [_safe_numeric(a) for a in args]
|
||||
nums = [n for n in nums if n is not None]
|
||||
return sum(nums) if nums else None
|
||||
|
||||
|
||||
def _fn_avg(args):
|
||||
nums = [_safe_numeric(a) for a in args]
|
||||
nums = [n for n in nums if n is not None]
|
||||
return sum(nums) / len(nums) if nums else None
|
||||
|
||||
|
||||
def _fn_len(args):
|
||||
v = args[0] if args else None
|
||||
if v is None:
|
||||
return None
|
||||
return len(str(v))
|
||||
|
||||
|
||||
def _fn_upper(args):
|
||||
v = args[0] if args else None
|
||||
return str(v).upper() if v is not None else None
|
||||
|
||||
|
||||
def _fn_lower(args):
|
||||
v = args[0] if args else None
|
||||
return str(v).lower() if v is not None else None
|
||||
|
||||
|
||||
def _fn_trim(args):
|
||||
v = args[0] if args else None
|
||||
return str(v).strip() if v is not None else None
|
||||
|
||||
|
||||
def _fn_left(args):
|
||||
if len(args) < 2:
|
||||
return None
|
||||
v, n = args[0], args[1]
|
||||
if v is None or n is None:
|
||||
return None
|
||||
return str(v)[:int(n)]
|
||||
|
||||
|
||||
def _fn_right(args):
|
||||
if len(args) < 2:
|
||||
return None
|
||||
v, n = args[0], args[1]
|
||||
if v is None or n is None:
|
||||
return None
|
||||
n = int(n)
|
||||
return str(v)[-n:] if n > 0 else ""
|
||||
|
||||
|
||||
def _fn_concat(args):
|
||||
return "".join(str(a) if a is not None else "" for a in args)
|
||||
|
||||
|
||||
def _fn_year(args):
|
||||
v = args[0] if args else None
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, (datetime, date)):
|
||||
return v.year
|
||||
try:
|
||||
return datetime.fromisoformat(str(v)).year
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
def _fn_month(args):
|
||||
v = args[0] if args else None
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, (datetime, date)):
|
||||
return v.month
|
||||
try:
|
||||
return datetime.fromisoformat(str(v)).month
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
def _fn_day(args):
|
||||
v = args[0] if args else None
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, (datetime, date)):
|
||||
return v.day
|
||||
try:
|
||||
return datetime.fromisoformat(str(v)).day
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
def _fn_today(args):
|
||||
return date.today()
|
||||
|
||||
|
||||
def _fn_datediff(args):
|
||||
if len(args) < 2:
|
||||
return None
|
||||
d1, d2 = args[0], args[1]
|
||||
if d1 is None or d2 is None:
|
||||
return None
|
||||
try:
|
||||
if not isinstance(d1, (datetime, date)):
|
||||
d1 = datetime.fromisoformat(str(d1))
|
||||
if not isinstance(d2, (datetime, date)):
|
||||
d2 = datetime.fromisoformat(str(d2))
|
||||
delta = d1 - d2
|
||||
return delta.days
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
def _fn_coalesce(args):
|
||||
for a in args:
|
||||
if a is not None:
|
||||
return a
|
||||
return None
|
||||
|
||||
|
||||
def _fn_if_error(args):
|
||||
# if_error(expr, fallback) - expr already evaluated, error would be None
|
||||
if len(args) < 2:
|
||||
return args[0] if args else None
|
||||
return args[0] if args[0] is not None else args[1]
|
||||
|
||||
|
||||
def _fn_count(args):
|
||||
"""Count non-None values (used for aggregation results)."""
|
||||
return sum(1 for a in args if a is not None)
|
||||
|
||||
|
||||
# ==================== Function registry ====================
|
||||
|
||||
BUILTIN_FUNCTIONS: dict[str, Callable] = {
|
||||
# Math
|
||||
"round": _fn_round,
|
||||
"abs": _fn_abs,
|
||||
"min": _fn_min,
|
||||
"max": _fn_max,
|
||||
"floor": _fn_floor,
|
||||
"ceil": _fn_ceil,
|
||||
"sqrt": _fn_sqrt,
|
||||
"sum": _fn_sum,
|
||||
"avg": _fn_avg,
|
||||
# Text
|
||||
"len": _fn_len,
|
||||
"upper": _fn_upper,
|
||||
"lower": _fn_lower,
|
||||
"trim": _fn_trim,
|
||||
"left": _fn_left,
|
||||
"right": _fn_right,
|
||||
"concat": _fn_concat,
|
||||
# Date
|
||||
"year": _fn_year,
|
||||
"month": _fn_month,
|
||||
"day": _fn_day,
|
||||
"today": _fn_today,
|
||||
"datediff": _fn_datediff,
|
||||
# Utility
|
||||
"coalesce": _fn_coalesce,
|
||||
"if_error": _fn_if_error,
|
||||
"count": _fn_count,
|
||||
}
|
||||
|
||||
|
||||
# ==================== Evaluator ====================
|
||||
|
||||
class FormulaEvaluator:
|
||||
"""
|
||||
Row-by-row formula evaluator.
|
||||
|
||||
Evaluates a FormulaDefinition AST against a single row of data.
|
||||
|
||||
Args:
|
||||
cross_table_resolver: Optional callback for cross-table references.
|
||||
Signature: resolver(table, column, where_clause, row_index) -> value
|
||||
"""
|
||||
|
||||
def __init__(self, cross_table_resolver: Optional[CrossTableResolver] = None):
|
||||
self._cross_table_resolver = cross_table_resolver
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
formula: FormulaDefinition,
|
||||
row_data: dict,
|
||||
row_index: int,
|
||||
) -> Any:
|
||||
"""
|
||||
Evaluate a formula for a single row.
|
||||
|
||||
Args:
|
||||
formula: The parsed FormulaDefinition AST.
|
||||
row_data: Dict mapping column_id -> value for the current row.
|
||||
row_index: The integer index of the current row.
|
||||
|
||||
Returns:
|
||||
The computed value, or None on error.
|
||||
"""
|
||||
try:
|
||||
return self._eval(formula.expression, row_data, row_index)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Formula evaluation error at row %d: %s", row_index, exc
|
||||
)
|
||||
return None
|
||||
|
||||
def _eval(self, node: FormulaNode, row_data: dict, row_index: int) -> Any:
|
||||
"""
|
||||
Recursively evaluate an AST node.
|
||||
|
||||
Args:
|
||||
node: The AST node to evaluate.
|
||||
row_data: Current row data dict.
|
||||
row_index: Current row index.
|
||||
|
||||
Returns:
|
||||
Evaluated value.
|
||||
"""
|
||||
if isinstance(node, LiteralNode):
|
||||
return node.value
|
||||
|
||||
if isinstance(node, ColumnRef):
|
||||
return self._resolve_column(node.column, row_data)
|
||||
|
||||
if isinstance(node, CrossTableRef):
|
||||
return self._resolve_cross_table(node, row_index)
|
||||
|
||||
if isinstance(node, BinaryOp):
|
||||
return self._eval_binary(node, row_data, row_index)
|
||||
|
||||
if isinstance(node, UnaryOp):
|
||||
return self._eval_unary(node, row_data, row_index)
|
||||
|
||||
if isinstance(node, FunctionCall):
|
||||
return self._eval_function(node, row_data, row_index)
|
||||
|
||||
if isinstance(node, ConditionalExpr):
|
||||
return self._eval_conditional(node, row_data, row_index)
|
||||
|
||||
logger.warning("Unknown AST node type: %s", type(node).__name__)
|
||||
return None
|
||||
|
||||
def _resolve_column(self, column_name: str, row_data: dict) -> Any:
|
||||
"""Resolve a column reference in the current row."""
|
||||
if column_name in row_data:
|
||||
return row_data[column_name]
|
||||
# Try case-insensitive match
|
||||
lower_name = column_name.lower()
|
||||
for key, value in row_data.items():
|
||||
if str(key).lower() == lower_name:
|
||||
return value
|
||||
logger.debug("Column '%s' not found in row_data", column_name)
|
||||
return None
|
||||
|
||||
def _resolve_cross_table(self, node: CrossTableRef, row_index: int) -> Any:
|
||||
"""Resolve a cross-table reference."""
|
||||
if self._cross_table_resolver is None:
|
||||
logger.warning(
|
||||
"No cross_table_resolver set for cross-table ref %s.%s",
|
||||
node.table, node.column,
|
||||
)
|
||||
return None
|
||||
return self._cross_table_resolver(
|
||||
node.table, node.column, node.where_clause, row_index
|
||||
)
|
||||
|
||||
def _eval_binary(self, node: BinaryOp, row_data: dict, row_index: int) -> Any:
|
||||
"""Evaluate a binary operation."""
|
||||
left = self._eval(node.left, row_data, row_index)
|
||||
op = node.operator
|
||||
|
||||
# Short-circuit for logical operators
|
||||
if op == "and":
|
||||
if not self._truthy(left):
|
||||
return False
|
||||
right = self._eval(node.right, row_data, row_index)
|
||||
return self._truthy(left) and self._truthy(right)
|
||||
|
||||
if op == "or":
|
||||
if self._truthy(left):
|
||||
return True
|
||||
right = self._eval(node.right, row_data, row_index)
|
||||
return self._truthy(left) or self._truthy(right)
|
||||
|
||||
right = self._eval(node.right, row_data, row_index)
|
||||
|
||||
# Arithmetic
|
||||
if op == "+":
|
||||
if isinstance(left, str) or isinstance(right, str):
|
||||
return str(left or "") + str(right or "")
|
||||
return self._num_op(left, right, lambda a, b: a + b)
|
||||
if op == "-":
|
||||
return self._num_op(left, right, lambda a, b: a - b)
|
||||
if op == "*":
|
||||
return self._num_op(left, right, lambda a, b: a * b)
|
||||
if op == "/":
|
||||
if right == 0 or right == 0.0:
|
||||
return None # Division by zero -> None
|
||||
return self._num_op(left, right, lambda a, b: a / b)
|
||||
if op == "%":
|
||||
if right == 0 or right == 0.0:
|
||||
return None
|
||||
return self._num_op(left, right, lambda a, b: a % b)
|
||||
if op == "^":
|
||||
return self._num_op(left, right, lambda a, b: a ** b)
|
||||
|
||||
# Comparison
|
||||
if op == "==":
|
||||
return left == right
|
||||
if op == "!=":
|
||||
return left != right
|
||||
if op == "<":
|
||||
return self._compare(left, right) < 0
|
||||
if op == "<=":
|
||||
return self._compare(left, right) <= 0
|
||||
if op == ">":
|
||||
return self._compare(left, right) > 0
|
||||
if op == ">=":
|
||||
return self._compare(left, right) >= 0
|
||||
|
||||
# String operations
|
||||
if op == "contains":
|
||||
return str(right) in str(left) if left is not None else False
|
||||
if op == "startswith":
|
||||
return str(left).startswith(str(right)) if left is not None else False
|
||||
if op == "endswith":
|
||||
return str(left).endswith(str(right)) if left is not None else False
|
||||
|
||||
# Collection operations
|
||||
if op == "in":
|
||||
values = right.value if isinstance(right, LiteralNode) else right
|
||||
if isinstance(values, list):
|
||||
return left in [v.value if isinstance(v, LiteralNode) else v for v in values]
|
||||
return left in (values or [])
|
||||
if op == "between":
|
||||
values = right.value if isinstance(right, LiteralNode) else right
|
||||
if isinstance(values, list) and len(values) == 2:
|
||||
lo = values[0].value if isinstance(values[0], LiteralNode) else values[0]
|
||||
hi = values[1].value if isinstance(values[1], LiteralNode) else values[1]
|
||||
return lo <= left <= hi
|
||||
return None
|
||||
|
||||
logger.warning("Unknown binary operator: %s", op)
|
||||
return None
|
||||
|
||||
def _eval_unary(self, node: UnaryOp, row_data: dict, row_index: int) -> Any:
|
||||
"""Evaluate a unary operation."""
|
||||
operand = self._eval(node.operand, row_data, row_index)
|
||||
op = node.operator
|
||||
|
||||
if op == "-":
|
||||
v = _safe_numeric(operand)
|
||||
return -v if v is not None else None
|
||||
if op == "not":
|
||||
return not self._truthy(operand)
|
||||
if op == "isempty":
|
||||
return operand is None or operand == "" or operand == []
|
||||
if op == "isnotempty":
|
||||
return operand is not None and operand != "" and operand != []
|
||||
if op == "isnan":
|
||||
try:
|
||||
return math.isnan(float(operand))
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
|
||||
logger.warning("Unknown unary operator: %s", op)
|
||||
return None
|
||||
|
||||
def _eval_function(self, node: FunctionCall, row_data: dict, row_index: int) -> Any:
|
||||
"""Evaluate a function call."""
|
||||
name = node.function_name.lower()
|
||||
args = [self._eval(arg, row_data, row_index) for arg in node.arguments]
|
||||
|
||||
if name in BUILTIN_FUNCTIONS:
|
||||
try:
|
||||
return BUILTIN_FUNCTIONS[name](args)
|
||||
except Exception as exc:
|
||||
logger.warning("Function '%s' error: %s", name, exc)
|
||||
return None
|
||||
|
||||
logger.warning("Unknown function: %s", name)
|
||||
return None
|
||||
|
||||
def _eval_conditional(self, node: ConditionalExpr, row_data: dict, row_index: int) -> Any:
|
||||
"""Evaluate a conditional expression."""
|
||||
condition = self._eval(node.condition, row_data, row_index)
|
||||
if self._truthy(condition):
|
||||
return self._eval(node.value_expr, row_data, row_index)
|
||||
elif node.else_expr is not None:
|
||||
return self._eval(node.else_expr, row_data, row_index)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _truthy(value: Any) -> bool:
|
||||
"""Convert a value to boolean for conditional evaluation."""
|
||||
if value is None:
|
||||
return False
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, (int, float)):
|
||||
return value != 0
|
||||
if isinstance(value, str):
|
||||
return len(value) > 0
|
||||
return bool(value)
|
||||
|
||||
@staticmethod
|
||||
def _num_op(left: Any, right: Any, fn: Callable) -> Any:
|
||||
"""Apply a numeric binary function, returning None if inputs are non-numeric."""
|
||||
a = _safe_numeric(left)
|
||||
b = _safe_numeric(right)
|
||||
if a is None or b is None:
|
||||
return None
|
||||
try:
|
||||
return fn(a, b)
|
||||
except (ZeroDivisionError, OverflowError, ValueError):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _compare(left: Any, right: Any) -> int:
|
||||
"""
|
||||
Compare two values, returning -1, 0, or 1.
|
||||
|
||||
Handles mixed numeric/string comparisons gracefully.
|
||||
"""
|
||||
try:
|
||||
if left < right:
|
||||
return -1
|
||||
elif left > right:
|
||||
return 1
|
||||
return 0
|
||||
except TypeError:
|
||||
# Fallback: compare as strings
|
||||
sl, sr = str(left), str(right)
|
||||
if sl < sr:
|
||||
return -1
|
||||
elif sl > sr:
|
||||
return 1
|
||||
return 0
|
||||
@@ -243,7 +243,7 @@ class TestConflictResolution:
|
||||
css, formatted = engine.apply_format(rules, cell_value=150, row_data=row_data)
|
||||
|
||||
assert isinstance(css, StyleContainer)
|
||||
assert "var(--color-secondary)" in css.css # Style from Rule 2
|
||||
assert css.cls == "mf-formatting-secondary" # Style from Rule 2
|
||||
assert formatted == "150.00 €" # Formatter from Rule 1
|
||||
|
||||
# Case 2: Condition not met (value <= budget)
|
||||
@@ -282,7 +282,7 @@ class TestConflictResolution:
|
||||
|
||||
css, formatted = engine.apply_format(rules, cell_value=-5.67)
|
||||
|
||||
assert "var(--color-error)" in css.css # Rule 3 wins for style
|
||||
assert css.cls == "mf-formatting-error" # Rule 3 wins for style
|
||||
assert formatted == "-6 €" # Rule 4 wins for formatter (precision=0)
|
||||
|
||||
|
||||
@@ -316,7 +316,7 @@ class TestWithRowData:
|
||||
css, _ = engine.apply_format(rules, cell_value=42, row_data=row_data)
|
||||
|
||||
assert isinstance(css, StyleContainer)
|
||||
assert "background-color" in css.css
|
||||
assert css.cls == "mf-formatting-error"
|
||||
|
||||
|
||||
class TestPresets:
|
||||
@@ -327,7 +327,7 @@ class TestPresets:
|
||||
|
||||
css, _ = engine.apply_format(rules, cell_value=42)
|
||||
|
||||
assert "var(--color-success)" in css.css
|
||||
assert css.cls == "mf-formatting-success"
|
||||
|
||||
def test_formatter_preset(self):
|
||||
"""Formatter preset is resolved correctly."""
|
||||
|
||||
@@ -16,9 +16,14 @@ class TestResolve:
|
||||
assert result["font-weight"] == "bold"
|
||||
|
||||
def test_resolve_preset_with_override(self):
|
||||
"""Preset properties can be overridden by explicit values."""
|
||||
resolver = StyleResolver()
|
||||
# "success" preset has background and color defined
|
||||
"""Preset CSS properties can be overridden by explicit values."""
|
||||
custom_presets = {
|
||||
"success": {
|
||||
"background-color": "var(--color-success)",
|
||||
"color": "var(--color-success-content)",
|
||||
}
|
||||
}
|
||||
resolver = StyleResolver(style_presets=custom_presets)
|
||||
style = Style(preset="success", color="black")
|
||||
result = resolver.resolve(style)
|
||||
|
||||
@@ -66,6 +71,16 @@ class TestResolve:
|
||||
|
||||
assert result == {}
|
||||
|
||||
def test_i_can_resolve_class_only_preset(self):
|
||||
"""Default preset with __class__ only is included as-is in resolve result."""
|
||||
resolver = StyleResolver()
|
||||
style = Style(preset="success")
|
||||
result = resolver.resolve(style)
|
||||
|
||||
assert result["__class__"] == "mf-formatting-success"
|
||||
assert "background-color" not in result
|
||||
assert "color" not in result
|
||||
|
||||
def test_resolve_converts_property_names(self):
|
||||
"""Python attribute names are converted to CSS property names."""
|
||||
resolver = StyleResolver()
|
||||
@@ -151,11 +166,11 @@ class TestToStyleContainer:
|
||||
None,
|
||||
["background-color: red", "color: white"]
|
||||
),
|
||||
# Class only via preset
|
||||
# Class only via preset (default presets use __class__, no inline CSS)
|
||||
(
|
||||
Style(preset="success"),
|
||||
None, # Default presets don't have __class__
|
||||
["background-color: var(--color-success)", "color: var(--color-success-content)"]
|
||||
"mf-formatting-success",
|
||||
[]
|
||||
),
|
||||
# Empty style
|
||||
(
|
||||
@@ -246,3 +261,12 @@ class TestToStyleContainer:
|
||||
assert isinstance(result, StyleContainer)
|
||||
assert result.cls is None
|
||||
assert result.css == ""
|
||||
|
||||
def test_i_can_resolve_default_preset_to_container(self):
|
||||
"""Default preset with __class__ only generates cls but no css."""
|
||||
resolver = StyleResolver()
|
||||
style = Style(preset="error")
|
||||
result = resolver.to_style_container(style)
|
||||
|
||||
assert result.cls == "mf-formatting-error"
|
||||
assert result.css == ""
|
||||
|
||||
0
tests/core/formula/__init__.py
Normal file
0
tests/core/formula/__init__.py
Normal file
391
tests/core/formula/test_dependency_graph.py
Normal file
391
tests/core/formula/test_dependency_graph.py
Normal file
@@ -0,0 +1,391 @@
|
||||
"""
|
||||
Tests for the DependencyGraph DAG.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from myfasthtml.core.formula.dataclasses import (
|
||||
ColumnRef,
|
||||
ConditionalExpr,
|
||||
CrossTableRef,
|
||||
FormulaDefinition,
|
||||
FunctionCall,
|
||||
LiteralNode,
|
||||
UnaryOp,
|
||||
WhereClause,
|
||||
)
|
||||
from myfasthtml.core.formula.dependency_graph import DependencyGraph
|
||||
from myfasthtml.core.formula.dsl.exceptions import FormulaCycleError
|
||||
from myfasthtml.core.formula.engine import parse_formula
|
||||
|
||||
|
||||
def make_formula(expr):
|
||||
"""Create a FormulaDefinition from a raw AST node for direct AST testing."""
|
||||
return FormulaDefinition(expression=expr, source_text="test")
|
||||
|
||||
|
||||
# ==================== Add formula ====================
|
||||
|
||||
def test_i_can_add_simple_dependency():
|
||||
"""Test that adding a formula creates edges in the graph."""
|
||||
graph = DependencyGraph()
|
||||
formula = parse_formula("{Price} * {Quantity}")
|
||||
graph.add_formula("orders", "total", formula)
|
||||
|
||||
node = graph.get_node("orders", "total")
|
||||
assert node is not None
|
||||
assert node.formula is formula
|
||||
assert node.dirty is True
|
||||
|
||||
# Precedents should include Price and Quantity
|
||||
precedent = graph._precedents["orders.total"]
|
||||
assert "orders.Price" in precedent
|
||||
assert "orders.Quantity" in precedent
|
||||
|
||||
# Descendants are correctly set
|
||||
assert "orders.total" in graph._dependents["orders.Price"]
|
||||
assert "orders.total" in graph._dependents["orders.Quantity"]
|
||||
|
||||
|
||||
def test_i_can_add_formula_and_check_dirty():
|
||||
"""Test that newly added formulas are marked dirty."""
|
||||
graph = DependencyGraph()
|
||||
formula = parse_formula("{A} + {B}")
|
||||
graph.add_formula("t", "C", formula)
|
||||
node = graph.get_node("t", "C")
|
||||
assert node.dirty is True
|
||||
|
||||
|
||||
def test_i_can_update_formula():
|
||||
"""Test that replacing a formula updates dependencies."""
|
||||
graph = DependencyGraph()
|
||||
formula1 = parse_formula("{A} + {B}")
|
||||
graph.add_formula("t", "C", formula1)
|
||||
|
||||
formula2 = parse_formula("{X} * {Y}")
|
||||
graph.add_formula("t", "C", formula2)
|
||||
|
||||
node = graph.get_node("t", "C")
|
||||
assert node.formula is formula2
|
||||
# Should no longer depend on A and B
|
||||
precedent = graph._precedents.get("t.C", set())
|
||||
assert "t.A" not in precedent
|
||||
assert "t.B" not in precedent
|
||||
|
||||
|
||||
# ==================== Cycle detection ====================
|
||||
|
||||
def test_i_cannot_create_cycle():
|
||||
"""Test that circular dependencies raise FormulaCycleError."""
|
||||
graph = DependencyGraph()
|
||||
# A depends on B
|
||||
graph.add_formula("t", "A", parse_formula("{B} + 1"))
|
||||
# B depends on A -> cycle
|
||||
with pytest.raises(FormulaCycleError):
|
||||
graph.add_formula("t", "B", parse_formula("{A} * 2"))
|
||||
|
||||
|
||||
def test_i_cannot_create_self_reference():
|
||||
"""Test that a formula referencing its own column raises FormulaCycleError."""
|
||||
graph = DependencyGraph()
|
||||
with pytest.raises(FormulaCycleError):
|
||||
graph.add_formula("t", "A", parse_formula("{A} + 1"))
|
||||
|
||||
|
||||
def test_i_can_detect_long_cycle():
|
||||
"""Test that a longer chain cycle is also detected."""
|
||||
graph = DependencyGraph()
|
||||
graph.add_formula("t", "A", parse_formula("{B} + 1"))
|
||||
graph.add_formula("t", "B", parse_formula("{C} + 1"))
|
||||
with pytest.raises(FormulaCycleError):
|
||||
graph.add_formula("t", "C", parse_formula("{A} + 1"))
|
||||
|
||||
|
||||
# ==================== Dirty flag propagation ====================
|
||||
|
||||
def test_i_can_propagate_dirty_flags():
|
||||
"""Test that marking a source column dirty propagates to dependents."""
|
||||
graph = DependencyGraph()
|
||||
graph.add_formula("t", "B", parse_formula("{A} * 2"))
|
||||
# Clear dirty flag set by add_formula
|
||||
graph.clear_dirty("t.B")
|
||||
assert not graph.get_node("t", "B").dirty
|
||||
|
||||
# Mark source A as dirty
|
||||
graph.mark_dirty("t", "A")
|
||||
|
||||
# B depends on A, so B should become dirty
|
||||
node_b = graph.get_node("t", "B")
|
||||
assert node_b is not None
|
||||
assert node_b.dirty is True
|
||||
|
||||
|
||||
def test_i_can_propagate_dirty_to_chain():
|
||||
"""Test that dirty flags propagate through a chain: A -> B -> C."""
|
||||
graph = DependencyGraph()
|
||||
graph.add_formula("t", "B", parse_formula("{A} + 1"))
|
||||
graph.add_formula("t", "C", parse_formula("{B} + 1"))
|
||||
# Clear all dirty flags
|
||||
graph.clear_dirty("t.B")
|
||||
graph.clear_dirty("t.C")
|
||||
|
||||
# Mark A dirty
|
||||
graph.mark_dirty("t", "A")
|
||||
|
||||
assert graph.get_node("t", "B").dirty is True
|
||||
assert graph.get_node("t", "C").dirty is True
|
||||
|
||||
|
||||
def test_i_can_propagate_specific_rows():
|
||||
"""Test that dirty propagation can be limited to specific rows."""
|
||||
graph = DependencyGraph()
|
||||
graph.add_formula("t", "B", parse_formula("{A} * 2"))
|
||||
graph.clear_dirty("t.B")
|
||||
|
||||
graph.mark_dirty("t", "A", rows=[0, 2, 5])
|
||||
|
||||
node_b = graph.get_node("t", "B")
|
||||
assert node_b.dirty is True
|
||||
assert 0 in node_b.dirty_rows
|
||||
assert 2 in node_b.dirty_rows
|
||||
assert 5 in node_b.dirty_rows
|
||||
assert 1 not in node_b.dirty_rows
|
||||
|
||||
|
||||
# ==================== Topological ordering ====================
|
||||
|
||||
def test_i_can_get_calculation_order():
|
||||
"""Test that dirty formula nodes are returned in topological order."""
|
||||
graph = DependencyGraph()
|
||||
graph.add_formula("t", "B", parse_formula("{A} + 1"))
|
||||
graph.add_formula("t", "C", parse_formula("{B} + 1"))
|
||||
graph.mark_dirty("t", "A")
|
||||
|
||||
order = graph.get_calculation_order(table="t")
|
||||
node_ids = [n.node_id for n in order]
|
||||
|
||||
# B must come before C
|
||||
assert "t.B" in node_ids
|
||||
assert "t.C" in node_ids
|
||||
assert node_ids.index("t.B") < node_ids.index("t.C")
|
||||
|
||||
|
||||
def test_i_can_get_calculation_order_without_table_filter():
|
||||
"""Test that get_calculation_order with no table filter returns dirty nodes across all tables.
|
||||
|
||||
Why: The table parameter is optional. Omitting it should return dirty formula nodes
|
||||
from every table, not just one.
|
||||
"""
|
||||
graph = DependencyGraph()
|
||||
graph.add_formula("t1", "B", parse_formula("{A} + 1"))
|
||||
graph.add_formula("t2", "D", parse_formula("{C} * 2"))
|
||||
|
||||
order = graph.get_calculation_order()
|
||||
node_ids = [n.node_id for n in order]
|
||||
|
||||
assert "t1.B" in node_ids
|
||||
assert "t2.D" in node_ids
|
||||
|
||||
|
||||
def test_i_can_get_calculation_order_excludes_clean_nodes():
|
||||
"""Test that get_calculation_order only returns dirty formula nodes.
|
||||
|
||||
Why: The method filters on node.dirty (source line 166). Clean nodes must
|
||||
never appear in the output, otherwise they would be recalculated unnecessarily.
|
||||
"""
|
||||
graph = DependencyGraph()
|
||||
graph.add_formula("t", "B", parse_formula("{A} + 1"))
|
||||
graph.add_formula("t", "C", parse_formula("{A} * 2"))
|
||||
graph.clear_dirty("t.C")
|
||||
|
||||
order = graph.get_calculation_order(table="t")
|
||||
node_ids = [n.node_id for n in order]
|
||||
|
||||
assert "t.B" in node_ids
|
||||
assert "t.C" not in node_ids
|
||||
|
||||
|
||||
def test_i_can_handle_diamond_dependency():
|
||||
"""Test topological order with diamond dependency: A -> B, A -> C, B+C -> D."""
|
||||
graph = DependencyGraph()
|
||||
graph.add_formula("t", "B", parse_formula("{A} + 1"))
|
||||
graph.add_formula("t", "C", parse_formula("{A} * 2"))
|
||||
graph.add_formula("t", "D", parse_formula("{B} + {C}"))
|
||||
graph.mark_dirty("t", "A")
|
||||
|
||||
order = graph.get_calculation_order(table="t")
|
||||
node_ids = [n.node_id for n in order]
|
||||
|
||||
assert "t.B" in node_ids
|
||||
assert "t.C" in node_ids
|
||||
assert "t.D" in node_ids
|
||||
# D must come after both B and C
|
||||
assert node_ids.index("t.D") > node_ids.index("t.B")
|
||||
assert node_ids.index("t.D") > node_ids.index("t.C")
|
||||
|
||||
|
||||
# ==================== Remove formula ====================
|
||||
|
||||
def test_i_can_remove_formula():
|
||||
"""Test that a formula can be removed from the graph."""
|
||||
graph = DependencyGraph()
|
||||
graph.add_formula("t", "B", parse_formula("{A} + 1"))
|
||||
assert graph.has_formula("t", "B")
|
||||
|
||||
graph.remove_formula("t", "B")
|
||||
assert not graph.has_formula("t", "B")
|
||||
|
||||
|
||||
def test_i_can_remove_formula_node_kept_when_has_dependents():
|
||||
"""Test that removing a formula keeps the node when other formulas still depend on it.
|
||||
|
||||
Why: remove_formula deletes the node only when no dependents exist (source line 145-146).
|
||||
If another formula depends on the removed node it must remain as a data node.
|
||||
"""
|
||||
graph = DependencyGraph()
|
||||
graph.add_formula("t", "B", parse_formula("{A} + 1"))
|
||||
graph.add_formula("t", "C", parse_formula("{B} * 2"))
|
||||
|
||||
graph.remove_formula("t", "B")
|
||||
|
||||
assert not graph.has_formula("t", "B")
|
||||
assert graph.get_node("t", "B") is not None
|
||||
|
||||
|
||||
def test_i_can_remove_formula_and_add_back():
|
||||
"""Test that a formula can be removed and re-added."""
|
||||
graph = DependencyGraph()
|
||||
graph.add_formula("t", "B", parse_formula("{A} + 1"))
|
||||
graph.remove_formula("t", "B")
|
||||
# Should not raise
|
||||
graph.add_formula("t", "B", parse_formula("{X} * 2"))
|
||||
assert graph.has_formula("t", "B")
|
||||
|
||||
|
||||
# ==================== Cross-table ====================
|
||||
|
||||
def test_i_can_handle_cross_table_dependencies():
|
||||
"""Test that cross-table references create inter-table edges."""
|
||||
graph = DependencyGraph()
|
||||
formula = parse_formula("{Products.Price} * {Quantity}")
|
||||
graph.add_formula("orders", "total", formula)
|
||||
|
||||
node = graph.get_node("orders", "total")
|
||||
assert node is not None
|
||||
prec = graph._precedents.get("orders.total", set())
|
||||
assert "Products.Price" in prec
|
||||
assert "orders.Quantity" in prec
|
||||
|
||||
|
||||
def test_i_can_extract_cross_table_ref_with_where_clause():
|
||||
"""Test that CrossTableRef with a WHERE clause adds both the remote column
|
||||
and the local column from the WHERE clause as dependencies.
|
||||
|
||||
Why: Source line 366-367 adds where_clause.local_column as a dependency.
|
||||
Without this test that code path is never exercised.
|
||||
"""
|
||||
graph = DependencyGraph()
|
||||
formula = make_formula(
|
||||
CrossTableRef(
|
||||
table="Products",
|
||||
column="Price",
|
||||
where_clause=WhereClause(
|
||||
remote_table="Products",
|
||||
remote_column="id",
|
||||
local_column="product_id",
|
||||
),
|
||||
)
|
||||
)
|
||||
graph.add_formula("orders", "total", formula)
|
||||
|
||||
prec = graph._precedents.get("orders.total", set())
|
||||
assert "Products.Price" in prec
|
||||
assert "orders.product_id" in prec
|
||||
|
||||
|
||||
# ==================== _extract_dependencies AST coverage ====================
|
||||
|
||||
def test_i_can_extract_unary_op_dependency():
|
||||
"""Test that UnaryOp correctly extracts its operand's column dependency.
|
||||
|
||||
Why: Source lines 373-375 handle UnaryOp. Without this test that branch
|
||||
is never exercised — a negated column reference would silently produce
|
||||
no dependency edge.
|
||||
"""
|
||||
graph = DependencyGraph()
|
||||
formula = make_formula(UnaryOp("-", ColumnRef("A")))
|
||||
graph.add_formula("t", "B", formula)
|
||||
|
||||
prec = graph._precedents.get("t.B", set())
|
||||
assert "t.A" in prec
|
||||
|
||||
|
||||
def test_i_can_extract_function_call_dependencies():
|
||||
"""Test that FunctionCall extracts column dependencies from all arguments.
|
||||
|
||||
Why: Source lines 376-378 iterate over arguments. With multiple column
|
||||
arguments every dependency must appear in the precedents set.
|
||||
"""
|
||||
graph = DependencyGraph()
|
||||
formula = make_formula(FunctionCall("SUM", [ColumnRef("A"), ColumnRef("B")]))
|
||||
graph.add_formula("t", "C", formula)
|
||||
|
||||
prec = graph._precedents.get("t.C", set())
|
||||
assert "t.A" in prec
|
||||
assert "t.B" in prec
|
||||
|
||||
|
||||
def test_i_can_extract_function_call_with_no_column_args():
|
||||
"""Test that a FunctionCall with only literal arguments creates no column dependencies.
|
||||
|
||||
Why: FunctionCall with literal args only (e.g. NOW(1)) must not create
|
||||
spurious dependency edges that would trigger unnecessary recalculations.
|
||||
"""
|
||||
graph = DependencyGraph()
|
||||
formula = make_formula(FunctionCall("NOW", [LiteralNode(1)]))
|
||||
graph.add_formula("t", "C", formula)
|
||||
|
||||
prec = graph._precedents.get("t.C", set())
|
||||
assert len(prec) == 0, "FunctionCall with literal args should produce no column dependencies"
|
||||
|
||||
|
||||
def test_i_can_extract_conditional_expr_all_branches():
|
||||
"""Test that ConditionalExpr extracts dependencies from value_expr, condition, and else_expr.
|
||||
|
||||
Why: Source lines 380-384 walk all three branches. All three column references
|
||||
must appear as dependencies so that any change in any branch triggers recalculation.
|
||||
"""
|
||||
graph = DependencyGraph()
|
||||
formula = make_formula(
|
||||
ConditionalExpr(
|
||||
value_expr=ColumnRef("A"),
|
||||
condition=ColumnRef("B"),
|
||||
else_expr=ColumnRef("C"),
|
||||
)
|
||||
)
|
||||
graph.add_formula("t", "D", formula)
|
||||
|
||||
prec = graph._precedents.get("t.D", set())
|
||||
assert "t.A" in prec
|
||||
assert "t.B" in prec
|
||||
assert "t.C" in prec
|
||||
|
||||
|
||||
def test_i_can_extract_conditional_expr_without_else():
|
||||
"""Test that ConditionalExpr with else_expr=None does not crash and extracts value and condition.
|
||||
|
||||
Why: Source line 383 guards on ``if node.else_expr is not None``. A missing
|
||||
else branch must not raise and must still extract the remaining two dependencies.
|
||||
"""
|
||||
graph = DependencyGraph()
|
||||
formula = make_formula(
|
||||
ConditionalExpr(
|
||||
value_expr=ColumnRef("A"),
|
||||
condition=ColumnRef("B"),
|
||||
else_expr=None,
|
||||
)
|
||||
)
|
||||
graph.add_formula("t", "C", formula)
|
||||
|
||||
prec = graph._precedents.get("t.C", set())
|
||||
assert "t.A" in prec
|
||||
assert "t.B" in prec
|
||||
408
tests/core/formula/test_formula_engine.py
Normal file
408
tests/core/formula/test_formula_engine.py
Normal file
@@ -0,0 +1,408 @@
|
||||
"""
|
||||
Tests for the FormulaEngine facade.
|
||||
"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from myfasthtml.core.formula.dsl.exceptions import FormulaSyntaxError, FormulaCycleError
|
||||
from myfasthtml.core.formula.engine import FormulaEngine
|
||||
|
||||
|
||||
class FakeStore:
|
||||
"""Minimal DatagridStore-like object for testing."""
|
||||
|
||||
def __init__(self, rows):
|
||||
self.ns_row_data = rows
|
||||
self.ns_fast_access = self._build_fast_access(rows)
|
||||
self.ns_total_rows = len(rows)
|
||||
|
||||
@staticmethod
|
||||
def _build_fast_access(rows):
|
||||
"""Build a columnar fast-access dict from a list of row dicts."""
|
||||
if not rows:
|
||||
return {}
|
||||
return {
|
||||
col: np.array([row.get(col) for row in rows], dtype=object)
|
||||
for col in rows[0].keys()
|
||||
}
|
||||
|
||||
|
||||
def make_engine(store_map=None):
|
||||
"""Create an engine with an optional store map for cross-table resolution."""
|
||||
|
||||
def resolver(table_name):
|
||||
return store_map.get(table_name) if store_map else None
|
||||
|
||||
return FormulaEngine(registry_resolver=resolver)
|
||||
|
||||
|
||||
# ==================== TestSetFormula ====================
|
||||
|
||||
class TestSetFormula:
|
||||
"""Tests for set_formula: parsing, registration, and edge cases."""
|
||||
|
||||
def test_i_can_set_and_evaluate_formula(self):
|
||||
"""Test that a formula can be set and evaluated."""
|
||||
rows = [{"Price": 10, "Quantity": 3}]
|
||||
store = FakeStore(rows)
|
||||
|
||||
engine = make_engine({"orders": store})
|
||||
engine.set_formula("orders", "total", "{Price} * {Quantity}")
|
||||
engine.recalculate_if_needed("orders", store)
|
||||
|
||||
assert "total" in store.ns_fast_access
|
||||
assert store.ns_fast_access["total"][0] == 30
|
||||
|
||||
def test_i_can_evaluate_multiple_rows(self):
|
||||
"""Test that formula evaluation works across multiple rows."""
|
||||
rows = [
|
||||
{"Price": 10, "Quantity": 3},
|
||||
{"Price": 20, "Quantity": 2},
|
||||
{"Price": 5, "Quantity": 10},
|
||||
]
|
||||
store = FakeStore(rows)
|
||||
|
||||
engine = make_engine({"orders": store})
|
||||
engine.set_formula("orders", "total", "{Price} * {Quantity}")
|
||||
engine.recalculate_if_needed("orders", store)
|
||||
|
||||
totals = store.ns_fast_access["total"]
|
||||
assert totals[0] == 30
|
||||
assert totals[1] == 40
|
||||
assert totals[2] == 50
|
||||
|
||||
def test_i_cannot_set_invalid_formula(self):
|
||||
"""Test that invalid formula syntax raises FormulaSyntaxError."""
|
||||
engine = make_engine()
|
||||
with pytest.raises(FormulaSyntaxError):
|
||||
engine.set_formula("t", "col", "{Price} * * {Qty}")
|
||||
|
||||
def test_i_cannot_set_formula_with_cycle(self):
|
||||
"""Test that a circular dependency raises FormulaCycleError."""
|
||||
rows = [{"A": 1}]
|
||||
store = FakeStore(rows)
|
||||
engine = make_engine({"t": store})
|
||||
|
||||
engine.set_formula("t", "A", "{B} + 1")
|
||||
with pytest.raises(FormulaCycleError):
|
||||
engine.set_formula("t", "B", "{A} * 2")
|
||||
|
||||
@pytest.mark.parametrize("text", ["", " ", "\t\n"])
|
||||
def test_i_can_set_formula_with_blank_input_removes_it(self, text):
|
||||
"""Test that setting a blank or whitespace-only formula string removes it.
|
||||
|
||||
Why: set_formula strips the input (source line 86) before checking emptiness.
|
||||
All blank variants — empty string, spaces, tabs — must behave identically.
|
||||
"""
|
||||
rows = [{"Price": 10}]
|
||||
store = FakeStore(rows)
|
||||
engine = make_engine({"t": store})
|
||||
|
||||
engine.set_formula("t", "col", "{Price} * 2")
|
||||
assert engine.has_formula("t", "col")
|
||||
|
||||
engine.set_formula("t", "col", text)
|
||||
assert not engine.has_formula("t", "col")
|
||||
|
||||
def test_i_can_replace_formula_clears_old_dependencies(self):
|
||||
"""Test that replacing a formula removes old dependency edges.
|
||||
|
||||
Why: add_formula calls _remove_edges before adding new ones (source line 99).
|
||||
After replacement, marking the old dependency dirty must NOT trigger
|
||||
recalculation of the formula column.
|
||||
"""
|
||||
rows = [{"A": 5, "B": 3, "X": 10}]
|
||||
store = FakeStore(rows)
|
||||
engine = make_engine({"t": store})
|
||||
|
||||
engine.set_formula("t", "C", "{A} + {B}")
|
||||
engine.recalculate_if_needed("t", store)
|
||||
assert store.ns_fast_access["C"][0] == pytest.approx(8.0)
|
||||
|
||||
engine.set_formula("t", "C", "{X} * 2")
|
||||
engine.recalculate_if_needed("t", store)
|
||||
assert store.ns_fast_access["C"][0] == pytest.approx(20.0)
|
||||
|
||||
# A is no longer a dependency — changing it should NOT trigger C recalculation
|
||||
rows[0]["A"] = 100
|
||||
store.ns_fast_access["A"] = np.array([100], dtype=object)
|
||||
engine.mark_data_changed("t", "A")
|
||||
|
||||
store.ns_fast_access["C"][0] = 999 # sentinel
|
||||
engine.recalculate_if_needed("t", store)
|
||||
assert store.ns_fast_access["C"][0] == 999, (
|
||||
"C should not be recalculated when A changes after formula replacement"
|
||||
)
|
||||
|
||||
|
||||
# ==================== TestRemoveFormula ====================
|
||||
|
||||
class TestRemoveFormula:
|
||||
"""Tests for remove_formula."""
|
||||
|
||||
def test_i_can_remove_formula(self):
|
||||
"""Test that a formula can be removed."""
|
||||
rows = [{"Price": 10, "Quantity": 3}]
|
||||
store = FakeStore(rows)
|
||||
engine = make_engine({"t": store})
|
||||
|
||||
engine.set_formula("t", "total", "{Price} * {Quantity}")
|
||||
assert engine.has_formula("t", "total")
|
||||
|
||||
engine.remove_formula("t", "total")
|
||||
assert not engine.has_formula("t", "total")
|
||||
|
||||
def test_i_can_remove_formula_and_set_back(self):
|
||||
"""Test that a formula can be removed via empty string and then re-added."""
|
||||
rows = [{"Price": 10}]
|
||||
store = FakeStore(rows)
|
||||
engine = make_engine({"t": store})
|
||||
|
||||
engine.set_formula("t", "col", "{Price} * 2")
|
||||
assert engine.has_formula("t", "col")
|
||||
|
||||
engine.set_formula("t", "col", "")
|
||||
assert not engine.has_formula("t", "col")
|
||||
|
||||
|
||||
# ==================== TestRecalculate ====================
|
||||
|
||||
class TestRecalculate:
|
||||
"""Tests for recalculate_if_needed: dirty tracking, return values, and row data update."""
|
||||
|
||||
def test_i_can_recalculate_only_dirty(self):
|
||||
"""Test that only dirty formula columns are recalculated."""
|
||||
rows = [{"A": 5, "B": 3}]
|
||||
store = FakeStore(rows)
|
||||
engine = make_engine({"t": store})
|
||||
|
||||
engine.set_formula("t", "C", "{A} + {B}")
|
||||
engine.recalculate_if_needed("t", store)
|
||||
|
||||
# Manually set a different value to detect recalculation
|
||||
store.ns_fast_access["C"][0] = 999
|
||||
|
||||
# No dirty flags → should NOT recalculate
|
||||
engine.recalculate_if_needed("t", store)
|
||||
assert store.ns_fast_access["C"][0] == 999 # unchanged
|
||||
|
||||
def test_i_can_recalculate_after_data_changed(self):
|
||||
"""Test that marking data changed triggers recalculation."""
|
||||
rows = [{"A": 5, "B": 3}]
|
||||
store = FakeStore(rows)
|
||||
engine = make_engine({"t": store})
|
||||
|
||||
engine.set_formula("t", "C", "{A} + {B}")
|
||||
engine.recalculate_if_needed("t", store)
|
||||
assert store.ns_fast_access["C"][0] == 8
|
||||
|
||||
# Update both storage structures (mirrors real DataGrid behaviour)
|
||||
rows[0]["A"] = 10
|
||||
store.ns_fast_access["A"] = np.array([10], dtype=object)
|
||||
|
||||
# Mark source column dirty
|
||||
engine.mark_data_changed("t", "A")
|
||||
engine.recalculate_if_needed("t", store)
|
||||
assert store.ns_fast_access["C"][0] == 13
|
||||
|
||||
def test_i_can_recalculate_returns_false_when_no_dirty(self):
|
||||
"""Test that recalculate_if_needed returns False when no nodes are dirty.
|
||||
|
||||
Why: Source line 151 returns False early when get_calculation_order is empty.
|
||||
After a first successful recalculation all dirty flags are cleared, so the
|
||||
second call must return False to avoid redundant work.
|
||||
"""
|
||||
rows = [{"A": 5}]
|
||||
store = FakeStore(rows)
|
||||
engine = make_engine({"t": store})
|
||||
|
||||
engine.set_formula("t", "B", "{A} + 1")
|
||||
engine.recalculate_if_needed("t", store) # clears dirty flag
|
||||
|
||||
result = engine.recalculate_if_needed("t", store)
|
||||
assert result is False
|
||||
|
||||
def test_i_can_recalculate_returns_true_when_dirty(self):
|
||||
"""Test that recalculate_if_needed returns True when columns are recalculated.
|
||||
|
||||
Why: Source line 164 returns True after the evaluation loop. This return value
|
||||
lets callers skip downstream work (e.g. rendering) when nothing changed.
|
||||
"""
|
||||
rows = [{"A": 5}]
|
||||
store = FakeStore(rows)
|
||||
engine = make_engine({"t": store})
|
||||
|
||||
engine.set_formula("t", "B", "{A} + 1")
|
||||
|
||||
result = engine.recalculate_if_needed("t", store)
|
||||
assert result is True
|
||||
|
||||
def test_i_can_recalculate_with_empty_store(self):
|
||||
"""Test that recalculate_if_needed handles a store with no rows without crashing.
|
||||
|
||||
Why: _evaluate_column guards on empty ns_row_data (source line 211-212).
|
||||
No formula result should be written when there are no rows to process.
|
||||
"""
|
||||
store = FakeStore([]) # no rows
|
||||
engine = make_engine({"t": store})
|
||||
|
||||
engine.set_formula("t", "B", "{A} + 1")
|
||||
engine.recalculate_if_needed("t", store) # must not raise
|
||||
|
||||
assert "B" not in store.ns_fast_access
|
||||
|
||||
def test_i_can_verify_formula_values_appear_in_row_data(self):
|
||||
"""Test that formula values are written back into ns_row_data after recalculation.
|
||||
|
||||
Why: _rebuild_row_data (source line 231-249) merges ns_fast_access values into
|
||||
each row dict. This ensures formula results are available in row_data for
|
||||
subsequent evaluation passes and for rendering.
|
||||
"""
|
||||
rows = [{"A": 5}]
|
||||
store = FakeStore(rows)
|
||||
engine = make_engine({"t": store})
|
||||
|
||||
engine.set_formula("t", "B", "{A} + 10")
|
||||
engine.recalculate_if_needed("t", store)
|
||||
|
||||
assert store.ns_row_data[0]["B"] == pytest.approx(15.0)
|
||||
|
||||
# --- Known bug: chained formula columns ---
|
||||
# _rebuild_row_data is called once AFTER the full evaluation loop, so formula
|
||||
# column B is not yet in row_data when formula column C is evaluated.
|
||||
# Fix needed: rebuild row_data between each column evaluation in the loop.
|
||||
|
||||
def test_i_can_recalculate_chain_formula_initial(self):
|
||||
"""Test that C = f(B) is correct when B is itself a formula column (initial pass).
|
||||
|
||||
Chain: A (data) → B = A + 10 → C = B * 2
|
||||
Expected: B = 15, C = 30.
|
||||
"""
|
||||
rows = [{"A": 5}]
|
||||
store = FakeStore(rows)
|
||||
engine = make_engine({"t": store})
|
||||
|
||||
engine.set_formula("t", "B", "{A} + 10")
|
||||
engine.set_formula("t", "C", "{B} * 2")
|
||||
engine.recalculate_if_needed("t", store)
|
||||
|
||||
assert store.ns_fast_access["B"][0] == 15
|
||||
assert store.ns_fast_access["C"][0] == 30
|
||||
|
||||
def test_i_can_recalculate_chain_formula_after_data_change(self):
|
||||
"""Test that C = f(B) stays correct after the source data column A changes.
|
||||
|
||||
Chain: A (data) → B = A + 10 → C = B * 2
|
||||
After A changes from 5 to 10: B = 20, C = 40.
|
||||
"""
|
||||
rows = [{"A": 5}]
|
||||
store = FakeStore(rows)
|
||||
engine = make_engine({"t": store})
|
||||
|
||||
engine.set_formula("t", "B", "{A} + 10")
|
||||
engine.set_formula("t", "C", "{B} * 2")
|
||||
engine.recalculate_if_needed("t", store) # first pass (B=15, C=None per bug above)
|
||||
|
||||
rows[0]["A"] = 10
|
||||
store.ns_fast_access["A"] = np.array([10], dtype=object)
|
||||
engine.mark_data_changed("t", "A")
|
||||
engine.recalculate_if_needed("t", store)
|
||||
|
||||
assert store.ns_fast_access["B"][0] == 20
|
||||
assert store.ns_fast_access["C"][0] == 40
|
||||
|
||||
|
||||
# ==================== TestCrossTable ====================
|
||||
|
||||
class TestCrossTable:
|
||||
"""Tests for cross-table reference resolution strategies."""
|
||||
|
||||
def test_i_can_handle_cross_table_formula(self):
|
||||
"""Test that cross-table references are resolved via registry (Strategy 3: row_index)."""
|
||||
orders_rows = [{"Quantity": 3, "ProductId": 1}]
|
||||
products_rows = [{"Price": 99.0}]
|
||||
|
||||
orders_store = FakeStore(orders_rows)
|
||||
products_store = FakeStore(products_rows)
|
||||
products_store.ns_fast_access = {"Price": np.array([99.0], dtype=object)}
|
||||
|
||||
engine = make_engine({
|
||||
"orders": orders_store,
|
||||
"products": products_store,
|
||||
})
|
||||
engine.set_formula("orders", "total", "{products.Price} * {Quantity}")
|
||||
engine.recalculate_if_needed("orders", orders_store)
|
||||
|
||||
assert "total" in orders_store.ns_fast_access
|
||||
assert orders_store.ns_fast_access["total"][0] == pytest.approx(297.0)
|
||||
|
||||
def test_i_can_resolve_cross_table_by_id_join(self):
|
||||
"""Test that cross-table references are resolved via implicit id-join (Strategy 2).
|
||||
|
||||
Why: Strategy 2 (source lines 303-318) matches rows where both tables share
|
||||
the same id value. This allows cross-table lookups without an explicit WHERE
|
||||
clause when both stores expose an 'id' column in ns_fast_access.
|
||||
"""
|
||||
orders_rows = [{"id": 101, "Quantity": 3}, {"id": 102, "Quantity": 5}]
|
||||
orders_store = FakeStore(orders_rows)
|
||||
|
||||
products_rows = [{"id": 101, "Price": 10.0}, {"id": 102, "Price": 20.0}]
|
||||
products_store = FakeStore(products_rows)
|
||||
|
||||
engine = make_engine({
|
||||
"orders": orders_store,
|
||||
"products": products_store,
|
||||
})
|
||||
engine.set_formula("orders", "total", "{products.Price} * {Quantity}")
|
||||
engine.recalculate_if_needed("orders", orders_store)
|
||||
|
||||
totals = orders_store.ns_fast_access["total"]
|
||||
assert totals[0] == 30, "Row with id=101: Price=10 * Qty=3"
|
||||
assert totals[1] == 100, "Row with id=102: Price=20 * Qty=5"
|
||||
|
||||
def test_i_can_handle_cross_table_without_registry(self):
|
||||
"""Test that a cross-table formula evaluates gracefully when no registry is set.
|
||||
|
||||
Why: _make_cross_table_resolver guards on registry_resolver=None (source line 274).
|
||||
The formula must evaluate to None without raising, preserving engine stability.
|
||||
"""
|
||||
rows = [{"Quantity": 3}]
|
||||
store = FakeStore(rows)
|
||||
|
||||
engine = FormulaEngine(registry_resolver=None)
|
||||
engine.set_formula("orders", "total", "{products.Price} * {Quantity}")
|
||||
engine.recalculate_if_needed("orders", store) # must not raise
|
||||
|
||||
assert store.ns_fast_access["total"][0] is None
|
||||
|
||||
def test_i_can_handle_cross_table_missing_table(self):
|
||||
"""Test that a cross-table formula evaluates gracefully when the remote table is absent.
|
||||
|
||||
Why: Source line 282-284 returns None when registry_resolver returns None for
|
||||
the requested table. The engine must not crash and must produce None for the row.
|
||||
"""
|
||||
rows = [{"Quantity": 3}]
|
||||
orders_store = FakeStore(rows)
|
||||
|
||||
engine = make_engine({"orders": orders_store}) # "products" not in registry
|
||||
engine.set_formula("orders", "total", "{products.Price} * {Quantity}")
|
||||
engine.recalculate_if_needed("orders", orders_store) # must not raise
|
||||
|
||||
assert orders_store.ns_fast_access["total"][0] is None
|
||||
|
||||
|
||||
# ==================== TestGetFormulaText ====================
|
||||
|
||||
class TestGetFormulaText:
|
||||
"""Tests for formula text retrieval."""
|
||||
|
||||
def test_i_can_get_formula_text(self):
|
||||
"""Test that registered formula text can be retrieved."""
|
||||
engine = make_engine()
|
||||
engine.set_formula("t", "col", "{Price} * 2")
|
||||
assert engine.get_formula_text("t", "col") == "{Price} * 2"
|
||||
|
||||
def test_i_can_get_formula_text_returns_none_when_not_set(self):
|
||||
"""Test that get_formula_text returns None for non-formula columns."""
|
||||
engine = make_engine()
|
||||
assert engine.get_formula_text("t", "non_existing") is None
|
||||
188
tests/core/formula/test_formula_evaluator.py
Normal file
188
tests/core/formula/test_formula_evaluator.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""
|
||||
Tests for the FormulaEvaluator.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from myfasthtml.core.formula.engine import parse_formula
|
||||
from myfasthtml.core.formula.evaluator import FormulaEvaluator
|
||||
|
||||
|
||||
def make_evaluator(resolver=None):
|
||||
return FormulaEvaluator(cross_table_resolver=resolver)
|
||||
|
||||
|
||||
def eval_formula(text, row_data, row_index=0, resolver=None):
|
||||
"""Helper: parse and evaluate a formula."""
|
||||
formula = parse_formula(text)
|
||||
evaluator = make_evaluator(resolver)
|
||||
return evaluator.evaluate(formula, row_data, row_index)
|
||||
|
||||
|
||||
# ==================== Arithmetic ====================
|
||||
|
||||
@pytest.mark.parametrize("formula,row_data,expected", [
|
||||
("{Price} * {Quantity}", {"Price": 10, "Quantity": 3}, 30.0),
|
||||
("{Price} + {Tax}", {"Price": 100, "Tax": 20}, 120.0),
|
||||
("{Total} - {Discount}", {"Total": 100, "Discount": 15}, 85.0),
|
||||
("{Total} / {Count}", {"Total": 100, "Count": 4}, 25.0),
|
||||
("{Value} % 3", {"Value": 10}, 1.0),
|
||||
("{Base} ^ 2", {"Base": 5}, 25.0),
|
||||
])
|
||||
def test_i_can_evaluate_simple_arithmetic(formula, row_data, expected):
|
||||
"""Test that arithmetic formulas evaluate correctly."""
|
||||
result = eval_formula(formula, row_data)
|
||||
assert result == pytest.approx(expected)
|
||||
|
||||
|
||||
# ==================== Functions ====================
|
||||
|
||||
@pytest.mark.parametrize("formula,row_data,expected", [
|
||||
("round({Price} * 1.2, 2)", {"Price": 10}, 12.0),
|
||||
("abs({Balance})", {"Balance": -50}, 50.0),
|
||||
("upper({Name})", {"Name": "hello"}, "HELLO"),
|
||||
("lower({Name})", {"Name": "WORLD"}, "world"),
|
||||
("len({Description})", {"Description": "abc"}, 3),
|
||||
("concat({First}, \" \", {Last})", {"First": "John", "Last": "Doe"}, "John Doe"),
|
||||
("left({Code}, 3)", {"Code": "ABCDEF"}, "ABC"),
|
||||
("right({Code}, 3)", {"Code": "ABCDEF"}, "DEF"),
|
||||
("trim({Name})", {"Name": " hello "}, "hello"),
|
||||
])
|
||||
def test_i_can_evaluate_function(formula, row_data, expected):
|
||||
"""Test that built-in function calls evaluate correctly."""
|
||||
result = eval_formula(formula, row_data)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_i_can_evaluate_nested_functions():
|
||||
"""Test that nested function calls evaluate correctly."""
|
||||
result = eval_formula("round(abs({Val}), 1)", {"Val": -3.456})
|
||||
assert result == 3.5
|
||||
|
||||
|
||||
# ==================== Conditionals ====================
|
||||
|
||||
def test_i_can_evaluate_conditional_true():
|
||||
"""Test conditional when condition is true."""
|
||||
result = eval_formula('{Price} * 0.8 if {Country} == "FR"', {"Price": 100, "Country": "FR"})
|
||||
assert result == 80.0
|
||||
|
||||
|
||||
def test_i_can_evaluate_conditional_false_no_else():
|
||||
"""Test conditional returns None when condition is false and no else."""
|
||||
result = eval_formula('{Price} * 0.8 if {Country} == "FR"', {"Price": 100, "Country": "DE"})
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_i_can_evaluate_conditional_with_else():
|
||||
"""Test conditional returns else value when condition is false."""
|
||||
result = eval_formula('{Price} * 0.8 if {Country} == "FR" else {Price}', {"Price": 100, "Country": "DE"})
|
||||
assert result == 100.0
|
||||
|
||||
|
||||
def test_i_can_evaluate_chained_conditional():
|
||||
"""Test chained conditionals evaluate in order."""
|
||||
formula = '{Price} * 0.8 if {Country} == "FR" else {Price} * 0.9 if {Country} == "DE" else {Price}'
|
||||
assert eval_formula(formula, {"Price": 100, "Country": "FR"}) == pytest.approx(80.0)
|
||||
assert eval_formula(formula, {"Price": 100, "Country": "DE"}) == pytest.approx(90.0)
|
||||
assert eval_formula(formula, {"Price": 100, "Country": "US"}) == pytest.approx(100.0)
|
||||
|
||||
|
||||
# ==================== Logical operators ====================
|
||||
|
||||
@pytest.mark.parametrize("formula,row_data,expected", [
|
||||
("{A} and {B}", {"A": True, "B": True}, True),
|
||||
("{A} and {B}", {"A": True, "B": False}, False),
|
||||
("{A} or {B}", {"A": False, "B": True}, True),
|
||||
("{A} or {B}", {"A": False, "B": False}, False),
|
||||
("not {A}", {"A": True}, False),
|
||||
("not {A}", {"A": False}, True),
|
||||
])
|
||||
def test_i_can_evaluate_logical_operators(formula, row_data, expected):
|
||||
"""Test that logical operators evaluate correctly."""
|
||||
result = eval_formula(formula, row_data)
|
||||
assert result == expected
|
||||
|
||||
|
||||
# ==================== Error handling ====================
|
||||
|
||||
def test_i_can_handle_division_by_zero():
|
||||
"""Test that division by zero returns None."""
|
||||
result = eval_formula("{A} / {B}", {"A": 10, "B": 0})
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_i_can_handle_missing_column():
|
||||
"""Test that missing column reference returns None."""
|
||||
result = eval_formula("{NonExistent} * 2", {})
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_i_can_handle_none_operand():
|
||||
"""Test that operations with None operands return None."""
|
||||
result = eval_formula("{A} * {B}", {"A": None, "B": 5})
|
||||
assert result is None
|
||||
|
||||
|
||||
# ==================== Cross-table references ====================
|
||||
|
||||
def test_i_can_evaluate_cross_table_ref():
|
||||
"""Test cross-table reference with mock resolver."""
|
||||
|
||||
def mock_resolver(table, column, where_clause, row_index):
|
||||
assert table == "Products"
|
||||
assert column == "Price"
|
||||
return 99.0
|
||||
|
||||
result = eval_formula("{Products.Price} * {Quantity}", {"Quantity": 3}, resolver=mock_resolver)
|
||||
assert result == pytest.approx(297.0)
|
||||
|
||||
|
||||
def test_i_can_evaluate_cross_table_with_where():
|
||||
"""Test cross-table reference with WHERE clause and mock resolver."""
|
||||
|
||||
def mock_resolver(table, column, where_clause, row_index):
|
||||
assert where_clause is not None
|
||||
assert where_clause.local_column == "ProductCode"
|
||||
return 50.0
|
||||
|
||||
result = eval_formula(
|
||||
"{Products.Price where Products.Code = ProductCode} * {Qty}",
|
||||
{"ProductCode": "ABC", "Qty": 2},
|
||||
resolver=mock_resolver,
|
||||
)
|
||||
assert result == pytest.approx(100.0)
|
||||
|
||||
|
||||
# ==================== Aggregation ====================
|
||||
|
||||
def test_i_can_evaluate_aggregation():
|
||||
"""Test that aggregation functions work with cross-table resolver."""
|
||||
values = [10.0, 20.0, 30.0]
|
||||
call_count = [0]
|
||||
|
||||
def mock_resolver(table, column, where_clause, row_index):
|
||||
# For aggregation test, return a list to simulate multi-row match
|
||||
val = values[call_count[0] % len(values)]
|
||||
call_count[0] += 1
|
||||
return val
|
||||
|
||||
formula = parse_formula("sum({OrderLines.Amount where OrderLines.OrderId = Id})")
|
||||
evaluator = FormulaEvaluator(cross_table_resolver=mock_resolver)
|
||||
result = evaluator.evaluate(formula, {"Id": 1}, 0)
|
||||
# sum() with a single cross-table value returned by resolver
|
||||
assert result is not None
|
||||
|
||||
|
||||
# ==================== String operations ====================
|
||||
|
||||
@pytest.mark.parametrize("formula,row_data,expected", [
|
||||
('{Name} contains "Corp"', {"Name": "Acme Corp"}, True),
|
||||
('{Code} startswith "ERR"', {"Code": "ERR001"}, True),
|
||||
('{File} endswith ".csv"', {"File": "data.csv"}, True),
|
||||
('{Status} in ["active", "new"]', {"Status": "active"}, True),
|
||||
('{Status} in ["active", "new"]', {"Status": "deleted"}, False),
|
||||
])
|
||||
def test_i_can_evaluate_string_operations(formula, row_data, expected):
|
||||
"""Test string comparison operations."""
|
||||
result = eval_formula(formula, row_data)
|
||||
assert result == expected
|
||||
188
tests/core/formula/test_formula_parser.py
Normal file
188
tests/core/formula/test_formula_parser.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""
|
||||
Tests for the formula parser (grammar + transformer integration).
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from myfasthtml.core.formula.dataclasses import (
|
||||
BinaryOp,
|
||||
CrossTableRef,
|
||||
ConditionalExpr,
|
||||
FunctionCall,
|
||||
LiteralNode,
|
||||
FormulaDefinition,
|
||||
)
|
||||
from myfasthtml.core.formula.dsl.exceptions import FormulaSyntaxError
|
||||
from myfasthtml.core.formula.engine import parse_formula
|
||||
|
||||
|
||||
# ==================== Valid formulas ====================
|
||||
|
||||
@pytest.mark.parametrize("formula_text", [
|
||||
"{Price} * {Quantity}",
|
||||
"{Price} + {Tax}",
|
||||
"{Total} - {Discount}",
|
||||
"{Total} / {Count}",
|
||||
"{Value} % 2",
|
||||
"{Base} ^ 2",
|
||||
])
|
||||
def test_i_can_parse_simple_arithmetic(formula_text):
|
||||
"""Test that basic arithmetic formulas parse without error."""
|
||||
result = parse_formula(formula_text)
|
||||
assert result is not None
|
||||
assert isinstance(result, FormulaDefinition)
|
||||
assert isinstance(result.expression, BinaryOp)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("formula_text,expected_func", [
|
||||
("round({Price} * 1.2, 2)", "round"),
|
||||
("abs({Balance})", "abs"),
|
||||
("upper({Name})", "upper"),
|
||||
("len({Description})", "len"),
|
||||
("today()", "today"),
|
||||
("concat({First}, \" \", {Last})", "concat"),
|
||||
])
|
||||
def test_i_can_parse_function_call(formula_text, expected_func):
|
||||
"""Test that function calls parse correctly."""
|
||||
result = parse_formula(formula_text)
|
||||
assert result is not None
|
||||
assert isinstance(result.expression, FunctionCall)
|
||||
assert result.expression.function_name == expected_func
|
||||
|
||||
|
||||
def test_i_can_parse_conditional_no_else():
|
||||
"""Test that suffix-if without else parses correctly."""
|
||||
result = parse_formula('{Price} * 0.8 if {Country} == "FR"')
|
||||
assert result is not None
|
||||
expr = result.expression
|
||||
assert isinstance(expr, ConditionalExpr)
|
||||
assert expr.else_expr is None
|
||||
assert isinstance(expr.value_expr, BinaryOp)
|
||||
|
||||
|
||||
def test_i_can_parse_conditional_with_else():
|
||||
"""Test that suffix-if with else parses correctly."""
|
||||
result = parse_formula('{Price} * 0.8 if {Country} == "FR" else {Price}')
|
||||
assert result is not None
|
||||
expr = result.expression
|
||||
assert isinstance(expr, ConditionalExpr)
|
||||
assert expr.else_expr is not None
|
||||
|
||||
|
||||
def test_i_can_parse_chained_conditional():
|
||||
"""Test that chained conditionals parse correctly."""
|
||||
formula = '{Price} * 0.8 if {Country} == "FR" else {Price} * 0.9 if {Country} == "DE" else {Price}'
|
||||
result = parse_formula(formula)
|
||||
assert result is not None
|
||||
expr = result.expression
|
||||
assert isinstance(expr, ConditionalExpr)
|
||||
# The else_expr should be another ConditionalExpr
|
||||
assert isinstance(expr.else_expr, ConditionalExpr)
|
||||
|
||||
|
||||
def test_i_can_parse_cross_table_ref():
|
||||
"""Test that cross-table references parse correctly."""
|
||||
result = parse_formula("{Products.Price} * {Quantity}")
|
||||
assert result is not None
|
||||
expr = result.expression
|
||||
assert isinstance(expr, BinaryOp)
|
||||
assert isinstance(expr.left, CrossTableRef)
|
||||
assert expr.left.table == "Products"
|
||||
assert expr.left.column == "Price"
|
||||
|
||||
|
||||
def test_i_can_parse_cross_table_with_where():
|
||||
"""Test that cross-table references with WHERE clause parse correctly."""
|
||||
result = parse_formula("{Products.Price where Products.Code = ProductCode} * {Quantity}")
|
||||
assert result is not None
|
||||
expr = result.expression
|
||||
assert isinstance(expr, BinaryOp)
|
||||
cross_ref = expr.left
|
||||
assert isinstance(cross_ref, CrossTableRef)
|
||||
assert cross_ref.where_clause is not None
|
||||
assert cross_ref.where_clause.remote_table == "Products"
|
||||
assert cross_ref.where_clause.remote_column == "Code"
|
||||
assert cross_ref.where_clause.local_column == "ProductCode"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("formula_text,expected_op", [
|
||||
("{A} and {B}", "and"),
|
||||
("{A} or {B}", "or"),
|
||||
("not {A}", "not"),
|
||||
])
|
||||
def test_i_can_parse_logical_operators(formula_text, expected_op):
|
||||
"""Test that logical operators parse correctly."""
|
||||
result = parse_formula(formula_text)
|
||||
assert result is not None
|
||||
if expected_op == "not":
|
||||
from myfasthtml.core.formula.dataclasses import UnaryOp
|
||||
assert isinstance(result.expression, UnaryOp)
|
||||
else:
|
||||
assert isinstance(result.expression, BinaryOp)
|
||||
assert result.expression.operator == expected_op
|
||||
|
||||
|
||||
@pytest.mark.parametrize("formula_text", [
|
||||
"42",
|
||||
"3.14",
|
||||
'"hello"',
|
||||
"true",
|
||||
"false",
|
||||
])
|
||||
def test_i_can_parse_literals(formula_text):
|
||||
"""Test that literal values parse correctly."""
|
||||
result = parse_formula(formula_text)
|
||||
assert result is not None
|
||||
assert isinstance(result.expression, LiteralNode)
|
||||
|
||||
|
||||
def test_i_can_parse_aggregation():
|
||||
"""Test that aggregation with cross-table WHERE parses correctly."""
|
||||
result = parse_formula("sum({OrderLines.Amount where OrderLines.OrderId = Id})")
|
||||
assert result is not None
|
||||
expr = result.expression
|
||||
assert isinstance(expr, FunctionCall)
|
||||
assert expr.function_name == "sum"
|
||||
assert len(expr.arguments) == 1
|
||||
arg = expr.arguments[0]
|
||||
assert isinstance(arg, CrossTableRef)
|
||||
assert arg.where_clause is not None
|
||||
|
||||
|
||||
def test_i_can_parse_empty_formula():
|
||||
"""Test that empty formula returns None."""
|
||||
result = parse_formula("")
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_i_can_parse_whitespace_formula():
|
||||
"""Test that whitespace-only formula returns None."""
|
||||
result = parse_formula(" ")
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_i_can_parse_nested_functions():
|
||||
"""Test that nested function calls parse correctly."""
|
||||
result = parse_formula("round(avg({Q1}, {Q2}, {Q3}), 1)")
|
||||
assert result is not None
|
||||
expr = result.expression
|
||||
assert isinstance(expr, FunctionCall)
|
||||
assert expr.function_name == "round"
|
||||
|
||||
|
||||
# ==================== Invalid formulas ====================
|
||||
|
||||
@pytest.mark.parametrize("formula_text", [
|
||||
"{Price} * * {Quantity}", # double operator
|
||||
"round(", # unclosed paren
|
||||
"123 + + 456", # double operator
|
||||
])
|
||||
def test_i_cannot_parse_invalid_syntax(formula_text):
|
||||
"""Test that invalid syntax raises FormulaSyntaxError."""
|
||||
with pytest.raises(FormulaSyntaxError):
|
||||
parse_formula(formula_text)
|
||||
|
||||
|
||||
def test_i_cannot_parse_unclosed_brace():
|
||||
"""Test that an unclosed brace raises FormulaSyntaxError."""
|
||||
with pytest.raises(FormulaSyntaxError):
|
||||
parse_formula("{Price")
|
||||
Reference in New Issue
Block a user