From e8443f07f9399cf6c41f4608ab426f1cdfe1fc2d Mon Sep 17 00:00:00 2001 From: Kodjo Sossouvi Date: Fri, 13 Feb 2026 21:38:00 +0100 Subject: [PATCH] Introducing columns formulas --- docs/Datagrid Formulas.md | 365 ++++++++++++ src/myfasthtml/controls/DataGrid.py | 38 ++ .../controls/DataGridColumnsManager.py | 29 + .../controls/DataGridFormulaEditor.py | 67 +++ src/myfasthtml/controls/DataGridQuery.py | 2 +- src/myfasthtml/controls/DataGridsManager.py | 38 +- src/myfasthtml/controls/datagrid_objects.py | 1 + src/myfasthtml/core/constants.py | 1 + .../core/formatting/style_resolver.py | 5 +- src/myfasthtml/core/formula/__init__.py | 0 src/myfasthtml/core/formula/dataclasses.py | 79 +++ .../core/formula/dependency_graph.py | 386 +++++++++++++ src/myfasthtml/core/formula/dsl/__init__.py | 0 .../dsl/completion/FormulaCompletionEngine.py | 180 ++++++ .../core/formula/dsl/completion/__init__.py | 1 + src/myfasthtml/core/formula/dsl/definition.py | 79 +++ src/myfasthtml/core/formula/dsl/exceptions.py | 35 ++ src/myfasthtml/core/formula/dsl/grammar.py | 100 ++++ src/myfasthtml/core/formula/dsl/parser.py | 85 +++ .../core/formula/dsl/transformer.py | 274 +++++++++ src/myfasthtml/core/formula/engine.py | 398 +++++++++++++ src/myfasthtml/core/formula/evaluator.py | 522 ++++++++++++++++++ tests/core/formatting/test_engine.py | 8 +- tests/core/formatting/test_style_resolver.py | 36 +- tests/core/formula/__init__.py | 0 tests/core/formula/test_dependency_graph.py | 391 +++++++++++++ tests/core/formula/test_formula_engine.py | 408 ++++++++++++++ tests/core/formula/test_formula_evaluator.py | 188 +++++++ tests/core/formula/test_formula_parser.py | 188 +++++++ 29 files changed, 3889 insertions(+), 15 deletions(-) create mode 100644 docs/Datagrid Formulas.md create mode 100644 src/myfasthtml/controls/DataGridFormulaEditor.py create mode 100644 src/myfasthtml/core/formula/__init__.py create mode 100644 src/myfasthtml/core/formula/dataclasses.py create mode 100644 src/myfasthtml/core/formula/dependency_graph.py create mode 100644 src/myfasthtml/core/formula/dsl/__init__.py create mode 100644 src/myfasthtml/core/formula/dsl/completion/FormulaCompletionEngine.py create mode 100644 src/myfasthtml/core/formula/dsl/completion/__init__.py create mode 100644 src/myfasthtml/core/formula/dsl/definition.py create mode 100644 src/myfasthtml/core/formula/dsl/exceptions.py create mode 100644 src/myfasthtml/core/formula/dsl/grammar.py create mode 100644 src/myfasthtml/core/formula/dsl/parser.py create mode 100644 src/myfasthtml/core/formula/dsl/transformer.py create mode 100644 src/myfasthtml/core/formula/engine.py create mode 100644 src/myfasthtml/core/formula/evaluator.py create mode 100644 tests/core/formula/__init__.py create mode 100644 tests/core/formula/test_dependency_graph.py create mode 100644 tests/core/formula/test_formula_engine.py create mode 100644 tests/core/formula/test_formula_evaluator.py create mode 100644 tests/core/formula/test_formula_parser.py diff --git a/docs/Datagrid Formulas.md b/docs/Datagrid Formulas.md new file mode 100644 index 0000000..62a2628 --- /dev/null +++ b/docs/Datagrid Formulas.md @@ -0,0 +1,365 @@ +# DataGrid Formulas + +## Overview + +The DataGrid formula system adds computed columns to the DataGrid. A formula column applies a single expression to every +row, producing derived values from existing data — within the same table or across tables. + +The system is designed for: + +- **Column-level formulas**: one formula per column, applied to all rows +- **Cross-table references**: direct syntax to reference columns from other tables +- **Reactive recalculation**: dirty flag propagation with page-aware computation +- **Cell-level overrides** (planned): individual cells can override the column formula + +## Formula Language + +### Basic Syntax + +A formula is an expression that references columns with `{ColumnName}` and produces a value for each row: + +``` +{Price} * {Quantity} +``` + +References use curly braces `{}` to distinguish column names from keywords and functions. Column names are matched by ID +or title. + +### Operators + +#### Arithmetic + +| Operator | Description | Example | +|----------|----------------|------------------------| +| `+` | Addition | `{Price} + {Tax}` | +| `-` | Subtraction | `{Total} - {Discount}` | +| `*` | Multiplication | `{Price} * {Quantity}` | +| `/` | Division | `{Total} / {Count}` | +| `%` | Modulo | `{Value} % 2` | +| `^` | Power | `{Base} ^ 2` | + +#### Comparison + +| Operator | Description | Example | +|--------------|--------------------|---------------------------------| +| `==` | Equal | `{Status} == "active"` | +| `!=` | Not equal | `{Status} != "deleted"` | +| `>` | Greater than | `{Price} > 100` | +| `<` | Less than | `{Stock} < 10` | +| `>=` | Greater or equal | `{Score} >= 80` | +| `<=` | Less or equal | `{Age} <= 18` | +| `contains` | String contains | `{Name} contains "Corp"` | +| `startswith` | String starts with | `{Code} startswith "ERR"` | +| `endswith` | String ends with | `{File} endswith ".csv"` | +| `in` | Value in list | `{Status} in ["active", "new"]` | +| `between` | Value in range | `{Age} between 18 and 65` | +| `isempty` | Value is empty | `{Notes} isempty` | +| `isnotempty` | Value is not empty | `{Email} isnotempty` | +| `isnan` | Value is NaN | `{Score} isnan` | + +#### Logical + +| Operator | Description | Example | +|----------|-------------|---------------------------------------| +| `and` | Logical AND | `{Age} > 18 and {Status} == "active"` | +| `or` | Logical OR | `{Type} == "A" or {Type} == "B"` | +| `not` | Negation | `not {Status} == "deleted"` | + +Parentheses control precedence: `({Type} == "A" or {Type} == "B") and {Active} == True` + +### Conditions (suffix-if) + +Conditions use a **suffix-if** syntax: the result expression comes first, then the condition. This keeps the focus on +the output, not the branching logic. + +#### Simple condition (no else — result is None when false) + +``` +{Price} * 0.8 if {Country} == "FR" +``` + +#### With else + +``` +{Price} * 0.8 if {Country} == "FR" else {Price} +``` + +#### Chained conditions + +``` +{Price} * 0.8 if {Country} == "FR" else {Price} * 0.9 if {Country} == "DE" else {Price} +``` + +#### With logical operators + +``` +{Price} * 0.8 if {Country} == "FR" and {Quantity} > 10 else {Price} +``` + +#### With grouping + +``` +{Price} * 0.8 if ({Country} == "FR" or {Country} == "DE") and {Quantity} > 10 +``` + +### Functions + +#### Math + +| Function | Description | Example | +|-------------------|-----------------------|-------------------------------| +| `round(expr, n)` | Round to n decimals | `round({Price} * 1.2, 2)` | +| `abs(expr)` | Absolute value | `abs({Balance})` | +| `min(expr, expr)` | Minimum of two values | `min({Price}, {MaxPrice})` | +| `max(expr, expr)` | Maximum of two values | `max({Score}, 0)` | +| `sum(expr, ...)` | Sum of values | `sum({Q1}, {Q2}, {Q3}, {Q4})` | +| `avg(expr, ...)` | Average of values | `avg({Q1}, {Q2}, {Q3}, {Q4})` | + +#### Text + +| Function | Description | Example | +|---------------------|---------------------|--------------------------------| +| `upper(expr)` | Uppercase | `upper({Name})` | +| `lower(expr)` | Lowercase | `lower({Email})` | +| `len(expr)` | String length | `len({Description})` | +| `concat(expr, ...)` | Concatenate strings | `concat({First}, " ", {Last})` | +| `trim(expr)` | Remove whitespace | `trim({Input})` | +| `left(expr, n)` | First n characters | `left({Code}, 3)` | +| `right(expr, n)` | Last n characters | `right({Phone}, 4)` | + +#### Date + +| Function | Description | Example | +|------------------------|--------------------|--------------------------------| +| `year(expr)` | Extract year | `year({CreatedAt})` | +| `month(expr)` | Extract month | `month({CreatedAt})` | +| `day(expr)` | Extract day | `day({CreatedAt})` | +| `today()` | Current date | `datediff({DueDate}, today())` | +| `datediff(expr, expr)` | Difference in days | `datediff({End}, {Start})` | + +#### Aggregation (for cross-table contexts) + +| Function | Description | Example | +|---------------|--------------|-----------------------------------------------------| +| `sum(expr)` | Sum values | `sum({Orders.Amount WHERE Orders.ClientId = Id})` | +| `count(expr)` | Count values | `count({Orders.Id WHERE Orders.ClientId = Id})` | +| `avg(expr)` | Average | `avg({Reviews.Score WHERE Reviews.ProductId = Id})` | +| `min(expr)` | Minimum | `min({Bids.Price WHERE Bids.ItemId = Id})` | +| `max(expr)` | Maximum | `max({Bids.Price WHERE Bids.ItemId = Id})` | + +## Cross-Table References + +### Direct Reference + +Reference a column from another table using `{TableName.ColumnName}`: + +``` +{Products.Price} * {Quantity} +``` + +### Join Resolution (implicit) + +When referencing another table without a WHERE clause, the join is resolved automatically: + +1. **By `id` column**: if both tables have a column named `id`, rows are matched on equal `id` values +2. **By row index**: if no `id` column exists in both tables, rows are matched by their internal row index (stable + across sort/filter) + +### Explicit Join (WHERE clause) + +For explicit control over which row of the other table to use: + +``` +{Products.Price WHERE Products.Code = ProductCode} * {Quantity} +``` + +Inside the WHERE clause: + +- `Products.Code` refers to a column in the referenced table +- `ProductCode` (no `Table.` prefix) refers to a column in the current table + +### Aggregation with Cross-Table + +When a cross-table reference matches multiple rows, use an aggregation function: + +``` +sum({OrderLines.Amount WHERE OrderLines.OrderId = Id}) +``` + +Without aggregation, a multi-row match returns the first matching value. + +## Calculation Engine + +### Dependency Graph (DAG) + +The formula system maintains a **Directed Acyclic Graph** of dependencies between columns: + +- **Nodes**: each formula column is a node, identified by `table_name.column_id` +- **Edges**: if column A's formula references column B, an edge B → A exists ("A depends on B") +- Both directions are tracked: + - **Precedents**: columns that a formula reads from + - **Dependents**: columns that need recalculation when this column changes + +Cross-table references create edges that span DataGrid instances, managed at the `DataGridsManager` level. + +### Dirty Flag Propagation + +When a source column's data changes: + +1. The source column is marked **dirty** +2. All direct dependents are marked dirty +3. Propagation continues recursively through the DAG +4. Each dirty column maintains a **dirty row set**: the specific row indices that need recalculation + +This propagation is **immediate** (fast — only flag marking, no computation). + +### Recalculation Strategy (Hybrid) + +Actual computation is **deferred to rendering time**: + +1. On value change → dirty flags propagate instantly through the DAG +2. On page render (`mk_body_content_page`) → only dirty rows within the visible page (up to 1000 rows) are recalculated +3. Off-screen pages remain dirty until scrolled into view +4. Calculation follows **topological order** of the DAG to ensure precedents are computed before dependents + +### Cycle Detection + +Before adding a formula, the engine checks for cycles in the DAG using Kahn's algorithm during topological sort. If a +cycle is detected: + +- The formula is **rejected** +- The editor displays an error identifying the circular dependency chain +- The previous formula (if any) remains unchanged + +### Caching + +Each formula column caches its computed values: + +- Results are stored in `ns_fast_access[col_id]` alongside raw data columns +- The dirty row set tracks which cached values are stale +- Non-dirty rows return their cached value without re-evaluation +- Cache is invalidated per-row when source data changes + +## Evaluation + +### Row-by-Row Execution + +Formulas are evaluated **row-by-row** within the page being rendered. For each row: + +1. Resolve column references `{ColumnName}` to the cell value at the current row index +2. Resolve cross-table references `{Table.Column}` via the join mechanism +3. Evaluate the expression with resolved values +4. Store the result in the cache (`ns_fast_access`) + +### Parser + +The formula language uses a **custom grammar** parsed with Lark (consistent with the formatting DSL). The parser: + +1. Tokenizes the formula string +2. Builds an AST (Abstract Syntax Tree) +3. Transforms the AST into an evaluable representation +4. Extracts column references for dependency graph registration + +### Error Handling + +| Error Type | Behavior | +|-----------------------|-------------------------------------------------------| +| Syntax error | Editor highlights the error, formula not saved | +| Unknown column | Editor highlights, autocompletion suggests fixes | +| Type mismatch | Cell displays error indicator, other cells unaffected | +| Division by zero | Cell displays `#DIV/0!` or None | +| Circular dependency | Formula rejected, editor shows cycle chain | +| Cross-table not found | Editor highlights unknown table name | +| No join match | Cell displays None | + +## User Interface + +### Creating a Formula Column + +Formula columns are created and edited through the **DataGridColumnsManager**: + +1. User opens the Columns Manager panel +2. Adds a new column or edits an existing one +3. Selects column type **"Formula"** +4. A **DslEditor** (CodeMirror 5) opens for formula input +5. The editor provides: + - **Syntax highlighting**: keywords, column references, functions, operators + - **Autocompletion**: column names (current table and other tables), function names, table names + - **Validation**: real-time syntax checking and dependency cycle detection + - **Error markers**: inline error indicators with descriptions + +### Formula Column Properties + +A formula column extends `DataGridColumnState` with: + +| Property | Type | Description | +|---------------------------------------------------------------------------|---------------|------------------------------------------------| +| `formula` | `str` or None | The formula expression (None for data columns) | +| `col_type` | `ColumnType` | Set to `ColumnType.Formula` | +| Other properties (`title`, `visible`, `width`, `format`) remain unchanged | + +Formula columns are **read-only** in the grid body — cell values are computed, not editable. Formatting rules from the +formatting DSL apply to formula columns like any other column. + +## Integration Points + +| Component | Role | +|--------------------------|----------------------------------------------------------| +| `DataGridColumnState` | Stores `formula` field and `ColumnType.Formula` type | +| `DatagridStore` | `ns_fast_access` caches formula results as numpy arrays | +| `DataGridColumnsManager` | UI for creating/editing formula columns | +| `DataGridsManager` | Hosts the global dependency DAG across all tables | +| `DslEditor` | CodeMirror 5 editor with highlighting and autocompletion | +| `FormattingEngine` | Applies formatting rules AFTER formula evaluation | +| `mk_body_content_page()` | Triggers formula computation for visible rows | +| `mk_body_cell_content()` | Reads computed values from `ns_fast_access` | + +## Syntax Summary + +``` +# Basic arithmetic +{Price} * {Quantity} + +# Function call +round({Price} * 1.2, 2) + +# Simple condition (None if false) +{Price} * 0.8 if {Country} == "FR" + +# Condition with else +{Price} * 0.8 if {Country} == "FR" else {Price} + +# Chained conditions +{Price} * 0.8 if {Country} == "FR" else {Price} * 0.9 if {Country} == "DE" else {Price} + +# Logical operators +{Price} * 0.8 if {Country} == "FR" and {Quantity} > 10 + +# Grouping +{Price} * 0.8 if ({Country} == "FR" or {Country} == "DE") and {Quantity} > 10 + +# Cross-table (implicit join on id) +{Products.Price} * {Quantity} + +# Cross-table (explicit join) +{Products.Price WHERE Products.Code = ProductCode} * {Quantity} + +# Cross-table aggregation +sum({OrderLines.Amount WHERE OrderLines.OrderId = Id}) + +# Nested functions +round(avg({Q1}, {Q2}, {Q3}, {Q4}), 1) + +# Text operations +concat(upper(left({FirstName}, 1)), ". ", {LastName}) +``` + +## Future: Cell-Level Overrides + +The architecture supports adding cell-level formula overrides with ~20-30% additional work: + +- **Storage**: sparse dict `cell_formulas: dict[(col_id, row_index), str]` (same pattern as `cell_formats`) +- **DAG**: new node type `table.column[row]` alongside existing `table.column` nodes +- **Evaluation**: "does this cell have an override? If yes, use it. Otherwise, use the column formula." +- **Node ID scheme**: designed to be extensible from the start (`table.column` for columns, `table.column[row]` for + cells) diff --git a/src/myfasthtml/controls/DataGrid.py b/src/myfasthtml/controls/DataGrid.py index f3e346e..46d944c 100644 --- a/src/myfasthtml/controls/DataGrid.py +++ b/src/myfasthtml/controls/DataGrid.py @@ -418,9 +418,41 @@ class DataGrid(MultipleInstance): self._df_store.ns_fast_access = _init_fast_access(self._df) self._df_store.ns_row_data = _init_row_data(self._df) self._df_store.ns_total_rows = len(self._df) if self._df is not None else 0 + if init_state: + self._register_existing_formulas() return self + def _register_existing_formulas(self) -> None: + """ + Re-register all formula columns with the FormulaEngine. + + Called after data reload to ensure the engine knows about all + formula columns and their expressions. + """ + engine = self._get_formula_engine() + if engine is None: + return + table = self.get_table_name() + for col_def in self._state.columns: + if col_def.formula: + try: + engine.set_formula(table, col_def.col_id, col_def.formula) + except Exception as e: + logger.warning("Failed to register formula for %s.%s: %s", table, col_def.col_id, e) + + def _recalculate_formulas(self) -> None: + """ + Recalculate dirty formula columns before rendering. + + Called at the start of mk_body_content_page() to ensure formula + columns are up-to-date before cells are rendered. + """ + engine = self._get_formula_engine() + if engine is None: + return + engine.recalculate_if_needed(self.get_table_name(), self._df_store) + def _get_format_rules(self, col_pos, row_index, col_def): """ Get format rules for a cell, returning only the most specific level defined. @@ -575,6 +607,11 @@ class DataGrid(MultipleInstance): def get_table_name(self): return f"{self._settings.namespace}.{self._settings.name}" if self._settings.namespace else self._settings.name + def get_formula_engine(self): + """Return the FormulaEngine from the DataGridsManager, if available.""" + return self._parent.get_formula_engine() + + def mk_headers(self): resize_cmd = self.commands.set_column_width() move_cmd = self.commands.move_column() @@ -701,6 +738,7 @@ class DataGrid(MultipleInstance): OPTIMIZED: Extract filter keyword once instead of 10,000 times. OPTIMIZED: Uses OptimizedDiv for rows instead of Div for faster rendering. """ + self._recalculate_formulas() df = self._get_filtered_df() if df is None: return [] diff --git a/src/myfasthtml/controls/DataGridColumnsManager.py b/src/myfasthtml/controls/DataGridColumnsManager.py index 9e3c5ed..83eba7d 100644 --- a/src/myfasthtml/controls/DataGridColumnsManager.py +++ b/src/myfasthtml/controls/DataGridColumnsManager.py @@ -99,6 +99,9 @@ class DataGridColumnsManager(MultipleInstance): col_def.type = ColumnType(v) elif k == "width": col_def.width = int(v) + elif k == "formula": + col_def.formula = v or "" + self._register_formula(col_def) else: setattr(col_def, k, v) @@ -107,6 +110,21 @@ class DataGridColumnsManager(MultipleInstance): return self.mk_all_columns() + def _register_formula(self, col_def) -> None: + """Register or remove a formula column with the FormulaEngine.""" + engine = self._parent.get_formula_engine() + if engine is None: + return + table = self._parent.get_table_name() + if col_def.formula: + try: + engine.set_formula(table, col_def.col_id, col_def.formula) + logger.debug("Registered formula for %s.%s", table, col_def.col_id) + except Exception as e: + logger.warning("Formula error for %s.%s: %s", table, col_def.col_id, e) + else: + engine.remove_formula(table, col_def.col_id) + def mk_column_label(self, col_def: DataGridColumnState): return Div( mk.mk( @@ -168,6 +186,17 @@ class DataGridColumnsManager(MultipleInstance): value=col_def.title, ), + *([ + Label("Formula"), + Textarea( + col_def.formula or "", + name="formula", + cls=f"textarea textarea-{size} w-full font-mono", + placeholder="{Column} * {OtherColumn}", + rows=3, + ), + ] if col_def.type == ColumnType.Formula else []), + legend="Column details", cls="fieldset border-base-300 rounded-box" ), diff --git a/src/myfasthtml/controls/DataGridFormulaEditor.py b/src/myfasthtml/controls/DataGridFormulaEditor.py new file mode 100644 index 0000000..9bc7744 --- /dev/null +++ b/src/myfasthtml/controls/DataGridFormulaEditor.py @@ -0,0 +1,67 @@ +""" +DataGridFormulaEditor — DslEditor for formula column expressions. + +Extends DslEditor with formula-specific behavior: +- Parses the formula on content change +- Registers the formula with FormulaEngine +- Triggers a body re-render on the parent DataGrid +""" +import logging + +from myfasthtml.controls.DslEditor import DslEditor +from myfasthtml.core.formula.dsl.exceptions import FormulaSyntaxError, FormulaCycleError + +logger = logging.getLogger("DataGridFormulaEditor") + + +class DataGridFormulaEditor(DslEditor): + """ + Formula editor for a specific DataGrid column. + + Args: + parent: The parent DataGrid instance. + col_def: The DataGridColumnState for the formula column. + conf: DslEditorConf for CodeMirror configuration. + _id: Optional instance ID. + """ + + def __init__(self, parent, col_def, conf=None, _id=None): + super().__init__(parent, conf=conf, _id=_id) + self._col_def = col_def + + def on_content_changed(self): + """ + Called when the formula text is changed in the editor. + + 1. Updates col_def.formula with the new text. + 2. Registers the formula with the FormulaEngine. + 3. Triggers a body re-render of the parent DataGrid. + """ + formula_text = self.get_content() + + # Update the column definition + self._col_def.formula = formula_text or "" + + # Register with the FormulaEngine + engine = self._parent._get_formula_engine() + if engine is not None: + table = self._parent.get_table_name() + try: + engine.set_formula(table, self._col_def.col_id, formula_text) + logger.debug( + "Formula updated for %s.%s: %s", + table, self._col_def.col_id, formula_text, + ) + except FormulaSyntaxError as e: + logger.debug("Formula syntax error, keeping old formula: %s", e) + return + except FormulaCycleError as e: + logger.warning("Formula cycle detected for %s.%s: %s", table, self._col_def.col_id, e) + return + except Exception as e: + logger.warning("Formula engine error for %s.%s: %s", table, self._col_def.col_id, e) + return + + # Save state and re-render the grid body + self._parent.save_state() + return self._parent.render_partial("body") diff --git a/src/myfasthtml/controls/DataGridQuery.py b/src/myfasthtml/controls/DataGridQuery.py index 3063f01..bf339ac 100644 --- a/src/myfasthtml/controls/DataGridQuery.py +++ b/src/myfasthtml/controls/DataGridQuery.py @@ -12,7 +12,7 @@ from myfasthtml.icons.fluent import brain_circuit20_regular from myfasthtml.icons.fluent_p1 import filter20_regular, search20_regular from myfasthtml.icons.fluent_p2 import dismiss_circle20_regular -logger = logging.getLogger("DataGridFilter") +logger = logging.getLogger("DataGridQuery") DG_QUERY_FILTER = "filter" DG_QUERY_SEARCH = "search" diff --git a/src/myfasthtml/controls/DataGridsManager.py b/src/myfasthtml/controls/DataGridsManager.py index 7dc67d9..8c7d18c 100644 --- a/src/myfasthtml/controls/DataGridsManager.py +++ b/src/myfasthtml/controls/DataGridsManager.py @@ -16,6 +16,7 @@ from myfasthtml.core.commands import Command from myfasthtml.core.dbmanager import DbObject from myfasthtml.core.formatting.dsl.completion.provider import DatagridMetadataProvider from myfasthtml.core.formatting.presets import DEFAULT_STYLE_PRESETS, DEFAULT_FORMATTER_PRESETS +from myfasthtml.core.formula.engine import FormulaEngine from myfasthtml.core.instances import InstancesManager, SingleInstance from myfasthtml.icons.fluent_p1 import table_add20_regular from myfasthtml.icons.fluent_p3 import folder_open20_regular @@ -91,6 +92,11 @@ class DataGridsManager(SingleInstance, DatagridMetadataProvider): self.style_presets: dict = DEFAULT_STYLE_PRESETS.copy() self.formatter_presets: dict = DEFAULT_FORMATTER_PRESETS.copy() self.all_tables_formats: list = [] + + # Formula engine shared across all DataGrids in this session + self._formula_engine = FormulaEngine( + registry_resolver=self._resolve_store_for_table + ) def upload_from_source(self): file_upload = FileUpload(self) @@ -167,10 +173,10 @@ class DataGridsManager(SingleInstance, DatagridMetadataProvider): def list_column_values(self, table_name, column_name): return self._registry.get_column_values(table_name, column_name) - + def get_row_count(self, table_name): return self._registry.get_row_count(table_name) - + def get_column_type(self, table_name, column_name): return self._registry.get_column_type(table_name, column_name) @@ -180,7 +186,29 @@ class DataGridsManager(SingleInstance, DatagridMetadataProvider): def list_format_presets(self) -> list[str]: return list(self.formatter_presets.keys()) - # === Presets Management === + def _resolve_store_for_table(self, table_name: str): + """ + Resolve the DatagridStore for a given table name. + + Used by FormulaEngine as the registry_resolver callback. + + Args: + table_name: Full table name in ``"namespace.name"`` format. + + Returns: + DatagridStore instance or None if not found. + """ + try: + as_fullname_dict = self._registry._get_entries_as_full_name_dict() + grid_id = as_fullname_dict.get(table_name) + if grid_id is None: + return None + datagrid = InstancesManager.get(self._session, grid_id, None) + if datagrid is None: + return None + return datagrid._df_store + except Exception: + return None def get_style_presets(self) -> dict: """Get the global style presets.""" @@ -190,6 +218,10 @@ class DataGridsManager(SingleInstance, DatagridMetadataProvider): """Get the global formatter presets.""" return self.formatter_presets + def get_formula_engine(self) -> FormulaEngine: + """The FormulaEngine shared across all DataGrids in this session.""" + return self._formula_engine + def add_style_preset(self, name: str, preset: dict): """ Add or update a style preset. diff --git a/src/myfasthtml/controls/datagrid_objects.py b/src/myfasthtml/controls/datagrid_objects.py index 4de71f8..9354740 100644 --- a/src/myfasthtml/controls/datagrid_objects.py +++ b/src/myfasthtml/controls/datagrid_objects.py @@ -20,6 +20,7 @@ class DataGridColumnState: visible: bool = True width: int = DATAGRID_DEFAULT_COLUMN_WIDTH format: list = field(default_factory=list) # + formula: str = "" # formula expression for ColumnType.Formula columns @dataclass diff --git a/src/myfasthtml/core/constants.py b/src/myfasthtml/core/constants.py index 9e90051..5eede34 100644 --- a/src/myfasthtml/core/constants.py +++ b/src/myfasthtml/core/constants.py @@ -26,6 +26,7 @@ class ColumnType(Enum): Bool = "Boolean" Choice = "Choice" Enum = "Enum" + Formula = "Formula" class ViewType(Enum): diff --git a/src/myfasthtml/core/formatting/style_resolver.py b/src/myfasthtml/core/formatting/style_resolver.py index 6e36a08..8f98743 100644 --- a/src/myfasthtml/core/formatting/style_resolver.py +++ b/src/myfasthtml/core/formatting/style_resolver.py @@ -98,6 +98,9 @@ class StyleResolver: return StyleContainer(None, "") cls = props.pop("__class__", None) + if not props: + return StyleContainer(cls, "") + css = "; ".join(f"{key}: {value}" for key, value in props.items()) + ";" - + return StyleContainer(cls, css) diff --git a/src/myfasthtml/core/formula/__init__.py b/src/myfasthtml/core/formula/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/myfasthtml/core/formula/dataclasses.py b/src/myfasthtml/core/formula/dataclasses.py new file mode 100644 index 0000000..7bcf3cc --- /dev/null +++ b/src/myfasthtml/core/formula/dataclasses.py @@ -0,0 +1,79 @@ +from dataclasses import dataclass, field +from typing import Any, Optional + + +@dataclass +class FormulaNode: + """Base AST node for formula expressions.""" + pass + + +@dataclass +class LiteralNode(FormulaNode): + """A literal value (number, string, boolean).""" + value: Any + + +@dataclass +class ColumnRef(FormulaNode): + """Reference to a column in the current table: {ColumnName}.""" + column: str + + +@dataclass +class WhereClause: + """WHERE clause for cross-table references: WHERE remote_table.remote_col = local_col.""" + remote_table: str + remote_column: str + local_column: str + + +@dataclass +class CrossTableRef(FormulaNode): + """Reference to a column in another table: {Table.Column}.""" + table: str + column: str + where_clause: Optional[WhereClause] = None + + +@dataclass +class BinaryOp(FormulaNode): + """Binary operation: left op right.""" + operator: str + left: FormulaNode + right: FormulaNode + + +@dataclass +class UnaryOp(FormulaNode): + """Unary operation: -expr or not expr.""" + operator: str + operand: FormulaNode + + +@dataclass +class FunctionCall(FormulaNode): + """Function call: func(args...).""" + function_name: str + arguments: list = field(default_factory=list) + + +@dataclass +class ConditionalExpr(FormulaNode): + """Conditional: value_expr if condition [else else_expr]. + + Chainable: val1 if cond1 else val2 if cond2 else val3 + """ + value_expr: FormulaNode + condition: FormulaNode + else_expr: Optional[FormulaNode] = None + + +@dataclass +class FormulaDefinition: + """A complete formula definition for a column.""" + expression: FormulaNode + source_text: str = "" + + def __str__(self): + return self.source_text diff --git a/src/myfasthtml/core/formula/dependency_graph.py b/src/myfasthtml/core/formula/dependency_graph.py new file mode 100644 index 0000000..9fe1330 --- /dev/null +++ b/src/myfasthtml/core/formula/dependency_graph.py @@ -0,0 +1,386 @@ +""" +Dependency Graph (DAG) for formula columns. + +Tracks column dependencies, propagates dirty flags, and provides +topological ordering for incremental recalculation. + +Node IDs use the format ``"table_name.column_id"`` for column-level +granularity, designed to be extensible to ``"table_name.column_id[row]"`` +for cell-level overrides. +""" +import logging +from collections import defaultdict, deque +from dataclasses import dataclass, field +from typing import Optional, Set + +from .dataclasses import ( + FormulaDefinition, + ColumnRef, + CrossTableRef, + BinaryOp, + UnaryOp, + FunctionCall, + ConditionalExpr, + FormulaNode, +) +from .dsl.exceptions import FormulaCycleError + +logger = logging.getLogger("DependencyGraph") + + +@dataclass +class DependencyNode: + """ + A node in the dependency graph. + + Attributes: + node_id: Unique identifier in the format ``"table.column"``. + table: Table name. + column: Column name. + dirty: Whether this node needs recalculation. + dirty_rows: Set of specific row indices that are dirty. + Empty set means all rows are dirty. + formula: The parsed FormulaDefinition for formula nodes. + """ + node_id: str + table: str + column: str + dirty: bool = False + dirty_rows: Set[int] = field(default_factory=set) + formula: Optional[FormulaDefinition] = None + + +class DependencyGraph: + """ + Directed Acyclic Graph of formula column dependencies. + + Tracks which columns depend on which other columns and provides: + - Dirty flag propagation via BFS + - Topological ordering for recalculation (Kahn's algorithm) + - Cycle detection before registering new formulas + + The graph is bidirectional: + - ``_dependents[A]`` = set of nodes that depend on A (forward edges) + - ``_precedents[B]`` = set of nodes that B depends on (reverse edges) + """ + + def __init__(self): + self._nodes: dict[str, DependencyNode] = {} + # forward: A -> {B, C} means "B and C depend on A" + self._dependents: dict[str, set[str]] = defaultdict(set) + # reverse: B -> {A} means "B depends on A" + self._precedents: dict[str, set[str]] = defaultdict(set) + + def add_formula( + self, + table: str, + column: str, + formula: FormulaDefinition, + ) -> None: + """ + Register a formula for a column and add dependency edges. + + Raises: + FormulaCycleError: If adding this formula would create a cycle. + + Args: + table: Table name. + column: Column name. + formula: The parsed FormulaDefinition. + """ + logger.debug(f"add_formula {table}.{column}:{formula.source_text}") + + node_id = self._make_node_id(table, column) + + # Extract dependency node_ids from the formula AST + dep_ids = self._extract_dependencies(formula.expression, table) + + # Temporarily remove old edges to avoid stale dependencies + self._remove_edges(node_id) + + # Add new edges + for dep_id in dep_ids: + if dep_id == node_id: + raise FormulaCycleError([node_id]) + self._dependents[dep_id].add(node_id) + self._precedents[node_id].add(dep_id) + + # Ensure all referenced nodes exist (as data nodes without formulas) + for dep_id in dep_ids: + if dep_id not in self._nodes: + dep_table, dep_col = dep_id.split(".", 1) + self._nodes[dep_id] = DependencyNode( + node_id=dep_id, + table=dep_table, + column=dep_col, + ) + + # Ensure formula node exists + node = self._get_or_create_node(table, column) + node.formula = formula + node.dirty = True # New formula -> needs evaluation + + # Detect cycles using Kahn's algorithm + self._detect_cycles() + + logger.debug("Added formula for %s depending on: %s", node_id, dep_ids) + + def remove_formula(self, table: str, column: str) -> None: + """ + Remove a formula column and its edges from the graph. + + Args: + table: Table name. + column: Column name. + """ + node_id = self._make_node_id(table, column) + self._remove_edges(node_id) + + if node_id in self._nodes: + node = self._nodes[node_id] + node.formula = None + node.dirty = False + node.dirty_rows.clear() + # If the node has no dependents either, remove it + if not self._dependents.get(node_id): + del self._nodes[node_id] + + logger.debug("Removed formula for %s", node_id) + + def get_calculation_order(self, table: Optional[str] = None) -> list[DependencyNode]: + """ + Return dirty formula nodes in topological order. + + Uses Kahn's algorithm (BFS-based topological sort). + Only returns nodes with a formula that are dirty. + + Args: + table: If provided, filter to only nodes for this table. + + Returns: + List of dirty DependencyNode objects in calculation order. + """ + # Build in-degree map for nodes with formulas + formula_nodes = { + nid: node for nid, node in self._nodes.items() + if node.formula is not None and node.dirty + } + + if not formula_nodes: + return [] + + # Kahn's algorithm on the subgraph of formula nodes + in_degree = {nid: 0 for nid in formula_nodes} + for nid in formula_nodes: + for prec_id in self._precedents.get(nid, set()): + if prec_id in formula_nodes: + in_degree[nid] += 1 + + queue = deque([nid for nid, deg in in_degree.items() if deg == 0]) + result = [] + + while queue: + nid = queue.popleft() + node = self._nodes[nid] + if table is None or node.table == table: + result.append(node) + for dep_id in self._dependents.get(nid, set()): + if dep_id in in_degree: + in_degree[dep_id] -= 1 + if in_degree[dep_id] == 0: + queue.append(dep_id) + + return result + + def clear_dirty(self, node_id: str) -> None: + """ + Clear dirty flags for a node after successful recalculation. + + Args: + node_id: The node ID in format ``"table.column"``. + """ + if node_id in self._nodes: + node = self._nodes[node_id] + node.dirty = False + node.dirty_rows.clear() + + def get_node(self, table: str, column: str) -> Optional[DependencyNode]: + """ + Get a node by table and column. + + Args: + table: Table name. + column: Column name. + + Returns: + DependencyNode or None if not found. + """ + node_id = self._make_node_id(table, column) + return self._nodes.get(node_id) + + def has_formula(self, table: str, column: str) -> bool: + """ + Check if a column has a formula registered. + + Args: + table: Table name. + column: Column name. + + Returns: + True if the column has a formula. + """ + node = self.get_node(table, column) + return node is not None and node.formula is not None + + def mark_dirty( + self, + table: str, + column: str, + rows: Optional[list[int]] = None, + ) -> None: + """ + Mark a column (and its transitive dependents) as dirty. + + Uses BFS to propagate dirty flags through the dependency graph. + + Args: + table: Table name. + column: Column name. + rows: Specific row indices to mark dirty. None means all rows. + """ + node_id = self._make_node_id(table, column) + self._mark_node_dirty(node_id, rows) + + # BFS propagation through dependents + queue = deque([node_id]) + visited = {node_id} + + while queue: + current_id = queue.popleft() + for dep_id in self._dependents.get(current_id, set()): + self._mark_node_dirty(dep_id, rows) + if dep_id not in visited: + visited.add(dep_id) + queue.append(dep_id) + + # ==================== Private helpers ==================== + + @staticmethod + def _make_node_id(table: str, column: str) -> str: + """Create a standard node ID from table and column names.""" + return f"{table}.{column}" + + def _get_or_create_node(self, table: str, column: str) -> DependencyNode: + """Get existing node or create a new one.""" + node_id = self._make_node_id(table, column) + if node_id not in self._nodes: + self._nodes[node_id] = DependencyNode( + node_id=node_id, + table=table, + column=column, + ) + return self._nodes[node_id] + + def _mark_node_dirty(self, node_id: str, rows: Optional[list[int]]) -> None: + """Mark a specific node as dirty.""" + if node_id not in self._nodes: + return + node = self._nodes[node_id] + node.dirty = True + if rows is not None: + node.dirty_rows.update(rows) + else: + node.dirty_rows.clear() # Empty = all rows dirty + + def _remove_edges(self, node_id: str) -> None: + """Remove all edges connected to a node.""" + # Remove this node from its precedents' dependents sets + for prec_id in list(self._precedents.get(node_id, set())): + self._dependents[prec_id].discard(node_id) + # Clear this node's precedents + self._precedents[node_id].clear() + + def _detect_cycles(self) -> None: + """ + Detect cycles in the full graph using Kahn's algorithm. + + Raises: + FormulaCycleError: If a cycle is detected. + """ + # Only check formula nodes + formula_nodes = { + nid for nid, node in self._nodes.items() + if node.formula is not None + } + + if not formula_nodes: + return + + in_degree = {} + for nid in formula_nodes: + in_degree[nid] = 0 + for nid in formula_nodes: + for prec_id in self._precedents.get(nid, set()): + if prec_id in formula_nodes: + in_degree[nid] = in_degree.get(nid, 0) + 1 + + queue = deque([nid for nid in formula_nodes if in_degree.get(nid, 0) == 0]) + processed = set() + + while queue: + nid = queue.popleft() + processed.add(nid) + for dep_id in self._dependents.get(nid, set()): + if dep_id in formula_nodes: + in_degree[dep_id] -= 1 + if in_degree[dep_id] == 0: + queue.append(dep_id) + + cycle_nodes = formula_nodes - processed + if cycle_nodes: + raise FormulaCycleError(sorted(cycle_nodes)) + + def _extract_dependencies( + self, + node: FormulaNode, + current_table: str, + ) -> set[str]: + """ + Recursively extract all column dependency IDs from a formula AST. + + Args: + node: The AST node to walk. + current_table: The table containing this formula (for ColumnRef). + + Returns: + Set of dependency node IDs (``"table.column"`` format). + """ + deps = set() + + if isinstance(node, ColumnRef): + deps.add(self._make_node_id(current_table, node.column)) + + elif isinstance(node, CrossTableRef): + deps.add(self._make_node_id(node.table, node.column)) + # Also depend on the local column used in WHERE clause + if node.where_clause is not None: + deps.add(self._make_node_id(current_table, node.where_clause.local_column)) + + elif isinstance(node, BinaryOp): + deps.update(self._extract_dependencies(node.left, current_table)) + deps.update(self._extract_dependencies(node.right, current_table)) + + elif isinstance(node, UnaryOp): + deps.update(self._extract_dependencies(node.operand, current_table)) + + elif isinstance(node, FunctionCall): + for arg in node.arguments: + deps.update(self._extract_dependencies(arg, current_table)) + + elif isinstance(node, ConditionalExpr): + deps.update(self._extract_dependencies(node.value_expr, current_table)) + deps.update(self._extract_dependencies(node.condition, current_table)) + if node.else_expr is not None: + deps.update(self._extract_dependencies(node.else_expr, current_table)) + + return deps diff --git a/src/myfasthtml/core/formula/dsl/__init__.py b/src/myfasthtml/core/formula/dsl/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/myfasthtml/core/formula/dsl/completion/FormulaCompletionEngine.py b/src/myfasthtml/core/formula/dsl/completion/FormulaCompletionEngine.py new file mode 100644 index 0000000..d158734 --- /dev/null +++ b/src/myfasthtml/core/formula/dsl/completion/FormulaCompletionEngine.py @@ -0,0 +1,180 @@ +""" +Autocompletion engine for the DataGrid Formula DSL. + +Provides context-aware suggestions for: +- Column names (after ``{``) +- Cross-table references (``{Table.``) +- Built-in function names +- Keywords: ``if``, ``else``, ``and``, ``or``, ``not``, ``WHERE`` +""" +import re + +from myfasthtml.core.dsl.base_completion import BaseCompletionEngine +from myfasthtml.core.dsl.types import Position, Suggestion +from myfasthtml.core.formula.evaluator import BUILTIN_FUNCTIONS +from myfasthtml.core.utils import make_safe_id + +FORMULA_KEYWORDS = [ + "if", "else", "and", "or", "not", "where", + "between", "in", "isempty", "isnotempty", "isnan", + "contains", "startswith", "endswith", + "true", "false", +] + + +class FormulaCompletionEngine(BaseCompletionEngine): + """ + Context-aware completion engine for formula expressions. + + Provides suggestions for column references, functions, and keywords. + + Args: + provider: DataGrid metadata provider (DataGridsManager or similar). + table_name: Name of the current table in ``"namespace.name"`` format. + """ + + def __init__(self, provider, table_name: str): + super().__init__(provider) + self.table_name = table_name + self._id = "formula_completion_engine#" + make_safe_id(table_name) + + def detect_scope(self, text: str, current_line: int): + """Formula has no scope — always the same single-expression scope.""" + return None + + def detect_context(self, text: str, cursor: Position, scope): + """ + Detect completion context based on cursor position in formula text. + + Args: + text: The full formula text. + cursor: Cursor position (line, ch). + scope: Unused (formulas have no scopes). + + Returns: + Context string: ``"column_ref"``, ``"cross_table"``, + ``"function"``, ``"keyword"``, or ``"general"``. + """ + # Get text up to cursor + lines = text.split("\n") + line_idx = min(cursor.line, len(lines) - 1) + line_text = lines[line_idx] + text_before = line_text[:cursor.ch] + + # Check if we are inside a { ... } reference + last_brace = text_before.rfind("{") + if last_brace >= 0: + inside = text_before[last_brace + 1:] + if "}" not in inside: + if "." in inside: + return "cross_table" + return "column_ref" + + # Check if we are typing a function name (alphanumeric at word start) + word_match = re.search(r"[a-z_][a-z0-9_]*$", text_before, re.IGNORECASE) + if word_match: + return "function_or_keyword" + + return "general" + + def get_suggestions(self, text: str, cursor: Position, scope, context) -> list: + """ + Generate suggestions based on the detected context. + + Args: + text: The full formula text. + cursor: Cursor position. + scope: Unused. + context: String from ``detect_context``. + + Returns: + List of Suggestion objects. + """ + suggestions = [] + + if context == "column_ref": + # Suggest columns from the current table + suggestions += self._column_suggestions(self.table_name) + + elif context == "cross_table": + # Get the table name prefix from text_before + lines = text.split("\n") + line_text = lines[min(cursor.line, len(lines) - 1)] + text_before = line_text[:cursor.ch] + last_brace = text_before.rfind("{") + inside = text_before[last_brace + 1:] if last_brace >= 0 else "" + dot_pos = inside.rfind(".") + table_prefix = inside[:dot_pos] if dot_pos >= 0 else "" + + # Suggest columns from the referenced table + if table_prefix: + suggestions += self._column_suggestions(table_prefix) + else: + suggestions += self._table_suggestions() + + elif context == "function_or_keyword": + suggestions += self._function_suggestions() + suggestions += self._keyword_suggestions() + + else: # general + suggestions += self._function_suggestions() + suggestions += self._keyword_suggestions() + suggestions += [ + Suggestion( + label="{", + detail="Column reference", + insert_text="{", + ) + ] + + return suggestions + + # ==================== Private helpers ==================== + + def _column_suggestions(self, table_name: str) -> list: + """Get column name suggestions for a table.""" + try: + columns = self.provider.list_columns(table_name) + return [ + Suggestion( + label=col, + detail=f"Column from {table_name}", + insert_text=col, + ) + for col in (columns or []) + ] + except Exception: + return [] + + def _table_suggestions(self) -> list: + """Get table name suggestions.""" + try: + tables = self.provider.list_tables() + return [ + Suggestion( + label=t, + detail="Table", + insert_text=t, + ) + for t in (tables or []) + ] + except Exception: + return [] + + def _function_suggestions(self) -> list: + """Get built-in function name suggestions.""" + return [ + Suggestion( + label=name, + detail="Function", + insert_text=f"{name}(", + ) + for name in sorted(BUILTIN_FUNCTIONS.keys()) + ] + + def _keyword_suggestions(self) -> list: + """Get keyword suggestions.""" + return [ + Suggestion(label=kw, detail="Keyword", insert_text=kw) + for kw in FORMULA_KEYWORDS + ] diff --git a/src/myfasthtml/core/formula/dsl/completion/__init__.py b/src/myfasthtml/core/formula/dsl/completion/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/src/myfasthtml/core/formula/dsl/completion/__init__.py @@ -0,0 +1 @@ + diff --git a/src/myfasthtml/core/formula/dsl/definition.py b/src/myfasthtml/core/formula/dsl/definition.py new file mode 100644 index 0000000..7138b00 --- /dev/null +++ b/src/myfasthtml/core/formula/dsl/definition.py @@ -0,0 +1,79 @@ +""" +FormulaDSL definition for the DslEditor control. + +Provides the Lark grammar and derived completions for the +DataGrid Formula DSL (CodeMirror 5 Simple Mode). +""" + +from functools import cached_property +from typing import Dict, Any + +from myfasthtml.core.dsl.base import DSLDefinition +from .grammar import FORMULA_GRAMMAR + + +class FormulaDSL(DSLDefinition): + """ + DSL definition for DataGrid formula expressions. + + Uses the Lark grammar from grammar.py to drive syntax highlighting + and autocompletion in the DslEditor. + """ + + name: str = "Formula DSL" + + def get_grammar(self) -> str: + """Return the Lark grammar for the formula DSL.""" + return FORMULA_GRAMMAR + + @cached_property + def simple_mode_config(self) -> Dict[str, Any]: + """ + Return a hand-tuned CodeMirror 5 Simple Mode config for formula syntax. + + Overrides the base class to provide optimized highlighting rules + for column references, operators, functions, and keywords. + """ + return { + "start": [ + # Column references: {ColumnName} or {Table.Column} + { + "regex": r"\{[A-Za-z_][A-Za-z0-9_.]*(?:\s+where\s+[A-Za-z_][A-Za-z0-9_.]*\s*=\s*[A-Za-z_][A-Za-z0-9_]*)?\}", + "token": "variable-2", + }, + # Function names before parenthesis + { + "regex": r"[a-z_][a-z0-9_]*(?=\s*\()", + "token": "keyword", + }, + # Keywords: if, else, and, or, not, where, between, in + { + "regex": r"\b(if|else|and|or|not|where|between|in|isempty|isnotempty|isnan|contains|startswith|endswith|true|false)\b", + "token": "keyword", + }, + # Numbers + { + "regex": r"[+-]?(\d+\.?\d*|\.\d+)([eE][+-]?\d+)?", + "token": "number", + }, + # Strings + { + "regex": r'"[^"\\]*"', + "token": "string", + }, + # Operators + { + "regex": r"[=!<>]=?|[+\-*/%^]", + "token": "operator", + }, + # Parentheses and brackets + { + "regex": r"[()[\],]", + "token": "punctuation", + }, + ], + "meta": { + "dontIndentStates": ["comment"], + "lineComment": "#", + }, + } diff --git a/src/myfasthtml/core/formula/dsl/exceptions.py b/src/myfasthtml/core/formula/dsl/exceptions.py new file mode 100644 index 0000000..12878e5 --- /dev/null +++ b/src/myfasthtml/core/formula/dsl/exceptions.py @@ -0,0 +1,35 @@ +class FormulaError(Exception): + """Base exception for formula errors.""" + pass + + +class FormulaSyntaxError(FormulaError): + """Raised when the formula has syntax errors.""" + + def __init__(self, message, line=None, column=None, context=None): + self.message = message + self.line = line + self.column = column + self.context = context + super().__init__(self._format_message()) + + def _format_message(self): + parts = [self.message] + if self.line is not None: + parts.append(f"at line {self.line}") + if self.column is not None: + parts.append(f"col {self.column}") + return " ".join(parts) + + +class FormulaValidationError(FormulaError): + """Raised when the formula is syntactically correct but semantically invalid.""" + pass + + +class FormulaCycleError(FormulaError): + """Raised when formula dependencies contain a cycle.""" + + def __init__(self, cycle_nodes): + self.cycle_nodes = cycle_nodes + super().__init__(f"Circular dependency detected involving: {', '.join(cycle_nodes)}") diff --git a/src/myfasthtml/core/formula/dsl/grammar.py b/src/myfasthtml/core/formula/dsl/grammar.py new file mode 100644 index 0000000..e590a9b --- /dev/null +++ b/src/myfasthtml/core/formula/dsl/grammar.py @@ -0,0 +1,100 @@ +FORMULA_GRAMMAR = r""" + start: expression + + // ==================== Top-level expression ==================== + + ?expression: conditional_expr + + // Suffix-if: value_expr if condition [else expression] + // Right-associative for chaining: a if c1 else b if c2 else d + ?conditional_expr: or_expr "if" or_expr "else" conditional_expr -> conditional_with_else + | or_expr "if" or_expr -> conditional_no_else + | or_expr + + // ==================== Logical ==================== + + ?or_expr: and_expr ("or" and_expr)* -> or_op + ?and_expr: not_expr ("and" not_expr)* -> and_op + ?not_expr: "not" not_expr -> not_op + | comparison + + // ==================== Comparison ==================== + + ?comparison: addition comp_op addition -> comparison_expr + | addition "in" "[" literal ("," literal)* "]" -> in_expr + | addition "between" addition "and" addition -> between_expr + | addition "contains" addition -> contains_expr + | addition "startswith" addition -> startswith_expr + | addition "endswith" addition -> endswith_expr + | addition "isempty" -> isempty_expr + | addition "isnotempty" -> isnotempty_expr + | addition "isnan" -> isnan_expr + | addition + + comp_op: "==" -> eq + | "!=" -> ne + | "<=" -> le + | "<" -> lt + | ">=" -> ge + | ">" -> gt + + // ==================== Arithmetic ==================== + + ?addition: multiplication (add_op multiplication)* -> add_expr + ?multiplication: power (mul_op power)* -> mul_expr + ?power: unary ("^" unary)* -> pow_expr + + add_op: "+" -> plus + | "-" -> minus + + mul_op: "*" -> times + | "/" -> divide + | "%" -> modulo + + ?unary: "-" unary -> neg + | atom + + // ==================== Atoms ==================== + + ?atom: function_call + | cross_table_ref + | column_ref + | literal + | "(" expression ")" -> paren + + // ==================== References ==================== + + // Cross-table must be checked before column_ref since both use { } + // TABLE_NAME.COL_NAME with optional WHERE clause + // Note: whitespace around "where" and "=" is handled by %ignore + cross_table_ref: "{" TABLE_COL_REF "}" -> cross_ref_simple + | "{" TABLE_COL_REF "where" where_clause "}" -> cross_ref_where + + column_ref: "{" COL_NAME "}" + + where_clause: TABLE_COL_REF "=" COL_NAME + + // TABLE_COL_REF matches "TableName.ColumnName" (dot-separated, no spaces) + TABLE_COL_REF: /[A-Za-z_][A-Za-z0-9_]*\.[A-Za-z_][A-Za-z0-9_]*/ + COL_NAME: /[A-Za-z_][A-Za-z0-9_ ]*/ + + // ==================== Functions ==================== + + function_call: FUNC_NAME "(" [expression ("," expression)*] ")" + FUNC_NAME: /[a-z_][a-z0-9_]*/ + + // ==================== Literals ==================== + + ?literal: NUMBER -> number_literal + | ESCAPED_STRING -> string_literal + | "true"i -> true_literal + | "false"i -> false_literal + + // ==================== Terminals ==================== + + NUMBER: /[+-]?(\d+\.?\d*|\.\d+)([eE][+-]?\d+)?/ + ESCAPED_STRING: "\"" /[^"\\]*/ "\"" + + %ignore /[ \t\f]+/ + %ignore /\#[^\n]*/ +""" diff --git a/src/myfasthtml/core/formula/dsl/parser.py b/src/myfasthtml/core/formula/dsl/parser.py new file mode 100644 index 0000000..2f422f4 --- /dev/null +++ b/src/myfasthtml/core/formula/dsl/parser.py @@ -0,0 +1,85 @@ +""" +Formula DSL parser using Lark. + +Handles parsing of formula expression strings into a Lark AST. +No indentation handling needed — formulas are single-line expressions. +""" +from lark import Lark, UnexpectedInput + +from .exceptions import FormulaSyntaxError +from .grammar import FORMULA_GRAMMAR + + +class FormulaParser: + """ + Parser for the DataGrid formula language. + + Uses Lark LALR parser without indentation handling. + + Example: + parser = FormulaParser() + tree = parser.parse("{Price} * {Quantity}") + """ + + def __init__(self): + self._parser = Lark( + FORMULA_GRAMMAR, + parser="lalr", + propagate_positions=False, + ) + + def parse(self, text: str): + """ + Parse a formula expression string into a Lark Tree. + + Args: + text: The formula expression text. + + Returns: + lark.Tree: The parsed AST. + + Raises: + FormulaSyntaxError: If the text has syntax errors. + """ + text = text.strip() + if not text: + return None + + try: + return self._parser.parse(text) + except UnexpectedInput as e: + context = None + if hasattr(e, "get_context"): + context = e.get_context(text) + + raise FormulaSyntaxError( + message=self._format_error_message(e), + line=getattr(e, "line", None), + column=getattr(e, "column", None), + context=context, + ) from e + + def _format_error_message(self, error: UnexpectedInput) -> str: + """Format a user-friendly error message from a Lark exception.""" + if hasattr(error, "expected"): + expected = list(error.expected) + if len(expected) == 1: + return f"Expected {expected[0]}" + elif len(expected) <= 5: + return f"Expected one of: {', '.join(expected)}" + else: + return "Unexpected input" + + return str(error) + + +# Singleton parser instance +_parser_instance = None + + +def get_parser() -> FormulaParser: + """Get the singleton FormulaParser instance.""" + global _parser_instance + if _parser_instance is None: + _parser_instance = FormulaParser() + return _parser_instance diff --git a/src/myfasthtml/core/formula/dsl/transformer.py b/src/myfasthtml/core/formula/dsl/transformer.py new file mode 100644 index 0000000..edd3fdb --- /dev/null +++ b/src/myfasthtml/core/formula/dsl/transformer.py @@ -0,0 +1,274 @@ +""" +Formula DSL Transformer. + +Converts a Lark AST tree into FormulaDefinition and related AST dataclasses. +""" +from lark import Transformer + +from ..dataclasses import ( + FormulaDefinition, + FormulaNode, + LiteralNode, + ColumnRef, + CrossTableRef, + WhereClause, + BinaryOp, + UnaryOp, + FunctionCall, + ConditionalExpr, +) + + +class FormulaTransformer(Transformer): + """ + Transforms the Lark parse tree into FormulaDefinition AST dataclasses. + + Handles left-associative folding for arithmetic and logical operators. + Handles right-associative chaining for conditional expressions. + """ + + # ==================== Top-level ==================== + + def start(self, items): + """Return the FormulaDefinition wrapping the single expression.""" + expr = items[0] + return FormulaDefinition(expression=expr) + + # ==================== Conditionals ==================== + + def conditional_with_else(self, items): + """value_expr if condition else else_expr""" + value_expr, condition, else_expr = items + return ConditionalExpr( + value_expr=value_expr, + condition=condition, + else_expr=else_expr, + ) + + def conditional_no_else(self, items): + """value_expr if condition""" + value_expr, condition = items + return ConditionalExpr( + value_expr=value_expr, + condition=condition, + else_expr=None, + ) + + # ==================== Logical ==================== + + def or_op(self, items): + """Fold left-associatively: a or b or c -> BinaryOp(or, BinaryOp(or, a, b), c)""" + return self._fold_left(items, "or") + + def and_op(self, items): + """Fold left-associatively: a and b and c -> BinaryOp(and, BinaryOp(and, a, b), c)""" + return self._fold_left(items, "and") + + def not_op(self, items): + """not expr""" + return UnaryOp(operator="not", operand=items[0]) + + # ==================== Comparisons ==================== + + def comparison_expr(self, items): + """left comp_op right""" + left, op, right = items + return BinaryOp(operator=op, left=left, right=right) + + def in_expr(self, items): + """operand in [literal, ...]""" + operand = items[0] + values = list(items[1:]) + return BinaryOp(operator="in", left=operand, right=LiteralNode(value=values)) + + def between_expr(self, items): + """operand between low and high""" + operand, low, high = items + return BinaryOp( + operator="between", + left=operand, + right=LiteralNode(value=[low, high]), + ) + + def contains_expr(self, items): + left, right = items + return BinaryOp(operator="contains", left=left, right=right) + + def startswith_expr(self, items): + left, right = items + return BinaryOp(operator="startswith", left=left, right=right) + + def endswith_expr(self, items): + left, right = items + return BinaryOp(operator="endswith", left=left, right=right) + + def isempty_expr(self, items): + return UnaryOp(operator="isempty", operand=items[0]) + + def isnotempty_expr(self, items): + return UnaryOp(operator="isnotempty", operand=items[0]) + + def isnan_expr(self, items): + return UnaryOp(operator="isnan", operand=items[0]) + + # ==================== Comparison operators ==================== + + def eq(self, items): + return "==" + + def ne(self, items): + return "!=" + + def le(self, items): + return "<=" + + def lt(self, items): + return "<" + + def ge(self, items): + return ">=" + + def gt(self, items): + return ">" + + # ==================== Arithmetic ==================== + + def add_expr(self, items): + """Fold left-associatively with alternating operands and operators.""" + return self._fold_binary_with_ops(items) + + def mul_expr(self, items): + """Fold left-associatively with alternating operands and operators.""" + return self._fold_binary_with_ops(items) + + def pow_expr(self, items): + """Fold left-associatively for power expressions (^ is left-assoc here).""" + # pow_expr items are [base, exp1, exp2, ...] — operator is always "^" + result = items[0] + for exp in items[1:]: + result = BinaryOp(operator="^", left=result, right=exp) + return result + + def plus(self, items): + return "+" + + def minus(self, items): + return "-" + + def times(self, items): + return "*" + + def divide(self, items): + return "/" + + def modulo(self, items): + return "%" + + def neg(self, items): + """Unary negation: -expr""" + return UnaryOp(operator="-", operand=items[0]) + + # ==================== References ==================== + + def cross_ref_simple(self, items): + """{ Table.Column }""" + table_col = str(items[0]) + table, column = table_col.split(".", 1) + return CrossTableRef(table=table, column=column) + + def cross_ref_where(self, items): + """{ Table.Column WHERE remote_table.remote_col = local_col }""" + table_col = str(items[0]) + where = items[1] + table, column = table_col.split(".", 1) + return CrossTableRef(table=table, column=column, where_clause=where) + + def column_ref(self, items): + """{ ColumnName }""" + col_name = str(items[0]).strip() + return ColumnRef(column=col_name) + + def where_clause(self, items): + """TABLE_COL_REF = COL_NAME""" + remote_table_col = str(items[0]) + local_col = str(items[1]).strip() + remote_table, remote_col = remote_table_col.split(".", 1) + return WhereClause( + remote_table=remote_table, + remote_column=remote_col, + local_column=local_col, + ) + + # ==================== Functions ==================== + + def function_call(self, items): + """func_name(arg1, arg2, ...)""" + func_name = str(items[0]).lower() + args = list(items[1:]) + return FunctionCall(function_name=func_name, arguments=args) + + # ==================== Literals ==================== + + def number_literal(self, items): + value = str(items[0]) + if "." in value or "e" in value.lower(): + return LiteralNode(value=float(value)) + try: + return LiteralNode(value=int(value)) + except ValueError: + return LiteralNode(value=float(value)) + + def string_literal(self, items): + raw = str(items[0]) + # Remove surrounding double quotes + if raw.startswith('"') and raw.endswith('"'): + return LiteralNode(value=raw[1:-1]) + return LiteralNode(value=raw) + + def true_literal(self, items): + return LiteralNode(value=True) + + def false_literal(self, items): + return LiteralNode(value=False) + + def paren(self, items): + """Parenthesized expression — transparent pass-through.""" + return items[0] + + # ==================== Helpers ==================== + + def _fold_left(self, items: list, op: str) -> FormulaNode: + """ + Fold a list of operands left-associatively with a fixed operator. + + Args: + items: List of FormulaNode operands. + op: Operator string. + + Returns: + Left-folded BinaryOp tree, or the single item if only one. + """ + result = items[0] + for operand in items[1:]: + result = BinaryOp(operator=op, left=result, right=operand) + return result + + def _fold_binary_with_ops(self, items: list) -> FormulaNode: + """ + Fold a list of alternating [operand, op, operand, op, operand, ...] + left-associatively. + + Args: + items: Alternating list: [expr, op_str, expr, op_str, expr, ...] + + Returns: + Left-folded BinaryOp tree. + """ + result = items[0] + i = 1 + while i < len(items): + op = items[i] + right = items[i + 1] + result = BinaryOp(operator=op, left=result, right=right) + i += 2 + return result diff --git a/src/myfasthtml/core/formula/engine.py b/src/myfasthtml/core/formula/engine.py new file mode 100644 index 0000000..5ba112f --- /dev/null +++ b/src/myfasthtml/core/formula/engine.py @@ -0,0 +1,398 @@ +""" +Formula Engine — facade orchestrating parsing, DAG, and evaluation. + +Coordinates: +- Parsing formula text via the DSL parser +- Registering formulas and their dependencies in the DependencyGraph +- Evaluating dirty formula columns row-by-row via FormulaEvaluator +- Updating ns_fast_access caches in the DatagridStore +""" +import logging +from typing import Any, Callable, Optional + +import numpy as np + +from .dataclasses import FormulaDefinition, WhereClause +from .dependency_graph import DependencyGraph +from .dsl.parser import get_parser +from .dsl.transformer import FormulaTransformer +from .evaluator import FormulaEvaluator + +logger = logging.getLogger("FormulaEngine") + +# Callback that returns a DatagridStore-like object for a given table name +RegistryResolver = Callable[[str], Any] + + +def parse_formula(text: str) -> FormulaDefinition | None: + """Parse a formula expression string into a FormulaDefinition AST. + + Args: + text: The formula expression string. + + Returns: + FormulaDefinition on success, None if text is empty. + + Raises: + FormulaSyntaxError: If the formula text is syntactically invalid. + """ + text = text.strip() if text else "" + if not text: + return None + + parser = get_parser() + tree = parser.parse(text) + if tree is None: + return None + + transformer = FormulaTransformer() + formula = transformer.transform(tree) + formula.source_text = text + return formula + + +class FormulaEngine: + """ + Facade for the formula calculation system. + + Orchestrates formula parsing, dependency tracking, and incremental + recalculation of formula columns. + + Args: + registry_resolver: Callback that takes a table name and returns + the DatagridStore for that table (used for cross-table refs). + Provided by DataGridsManager. + """ + + def __init__(self, registry_resolver: Optional[RegistryResolver] = None): + self._graph = DependencyGraph() + self._registry_resolver = registry_resolver + # Cache of parsed formulas: {(table, col): FormulaDefinition} + self._formulas: dict[tuple[str, str], FormulaDefinition] = {} + + def set_formula(self, table: str, col: str, formula_text: str) -> None: + """ + Parse and register a formula for a column. + + Args: + table: Table name. + col: Column name. + formula_text: The formula expression string. + + Raises: + FormulaSyntaxError: If the formula is syntactically invalid. + FormulaCycleError: If the formula would create a circular dependency. + """ + formula_text = formula_text.strip() if formula_text else "" + if not formula_text: + self.remove_formula(table, col) + return + + formula = parse_formula(formula_text) + if formula is None: + self.remove_formula(table, col) + return + + # Registers in DAG and raises FormulaCycleError if cycle detected + self._graph.add_formula(table, col, formula) + self._formulas[(table, col)] = formula + + logger.debug("Formula set for %s.%s: %s", table, col, formula_text) + + def remove_formula(self, table: str, col: str) -> None: + """ + Remove a formula column from the engine. + + Args: + table: Table name. + col: Column name. + """ + self._graph.remove_formula(table, col) + self._formulas.pop((table, col), None) + + def mark_data_changed( + self, + table: str, + col: str, + rows: Optional[list[int]] = None, + ) -> None: + """ + Mark a column's data as changed, propagating dirty flags. + + Call this when source data is modified so that dependent formula + columns are re-evaluated on next render. + + Args: + table: Table name. + col: Column name. + rows: Specific row indices that changed. None means all rows. + """ + self._graph.mark_dirty(table, col, rows) + + def recalculate_if_needed(self, table: str, store: Any) -> bool: + """ + Recalculate all dirty formula columns for a table. + + Should be called at the start of ``mk_body_content_page()`` to + ensure formula columns are up-to-date before rendering. + + Updates ``store.ns_fast_access`` and ``store.ns_row_data`` in place. + + Args: + table: Table name. + store: The DatagridStore instance for this table. + + Returns: + True if any columns were recalculated, False otherwise. + """ + dirty_nodes = self._graph.get_calculation_order(table=table) + + if not dirty_nodes: + return False + + for node in dirty_nodes: + formula = node.formula + if formula is None: + continue + self._evaluate_column(table, node.column, formula, store) + self._graph.clear_dirty(node.node_id) + + # Rebuild ns_row_data after recalculation + if dirty_nodes and store.ns_fast_access: + self._rebuild_row_data(store) + + return True + + def has_formula(self, table: str, col: str) -> bool: + """ + Check if a column has a formula registered. + + Args: + table: Table name. + col: Column name. + + Returns: + True if the column has a registered formula. + """ + return self._graph.has_formula(table, col) + + def get_formula_text(self, table: str, col: str) -> Optional[str]: + """ + Get the source text of a registered formula. + + Args: + table: Table name. + col: Column name. + + Returns: + Formula source text or None if not registered. + """ + formula = self._formulas.get((table, col)) + return formula.source_text if formula else None + + # ==================== Private helpers ==================== + + def _evaluate_column( + self, + table: str, + col: str, + formula: FormulaDefinition, + store: Any, + ) -> None: + """ + Evaluate a formula column row-by-row and update ns_fast_access. + + Args: + table: Table name. + col: Column name. + formula: The parsed FormulaDefinition. + store: The DatagridStore with ns_fast_access and ns_row_data. + """ + if store.ns_row_data is None or len(store.ns_row_data) == 0: + return + + n_rows = len(store.ns_row_data) + resolver = self._make_cross_table_resolver(table) + evaluator = FormulaEvaluator(cross_table_resolver=resolver) + + # Ensure ns_fast_access exists before the loop so that formula columns + # evaluated earlier in the same pass are visible to subsequent columns. + if store.ns_fast_access is None: + store.ns_fast_access = {} + + results = np.empty(n_rows, dtype=object) + + for row_index in range(n_rows): + # Build row_data from ns_fast_access so that formula columns evaluated + # earlier in this pass (e.g. B) are available to dependent columns (e.g. C). + row_data = { + c: arr[row_index] + for c, arr in store.ns_fast_access.items() + if arr is not None and row_index < len(arr) + } + results[row_index] = evaluator.evaluate(formula, row_data, row_index) + + store.ns_fast_access[col] = results + + logger.debug("Evaluated formula column %s.%s (%d rows)", table, col, n_rows) + + def _rebuild_row_data(self, store: Any) -> None: + """ + Rebuild ns_row_data to include formula column results. + + This ensures formula values are available to dependent formulas + in subsequent evaluation passes. + + Args: + store: The DatagridStore to update. + """ + if store.ns_fast_access is None: + return + + n_rows = len(store.ns_row_data) + for row_index in range(n_rows): + row = store.ns_row_data[row_index] + for col, arr in store.ns_fast_access.items(): + if arr is not None and row_index < len(arr): + row[col] = arr[row_index] + + def _make_cross_table_resolver(self, current_table: str): + """ + Create a cross-table resolver callback for the given table context. + + Resolution strategy: + 1. Explicit WHERE clause: scan remote column for matching rows. + 2. Implicit join by ``id`` column: match rows where both tables share + the same id value. + 3. Fallback: match by row_index. + + Args: + current_table: The table that contains the formula. + + Returns: + A callable ``resolver(table, column, where_clause, row_index) -> value``. + """ + + def resolver( + remote_table: str, + remote_column: str, + where_clause: Optional[WhereClause], + row_index: int, + ) -> Any: + if self._registry_resolver is None: + logger.warning( + "No registry_resolver set for cross-table ref %s.%s", + remote_table, remote_column, + ) + return None + + remote_store = self._registry_resolver(remote_table) + if remote_store is None: + logger.warning("Table '%s' not found in registry", remote_table) + return None + + ns = remote_store.ns_fast_access + if not ns or remote_column not in ns: + logger.debug( + "Column '%s' not found in table '%s'", remote_column, remote_table + ) + return None + + remote_array = ns[remote_column] + + # Strategy 1: Explicit WHERE clause + if where_clause is not None: + return self._resolve_with_where( + where_clause, remote_store, remote_column, + remote_array, current_table, row_index, + ) + + # Strategy 2: Implicit join by 'id' column + current_store = self._registry_resolver(current_table) + if ( + current_store is not None + and current_store.ns_fast_access is not None + and "id" in current_store.ns_fast_access + and "id" in ns + ): + local_id_arr = current_store.ns_fast_access["id"] + remote_id_arr = ns["id"] + if row_index < len(local_id_arr): + local_id = local_id_arr[row_index] + # Find first matching row in remote table + matches = np.where(remote_id_arr == local_id)[0] + if len(matches) > 0: + return remote_array[matches[0]] + return None + + # Strategy 3: Fallback — match by row_index + if row_index < len(remote_array): + return remote_array[row_index] + return None + + return resolver + + def _resolve_with_where( + self, + where_clause: WhereClause, + remote_store: Any, + remote_column: str, + remote_array: Any, + current_table: str, + row_index: int, + ) -> Any: + """ + Resolve a cross-table reference using an explicit WHERE clause. + + Args: + where_clause: The parsed WHERE clause. + remote_store: DatagridStore for the remote table. + remote_column: Column to return value from. + remote_array: numpy array of the remote column values. + current_table: Table containing the formula. + row_index: Current row being evaluated. + + Returns: + The value from the first matching remote row, or None. + """ + remote_ns = remote_store.ns_fast_access + if not remote_ns: + return None + + # Get the remote key column array + remote_key_col = where_clause.remote_column + if remote_key_col not in remote_ns: + logger.debug( + "WHERE key column '%s' not found in remote table", remote_key_col + ) + return None + + remote_key_array = remote_ns[remote_key_col] + + # Get the local value to compare + current_store = self._registry_resolver(current_table) if self._registry_resolver else None + if current_store is None or current_store.ns_fast_access is None: + return None + + local_col = where_clause.local_column + if local_col not in current_store.ns_fast_access: + logger.debug("WHERE local column '%s' not found", local_col) + return None + + local_array = current_store.ns_fast_access[local_col] + if row_index >= len(local_array): + return None + + local_value = local_array[row_index] + + # Find matching rows + try: + matches = np.where(remote_key_array == local_value)[0] + except Exception: + matches = [] + + if len(matches) == 0: + return None + + # Return value from first match (use aggregation functions for multi-row) + return remote_array[matches[0]] diff --git a/src/myfasthtml/core/formula/evaluator.py b/src/myfasthtml/core/formula/evaluator.py new file mode 100644 index 0000000..47ad9bc --- /dev/null +++ b/src/myfasthtml/core/formula/evaluator.py @@ -0,0 +1,522 @@ +""" +Formula Evaluator. + +Evaluates a FormulaDefinition AST row-by-row using column data. +""" +import logging +import math +from datetime import date, datetime +from typing import Any, Callable, Optional + +from .dataclasses import ( + FormulaNode, + FormulaDefinition, + LiteralNode, + ColumnRef, + CrossTableRef, + BinaryOp, + UnaryOp, + FunctionCall, + ConditionalExpr, +) + +logger = logging.getLogger("FormulaEvaluator") + +# Type alias for the cross-table resolver callback +CrossTableResolver = Callable[[str, str, Optional[object], int], Any] + + +def _safe_numeric(value) -> Optional[float]: + """Convert value to float, returning None if not possible.""" + if value is None: + return None + try: + return float(value) + except (TypeError, ValueError): + return None + + +# ==================== Built-in function registry ==================== + +def _fn_round(args): + if len(args) < 1: + return None + value = args[0] + decimals = int(args[1]) if len(args) > 1 else 0 + v = _safe_numeric(value) + return round(v, decimals) if v is not None else None + + +def _fn_abs(args): + v = _safe_numeric(args[0]) if args else None + return abs(v) if v is not None else None + + +def _fn_min(args): + nums = [_safe_numeric(a) for a in args] + nums = [n for n in nums if n is not None] + return min(nums) if nums else None + + +def _fn_max(args): + nums = [_safe_numeric(a) for a in args] + nums = [n for n in nums if n is not None] + return max(nums) if nums else None + + +def _fn_floor(args): + v = _safe_numeric(args[0]) if args else None + return math.floor(v) if v is not None else None + + +def _fn_ceil(args): + v = _safe_numeric(args[0]) if args else None + return math.ceil(v) if v is not None else None + + +def _fn_sqrt(args): + v = _safe_numeric(args[0]) if args else None + return math.sqrt(v) if v is not None and v >= 0 else None + + +def _fn_sum(args): + """Sum of all arguments (used for inline multi-value, not aggregation).""" + nums = [_safe_numeric(a) for a in args] + nums = [n for n in nums if n is not None] + return sum(nums) if nums else None + + +def _fn_avg(args): + nums = [_safe_numeric(a) for a in args] + nums = [n for n in nums if n is not None] + return sum(nums) / len(nums) if nums else None + + +def _fn_len(args): + v = args[0] if args else None + if v is None: + return None + return len(str(v)) + + +def _fn_upper(args): + v = args[0] if args else None + return str(v).upper() if v is not None else None + + +def _fn_lower(args): + v = args[0] if args else None + return str(v).lower() if v is not None else None + + +def _fn_trim(args): + v = args[0] if args else None + return str(v).strip() if v is not None else None + + +def _fn_left(args): + if len(args) < 2: + return None + v, n = args[0], args[1] + if v is None or n is None: + return None + return str(v)[:int(n)] + + +def _fn_right(args): + if len(args) < 2: + return None + v, n = args[0], args[1] + if v is None or n is None: + return None + n = int(n) + return str(v)[-n:] if n > 0 else "" + + +def _fn_concat(args): + return "".join(str(a) if a is not None else "" for a in args) + + +def _fn_year(args): + v = args[0] if args else None + if v is None: + return None + if isinstance(v, (datetime, date)): + return v.year + try: + return datetime.fromisoformat(str(v)).year + except (ValueError, TypeError): + return None + + +def _fn_month(args): + v = args[0] if args else None + if v is None: + return None + if isinstance(v, (datetime, date)): + return v.month + try: + return datetime.fromisoformat(str(v)).month + except (ValueError, TypeError): + return None + + +def _fn_day(args): + v = args[0] if args else None + if v is None: + return None + if isinstance(v, (datetime, date)): + return v.day + try: + return datetime.fromisoformat(str(v)).day + except (ValueError, TypeError): + return None + + +def _fn_today(args): + return date.today() + + +def _fn_datediff(args): + if len(args) < 2: + return None + d1, d2 = args[0], args[1] + if d1 is None or d2 is None: + return None + try: + if not isinstance(d1, (datetime, date)): + d1 = datetime.fromisoformat(str(d1)) + if not isinstance(d2, (datetime, date)): + d2 = datetime.fromisoformat(str(d2)) + delta = d1 - d2 + return delta.days + except (ValueError, TypeError): + return None + + +def _fn_coalesce(args): + for a in args: + if a is not None: + return a + return None + + +def _fn_if_error(args): + # if_error(expr, fallback) - expr already evaluated, error would be None + if len(args) < 2: + return args[0] if args else None + return args[0] if args[0] is not None else args[1] + + +def _fn_count(args): + """Count non-None values (used for aggregation results).""" + return sum(1 for a in args if a is not None) + + +# ==================== Function registry ==================== + +BUILTIN_FUNCTIONS: dict[str, Callable] = { + # Math + "round": _fn_round, + "abs": _fn_abs, + "min": _fn_min, + "max": _fn_max, + "floor": _fn_floor, + "ceil": _fn_ceil, + "sqrt": _fn_sqrt, + "sum": _fn_sum, + "avg": _fn_avg, + # Text + "len": _fn_len, + "upper": _fn_upper, + "lower": _fn_lower, + "trim": _fn_trim, + "left": _fn_left, + "right": _fn_right, + "concat": _fn_concat, + # Date + "year": _fn_year, + "month": _fn_month, + "day": _fn_day, + "today": _fn_today, + "datediff": _fn_datediff, + # Utility + "coalesce": _fn_coalesce, + "if_error": _fn_if_error, + "count": _fn_count, +} + + +# ==================== Evaluator ==================== + +class FormulaEvaluator: + """ + Row-by-row formula evaluator. + + Evaluates a FormulaDefinition AST against a single row of data. + + Args: + cross_table_resolver: Optional callback for cross-table references. + Signature: resolver(table, column, where_clause, row_index) -> value + """ + + def __init__(self, cross_table_resolver: Optional[CrossTableResolver] = None): + self._cross_table_resolver = cross_table_resolver + + def evaluate( + self, + formula: FormulaDefinition, + row_data: dict, + row_index: int, + ) -> Any: + """ + Evaluate a formula for a single row. + + Args: + formula: The parsed FormulaDefinition AST. + row_data: Dict mapping column_id -> value for the current row. + row_index: The integer index of the current row. + + Returns: + The computed value, or None on error. + """ + try: + return self._eval(formula.expression, row_data, row_index) + except Exception as exc: + logger.warning( + "Formula evaluation error at row %d: %s", row_index, exc + ) + return None + + def _eval(self, node: FormulaNode, row_data: dict, row_index: int) -> Any: + """ + Recursively evaluate an AST node. + + Args: + node: The AST node to evaluate. + row_data: Current row data dict. + row_index: Current row index. + + Returns: + Evaluated value. + """ + if isinstance(node, LiteralNode): + return node.value + + if isinstance(node, ColumnRef): + return self._resolve_column(node.column, row_data) + + if isinstance(node, CrossTableRef): + return self._resolve_cross_table(node, row_index) + + if isinstance(node, BinaryOp): + return self._eval_binary(node, row_data, row_index) + + if isinstance(node, UnaryOp): + return self._eval_unary(node, row_data, row_index) + + if isinstance(node, FunctionCall): + return self._eval_function(node, row_data, row_index) + + if isinstance(node, ConditionalExpr): + return self._eval_conditional(node, row_data, row_index) + + logger.warning("Unknown AST node type: %s", type(node).__name__) + return None + + def _resolve_column(self, column_name: str, row_data: dict) -> Any: + """Resolve a column reference in the current row.""" + if column_name in row_data: + return row_data[column_name] + # Try case-insensitive match + lower_name = column_name.lower() + for key, value in row_data.items(): + if str(key).lower() == lower_name: + return value + logger.debug("Column '%s' not found in row_data", column_name) + return None + + def _resolve_cross_table(self, node: CrossTableRef, row_index: int) -> Any: + """Resolve a cross-table reference.""" + if self._cross_table_resolver is None: + logger.warning( + "No cross_table_resolver set for cross-table ref %s.%s", + node.table, node.column, + ) + return None + return self._cross_table_resolver( + node.table, node.column, node.where_clause, row_index + ) + + def _eval_binary(self, node: BinaryOp, row_data: dict, row_index: int) -> Any: + """Evaluate a binary operation.""" + left = self._eval(node.left, row_data, row_index) + op = node.operator + + # Short-circuit for logical operators + if op == "and": + if not self._truthy(left): + return False + right = self._eval(node.right, row_data, row_index) + return self._truthy(left) and self._truthy(right) + + if op == "or": + if self._truthy(left): + return True + right = self._eval(node.right, row_data, row_index) + return self._truthy(left) or self._truthy(right) + + right = self._eval(node.right, row_data, row_index) + + # Arithmetic + if op == "+": + if isinstance(left, str) or isinstance(right, str): + return str(left or "") + str(right or "") + return self._num_op(left, right, lambda a, b: a + b) + if op == "-": + return self._num_op(left, right, lambda a, b: a - b) + if op == "*": + return self._num_op(left, right, lambda a, b: a * b) + if op == "/": + if right == 0 or right == 0.0: + return None # Division by zero -> None + return self._num_op(left, right, lambda a, b: a / b) + if op == "%": + if right == 0 or right == 0.0: + return None + return self._num_op(left, right, lambda a, b: a % b) + if op == "^": + return self._num_op(left, right, lambda a, b: a ** b) + + # Comparison + if op == "==": + return left == right + if op == "!=": + return left != right + if op == "<": + return self._compare(left, right) < 0 + if op == "<=": + return self._compare(left, right) <= 0 + if op == ">": + return self._compare(left, right) > 0 + if op == ">=": + return self._compare(left, right) >= 0 + + # String operations + if op == "contains": + return str(right) in str(left) if left is not None else False + if op == "startswith": + return str(left).startswith(str(right)) if left is not None else False + if op == "endswith": + return str(left).endswith(str(right)) if left is not None else False + + # Collection operations + if op == "in": + values = right.value if isinstance(right, LiteralNode) else right + if isinstance(values, list): + return left in [v.value if isinstance(v, LiteralNode) else v for v in values] + return left in (values or []) + if op == "between": + values = right.value if isinstance(right, LiteralNode) else right + if isinstance(values, list) and len(values) == 2: + lo = values[0].value if isinstance(values[0], LiteralNode) else values[0] + hi = values[1].value if isinstance(values[1], LiteralNode) else values[1] + return lo <= left <= hi + return None + + logger.warning("Unknown binary operator: %s", op) + return None + + def _eval_unary(self, node: UnaryOp, row_data: dict, row_index: int) -> Any: + """Evaluate a unary operation.""" + operand = self._eval(node.operand, row_data, row_index) + op = node.operator + + if op == "-": + v = _safe_numeric(operand) + return -v if v is not None else None + if op == "not": + return not self._truthy(operand) + if op == "isempty": + return operand is None or operand == "" or operand == [] + if op == "isnotempty": + return operand is not None and operand != "" and operand != [] + if op == "isnan": + try: + return math.isnan(float(operand)) + except (TypeError, ValueError): + return False + + logger.warning("Unknown unary operator: %s", op) + return None + + def _eval_function(self, node: FunctionCall, row_data: dict, row_index: int) -> Any: + """Evaluate a function call.""" + name = node.function_name.lower() + args = [self._eval(arg, row_data, row_index) for arg in node.arguments] + + if name in BUILTIN_FUNCTIONS: + try: + return BUILTIN_FUNCTIONS[name](args) + except Exception as exc: + logger.warning("Function '%s' error: %s", name, exc) + return None + + logger.warning("Unknown function: %s", name) + return None + + def _eval_conditional(self, node: ConditionalExpr, row_data: dict, row_index: int) -> Any: + """Evaluate a conditional expression.""" + condition = self._eval(node.condition, row_data, row_index) + if self._truthy(condition): + return self._eval(node.value_expr, row_data, row_index) + elif node.else_expr is not None: + return self._eval(node.else_expr, row_data, row_index) + return None + + @staticmethod + def _truthy(value: Any) -> bool: + """Convert a value to boolean for conditional evaluation.""" + if value is None: + return False + if isinstance(value, bool): + return value + if isinstance(value, (int, float)): + return value != 0 + if isinstance(value, str): + return len(value) > 0 + return bool(value) + + @staticmethod + def _num_op(left: Any, right: Any, fn: Callable) -> Any: + """Apply a numeric binary function, returning None if inputs are non-numeric.""" + a = _safe_numeric(left) + b = _safe_numeric(right) + if a is None or b is None: + return None + try: + return fn(a, b) + except (ZeroDivisionError, OverflowError, ValueError): + return None + + @staticmethod + def _compare(left: Any, right: Any) -> int: + """ + Compare two values, returning -1, 0, or 1. + + Handles mixed numeric/string comparisons gracefully. + """ + try: + if left < right: + return -1 + elif left > right: + return 1 + return 0 + except TypeError: + # Fallback: compare as strings + sl, sr = str(left), str(right) + if sl < sr: + return -1 + elif sl > sr: + return 1 + return 0 diff --git a/tests/core/formatting/test_engine.py b/tests/core/formatting/test_engine.py index 96771c8..f546abe 100644 --- a/tests/core/formatting/test_engine.py +++ b/tests/core/formatting/test_engine.py @@ -243,7 +243,7 @@ class TestConflictResolution: css, formatted = engine.apply_format(rules, cell_value=150, row_data=row_data) assert isinstance(css, StyleContainer) - assert "var(--color-secondary)" in css.css # Style from Rule 2 + assert css.cls == "mf-formatting-secondary" # Style from Rule 2 assert formatted == "150.00 €" # Formatter from Rule 1 # Case 2: Condition not met (value <= budget) @@ -282,7 +282,7 @@ class TestConflictResolution: css, formatted = engine.apply_format(rules, cell_value=-5.67) - assert "var(--color-error)" in css.css # Rule 3 wins for style + assert css.cls == "mf-formatting-error" # Rule 3 wins for style assert formatted == "-6 €" # Rule 4 wins for formatter (precision=0) @@ -316,7 +316,7 @@ class TestWithRowData: css, _ = engine.apply_format(rules, cell_value=42, row_data=row_data) assert isinstance(css, StyleContainer) - assert "background-color" in css.css + assert css.cls == "mf-formatting-error" class TestPresets: @@ -327,7 +327,7 @@ class TestPresets: css, _ = engine.apply_format(rules, cell_value=42) - assert "var(--color-success)" in css.css + assert css.cls == "mf-formatting-success" def test_formatter_preset(self): """Formatter preset is resolved correctly.""" diff --git a/tests/core/formatting/test_style_resolver.py b/tests/core/formatting/test_style_resolver.py index 23fc11a..1230cd4 100644 --- a/tests/core/formatting/test_style_resolver.py +++ b/tests/core/formatting/test_style_resolver.py @@ -16,9 +16,14 @@ class TestResolve: assert result["font-weight"] == "bold" def test_resolve_preset_with_override(self): - """Preset properties can be overridden by explicit values.""" - resolver = StyleResolver() - # "success" preset has background and color defined + """Preset CSS properties can be overridden by explicit values.""" + custom_presets = { + "success": { + "background-color": "var(--color-success)", + "color": "var(--color-success-content)", + } + } + resolver = StyleResolver(style_presets=custom_presets) style = Style(preset="success", color="black") result = resolver.resolve(style) @@ -66,6 +71,16 @@ class TestResolve: assert result == {} + def test_i_can_resolve_class_only_preset(self): + """Default preset with __class__ only is included as-is in resolve result.""" + resolver = StyleResolver() + style = Style(preset="success") + result = resolver.resolve(style) + + assert result["__class__"] == "mf-formatting-success" + assert "background-color" not in result + assert "color" not in result + def test_resolve_converts_property_names(self): """Python attribute names are converted to CSS property names.""" resolver = StyleResolver() @@ -151,11 +166,11 @@ class TestToStyleContainer: None, ["background-color: red", "color: white"] ), - # Class only via preset + # Class only via preset (default presets use __class__, no inline CSS) ( Style(preset="success"), - None, # Default presets don't have __class__ - ["background-color: var(--color-success)", "color: var(--color-success-content)"] + "mf-formatting-success", + [] ), # Empty style ( @@ -246,3 +261,12 @@ class TestToStyleContainer: assert isinstance(result, StyleContainer) assert result.cls is None assert result.css == "" + + def test_i_can_resolve_default_preset_to_container(self): + """Default preset with __class__ only generates cls but no css.""" + resolver = StyleResolver() + style = Style(preset="error") + result = resolver.to_style_container(style) + + assert result.cls == "mf-formatting-error" + assert result.css == "" diff --git a/tests/core/formula/__init__.py b/tests/core/formula/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/core/formula/test_dependency_graph.py b/tests/core/formula/test_dependency_graph.py new file mode 100644 index 0000000..1762810 --- /dev/null +++ b/tests/core/formula/test_dependency_graph.py @@ -0,0 +1,391 @@ +""" +Tests for the DependencyGraph DAG. +""" +import pytest + +from myfasthtml.core.formula.dataclasses import ( + ColumnRef, + ConditionalExpr, + CrossTableRef, + FormulaDefinition, + FunctionCall, + LiteralNode, + UnaryOp, + WhereClause, +) +from myfasthtml.core.formula.dependency_graph import DependencyGraph +from myfasthtml.core.formula.dsl.exceptions import FormulaCycleError +from myfasthtml.core.formula.engine import parse_formula + + +def make_formula(expr): + """Create a FormulaDefinition from a raw AST node for direct AST testing.""" + return FormulaDefinition(expression=expr, source_text="test") + + +# ==================== Add formula ==================== + +def test_i_can_add_simple_dependency(): + """Test that adding a formula creates edges in the graph.""" + graph = DependencyGraph() + formula = parse_formula("{Price} * {Quantity}") + graph.add_formula("orders", "total", formula) + + node = graph.get_node("orders", "total") + assert node is not None + assert node.formula is formula + assert node.dirty is True + + # Precedents should include Price and Quantity + precedent = graph._precedents["orders.total"] + assert "orders.Price" in precedent + assert "orders.Quantity" in precedent + + # Descendants are correctly set + assert "orders.total" in graph._dependents["orders.Price"] + assert "orders.total" in graph._dependents["orders.Quantity"] + + +def test_i_can_add_formula_and_check_dirty(): + """Test that newly added formulas are marked dirty.""" + graph = DependencyGraph() + formula = parse_formula("{A} + {B}") + graph.add_formula("t", "C", formula) + node = graph.get_node("t", "C") + assert node.dirty is True + + +def test_i_can_update_formula(): + """Test that replacing a formula updates dependencies.""" + graph = DependencyGraph() + formula1 = parse_formula("{A} + {B}") + graph.add_formula("t", "C", formula1) + + formula2 = parse_formula("{X} * {Y}") + graph.add_formula("t", "C", formula2) + + node = graph.get_node("t", "C") + assert node.formula is formula2 + # Should no longer depend on A and B + precedent = graph._precedents.get("t.C", set()) + assert "t.A" not in precedent + assert "t.B" not in precedent + + +# ==================== Cycle detection ==================== + +def test_i_cannot_create_cycle(): + """Test that circular dependencies raise FormulaCycleError.""" + graph = DependencyGraph() + # A depends on B + graph.add_formula("t", "A", parse_formula("{B} + 1")) + # B depends on A -> cycle + with pytest.raises(FormulaCycleError): + graph.add_formula("t", "B", parse_formula("{A} * 2")) + + +def test_i_cannot_create_self_reference(): + """Test that a formula referencing its own column raises FormulaCycleError.""" + graph = DependencyGraph() + with pytest.raises(FormulaCycleError): + graph.add_formula("t", "A", parse_formula("{A} + 1")) + + +def test_i_can_detect_long_cycle(): + """Test that a longer chain cycle is also detected.""" + graph = DependencyGraph() + graph.add_formula("t", "A", parse_formula("{B} + 1")) + graph.add_formula("t", "B", parse_formula("{C} + 1")) + with pytest.raises(FormulaCycleError): + graph.add_formula("t", "C", parse_formula("{A} + 1")) + + +# ==================== Dirty flag propagation ==================== + +def test_i_can_propagate_dirty_flags(): + """Test that marking a source column dirty propagates to dependents.""" + graph = DependencyGraph() + graph.add_formula("t", "B", parse_formula("{A} * 2")) + # Clear dirty flag set by add_formula + graph.clear_dirty("t.B") + assert not graph.get_node("t", "B").dirty + + # Mark source A as dirty + graph.mark_dirty("t", "A") + + # B depends on A, so B should become dirty + node_b = graph.get_node("t", "B") + assert node_b is not None + assert node_b.dirty is True + + +def test_i_can_propagate_dirty_to_chain(): + """Test that dirty flags propagate through a chain: A -> B -> C.""" + graph = DependencyGraph() + graph.add_formula("t", "B", parse_formula("{A} + 1")) + graph.add_formula("t", "C", parse_formula("{B} + 1")) + # Clear all dirty flags + graph.clear_dirty("t.B") + graph.clear_dirty("t.C") + + # Mark A dirty + graph.mark_dirty("t", "A") + + assert graph.get_node("t", "B").dirty is True + assert graph.get_node("t", "C").dirty is True + + +def test_i_can_propagate_specific_rows(): + """Test that dirty propagation can be limited to specific rows.""" + graph = DependencyGraph() + graph.add_formula("t", "B", parse_formula("{A} * 2")) + graph.clear_dirty("t.B") + + graph.mark_dirty("t", "A", rows=[0, 2, 5]) + + node_b = graph.get_node("t", "B") + assert node_b.dirty is True + assert 0 in node_b.dirty_rows + assert 2 in node_b.dirty_rows + assert 5 in node_b.dirty_rows + assert 1 not in node_b.dirty_rows + + +# ==================== Topological ordering ==================== + +def test_i_can_get_calculation_order(): + """Test that dirty formula nodes are returned in topological order.""" + graph = DependencyGraph() + graph.add_formula("t", "B", parse_formula("{A} + 1")) + graph.add_formula("t", "C", parse_formula("{B} + 1")) + graph.mark_dirty("t", "A") + + order = graph.get_calculation_order(table="t") + node_ids = [n.node_id for n in order] + + # B must come before C + assert "t.B" in node_ids + assert "t.C" in node_ids + assert node_ids.index("t.B") < node_ids.index("t.C") + + +def test_i_can_get_calculation_order_without_table_filter(): + """Test that get_calculation_order with no table filter returns dirty nodes across all tables. + + Why: The table parameter is optional. Omitting it should return dirty formula nodes + from every table, not just one. + """ + graph = DependencyGraph() + graph.add_formula("t1", "B", parse_formula("{A} + 1")) + graph.add_formula("t2", "D", parse_formula("{C} * 2")) + + order = graph.get_calculation_order() + node_ids = [n.node_id for n in order] + + assert "t1.B" in node_ids + assert "t2.D" in node_ids + + +def test_i_can_get_calculation_order_excludes_clean_nodes(): + """Test that get_calculation_order only returns dirty formula nodes. + + Why: The method filters on node.dirty (source line 166). Clean nodes must + never appear in the output, otherwise they would be recalculated unnecessarily. + """ + graph = DependencyGraph() + graph.add_formula("t", "B", parse_formula("{A} + 1")) + graph.add_formula("t", "C", parse_formula("{A} * 2")) + graph.clear_dirty("t.C") + + order = graph.get_calculation_order(table="t") + node_ids = [n.node_id for n in order] + + assert "t.B" in node_ids + assert "t.C" not in node_ids + + +def test_i_can_handle_diamond_dependency(): + """Test topological order with diamond dependency: A -> B, A -> C, B+C -> D.""" + graph = DependencyGraph() + graph.add_formula("t", "B", parse_formula("{A} + 1")) + graph.add_formula("t", "C", parse_formula("{A} * 2")) + graph.add_formula("t", "D", parse_formula("{B} + {C}")) + graph.mark_dirty("t", "A") + + order = graph.get_calculation_order(table="t") + node_ids = [n.node_id for n in order] + + assert "t.B" in node_ids + assert "t.C" in node_ids + assert "t.D" in node_ids + # D must come after both B and C + assert node_ids.index("t.D") > node_ids.index("t.B") + assert node_ids.index("t.D") > node_ids.index("t.C") + + +# ==================== Remove formula ==================== + +def test_i_can_remove_formula(): + """Test that a formula can be removed from the graph.""" + graph = DependencyGraph() + graph.add_formula("t", "B", parse_formula("{A} + 1")) + assert graph.has_formula("t", "B") + + graph.remove_formula("t", "B") + assert not graph.has_formula("t", "B") + + +def test_i_can_remove_formula_node_kept_when_has_dependents(): + """Test that removing a formula keeps the node when other formulas still depend on it. + + Why: remove_formula deletes the node only when no dependents exist (source line 145-146). + If another formula depends on the removed node it must remain as a data node. + """ + graph = DependencyGraph() + graph.add_formula("t", "B", parse_formula("{A} + 1")) + graph.add_formula("t", "C", parse_formula("{B} * 2")) + + graph.remove_formula("t", "B") + + assert not graph.has_formula("t", "B") + assert graph.get_node("t", "B") is not None + + +def test_i_can_remove_formula_and_add_back(): + """Test that a formula can be removed and re-added.""" + graph = DependencyGraph() + graph.add_formula("t", "B", parse_formula("{A} + 1")) + graph.remove_formula("t", "B") + # Should not raise + graph.add_formula("t", "B", parse_formula("{X} * 2")) + assert graph.has_formula("t", "B") + + +# ==================== Cross-table ==================== + +def test_i_can_handle_cross_table_dependencies(): + """Test that cross-table references create inter-table edges.""" + graph = DependencyGraph() + formula = parse_formula("{Products.Price} * {Quantity}") + graph.add_formula("orders", "total", formula) + + node = graph.get_node("orders", "total") + assert node is not None + prec = graph._precedents.get("orders.total", set()) + assert "Products.Price" in prec + assert "orders.Quantity" in prec + + +def test_i_can_extract_cross_table_ref_with_where_clause(): + """Test that CrossTableRef with a WHERE clause adds both the remote column + and the local column from the WHERE clause as dependencies. + + Why: Source line 366-367 adds where_clause.local_column as a dependency. + Without this test that code path is never exercised. + """ + graph = DependencyGraph() + formula = make_formula( + CrossTableRef( + table="Products", + column="Price", + where_clause=WhereClause( + remote_table="Products", + remote_column="id", + local_column="product_id", + ), + ) + ) + graph.add_formula("orders", "total", formula) + + prec = graph._precedents.get("orders.total", set()) + assert "Products.Price" in prec + assert "orders.product_id" in prec + + +# ==================== _extract_dependencies AST coverage ==================== + +def test_i_can_extract_unary_op_dependency(): + """Test that UnaryOp correctly extracts its operand's column dependency. + + Why: Source lines 373-375 handle UnaryOp. Without this test that branch + is never exercised — a negated column reference would silently produce + no dependency edge. + """ + graph = DependencyGraph() + formula = make_formula(UnaryOp("-", ColumnRef("A"))) + graph.add_formula("t", "B", formula) + + prec = graph._precedents.get("t.B", set()) + assert "t.A" in prec + + +def test_i_can_extract_function_call_dependencies(): + """Test that FunctionCall extracts column dependencies from all arguments. + + Why: Source lines 376-378 iterate over arguments. With multiple column + arguments every dependency must appear in the precedents set. + """ + graph = DependencyGraph() + formula = make_formula(FunctionCall("SUM", [ColumnRef("A"), ColumnRef("B")])) + graph.add_formula("t", "C", formula) + + prec = graph._precedents.get("t.C", set()) + assert "t.A" in prec + assert "t.B" in prec + + +def test_i_can_extract_function_call_with_no_column_args(): + """Test that a FunctionCall with only literal arguments creates no column dependencies. + + Why: FunctionCall with literal args only (e.g. NOW(1)) must not create + spurious dependency edges that would trigger unnecessary recalculations. + """ + graph = DependencyGraph() + formula = make_formula(FunctionCall("NOW", [LiteralNode(1)])) + graph.add_formula("t", "C", formula) + + prec = graph._precedents.get("t.C", set()) + assert len(prec) == 0, "FunctionCall with literal args should produce no column dependencies" + + +def test_i_can_extract_conditional_expr_all_branches(): + """Test that ConditionalExpr extracts dependencies from value_expr, condition, and else_expr. + + Why: Source lines 380-384 walk all three branches. All three column references + must appear as dependencies so that any change in any branch triggers recalculation. + """ + graph = DependencyGraph() + formula = make_formula( + ConditionalExpr( + value_expr=ColumnRef("A"), + condition=ColumnRef("B"), + else_expr=ColumnRef("C"), + ) + ) + graph.add_formula("t", "D", formula) + + prec = graph._precedents.get("t.D", set()) + assert "t.A" in prec + assert "t.B" in prec + assert "t.C" in prec + + +def test_i_can_extract_conditional_expr_without_else(): + """Test that ConditionalExpr with else_expr=None does not crash and extracts value and condition. + + Why: Source line 383 guards on ``if node.else_expr is not None``. A missing + else branch must not raise and must still extract the remaining two dependencies. + """ + graph = DependencyGraph() + formula = make_formula( + ConditionalExpr( + value_expr=ColumnRef("A"), + condition=ColumnRef("B"), + else_expr=None, + ) + ) + graph.add_formula("t", "C", formula) + + prec = graph._precedents.get("t.C", set()) + assert "t.A" in prec + assert "t.B" in prec diff --git a/tests/core/formula/test_formula_engine.py b/tests/core/formula/test_formula_engine.py new file mode 100644 index 0000000..875e4fd --- /dev/null +++ b/tests/core/formula/test_formula_engine.py @@ -0,0 +1,408 @@ +""" +Tests for the FormulaEngine facade. +""" +import numpy as np +import pytest + +from myfasthtml.core.formula.dsl.exceptions import FormulaSyntaxError, FormulaCycleError +from myfasthtml.core.formula.engine import FormulaEngine + + +class FakeStore: + """Minimal DatagridStore-like object for testing.""" + + def __init__(self, rows): + self.ns_row_data = rows + self.ns_fast_access = self._build_fast_access(rows) + self.ns_total_rows = len(rows) + + @staticmethod + def _build_fast_access(rows): + """Build a columnar fast-access dict from a list of row dicts.""" + if not rows: + return {} + return { + col: np.array([row.get(col) for row in rows], dtype=object) + for col in rows[0].keys() + } + + +def make_engine(store_map=None): + """Create an engine with an optional store map for cross-table resolution.""" + + def resolver(table_name): + return store_map.get(table_name) if store_map else None + + return FormulaEngine(registry_resolver=resolver) + + +# ==================== TestSetFormula ==================== + +class TestSetFormula: + """Tests for set_formula: parsing, registration, and edge cases.""" + + def test_i_can_set_and_evaluate_formula(self): + """Test that a formula can be set and evaluated.""" + rows = [{"Price": 10, "Quantity": 3}] + store = FakeStore(rows) + + engine = make_engine({"orders": store}) + engine.set_formula("orders", "total", "{Price} * {Quantity}") + engine.recalculate_if_needed("orders", store) + + assert "total" in store.ns_fast_access + assert store.ns_fast_access["total"][0] == 30 + + def test_i_can_evaluate_multiple_rows(self): + """Test that formula evaluation works across multiple rows.""" + rows = [ + {"Price": 10, "Quantity": 3}, + {"Price": 20, "Quantity": 2}, + {"Price": 5, "Quantity": 10}, + ] + store = FakeStore(rows) + + engine = make_engine({"orders": store}) + engine.set_formula("orders", "total", "{Price} * {Quantity}") + engine.recalculate_if_needed("orders", store) + + totals = store.ns_fast_access["total"] + assert totals[0] == 30 + assert totals[1] == 40 + assert totals[2] == 50 + + def test_i_cannot_set_invalid_formula(self): + """Test that invalid formula syntax raises FormulaSyntaxError.""" + engine = make_engine() + with pytest.raises(FormulaSyntaxError): + engine.set_formula("t", "col", "{Price} * * {Qty}") + + def test_i_cannot_set_formula_with_cycle(self): + """Test that a circular dependency raises FormulaCycleError.""" + rows = [{"A": 1}] + store = FakeStore(rows) + engine = make_engine({"t": store}) + + engine.set_formula("t", "A", "{B} + 1") + with pytest.raises(FormulaCycleError): + engine.set_formula("t", "B", "{A} * 2") + + @pytest.mark.parametrize("text", ["", " ", "\t\n"]) + def test_i_can_set_formula_with_blank_input_removes_it(self, text): + """Test that setting a blank or whitespace-only formula string removes it. + + Why: set_formula strips the input (source line 86) before checking emptiness. + All blank variants — empty string, spaces, tabs — must behave identically. + """ + rows = [{"Price": 10}] + store = FakeStore(rows) + engine = make_engine({"t": store}) + + engine.set_formula("t", "col", "{Price} * 2") + assert engine.has_formula("t", "col") + + engine.set_formula("t", "col", text) + assert not engine.has_formula("t", "col") + + def test_i_can_replace_formula_clears_old_dependencies(self): + """Test that replacing a formula removes old dependency edges. + + Why: add_formula calls _remove_edges before adding new ones (source line 99). + After replacement, marking the old dependency dirty must NOT trigger + recalculation of the formula column. + """ + rows = [{"A": 5, "B": 3, "X": 10}] + store = FakeStore(rows) + engine = make_engine({"t": store}) + + engine.set_formula("t", "C", "{A} + {B}") + engine.recalculate_if_needed("t", store) + assert store.ns_fast_access["C"][0] == pytest.approx(8.0) + + engine.set_formula("t", "C", "{X} * 2") + engine.recalculate_if_needed("t", store) + assert store.ns_fast_access["C"][0] == pytest.approx(20.0) + + # A is no longer a dependency — changing it should NOT trigger C recalculation + rows[0]["A"] = 100 + store.ns_fast_access["A"] = np.array([100], dtype=object) + engine.mark_data_changed("t", "A") + + store.ns_fast_access["C"][0] = 999 # sentinel + engine.recalculate_if_needed("t", store) + assert store.ns_fast_access["C"][0] == 999, ( + "C should not be recalculated when A changes after formula replacement" + ) + + +# ==================== TestRemoveFormula ==================== + +class TestRemoveFormula: + """Tests for remove_formula.""" + + def test_i_can_remove_formula(self): + """Test that a formula can be removed.""" + rows = [{"Price": 10, "Quantity": 3}] + store = FakeStore(rows) + engine = make_engine({"t": store}) + + engine.set_formula("t", "total", "{Price} * {Quantity}") + assert engine.has_formula("t", "total") + + engine.remove_formula("t", "total") + assert not engine.has_formula("t", "total") + + def test_i_can_remove_formula_and_set_back(self): + """Test that a formula can be removed via empty string and then re-added.""" + rows = [{"Price": 10}] + store = FakeStore(rows) + engine = make_engine({"t": store}) + + engine.set_formula("t", "col", "{Price} * 2") + assert engine.has_formula("t", "col") + + engine.set_formula("t", "col", "") + assert not engine.has_formula("t", "col") + + +# ==================== TestRecalculate ==================== + +class TestRecalculate: + """Tests for recalculate_if_needed: dirty tracking, return values, and row data update.""" + + def test_i_can_recalculate_only_dirty(self): + """Test that only dirty formula columns are recalculated.""" + rows = [{"A": 5, "B": 3}] + store = FakeStore(rows) + engine = make_engine({"t": store}) + + engine.set_formula("t", "C", "{A} + {B}") + engine.recalculate_if_needed("t", store) + + # Manually set a different value to detect recalculation + store.ns_fast_access["C"][0] = 999 + + # No dirty flags → should NOT recalculate + engine.recalculate_if_needed("t", store) + assert store.ns_fast_access["C"][0] == 999 # unchanged + + def test_i_can_recalculate_after_data_changed(self): + """Test that marking data changed triggers recalculation.""" + rows = [{"A": 5, "B": 3}] + store = FakeStore(rows) + engine = make_engine({"t": store}) + + engine.set_formula("t", "C", "{A} + {B}") + engine.recalculate_if_needed("t", store) + assert store.ns_fast_access["C"][0] == 8 + + # Update both storage structures (mirrors real DataGrid behaviour) + rows[0]["A"] = 10 + store.ns_fast_access["A"] = np.array([10], dtype=object) + + # Mark source column dirty + engine.mark_data_changed("t", "A") + engine.recalculate_if_needed("t", store) + assert store.ns_fast_access["C"][0] == 13 + + def test_i_can_recalculate_returns_false_when_no_dirty(self): + """Test that recalculate_if_needed returns False when no nodes are dirty. + + Why: Source line 151 returns False early when get_calculation_order is empty. + After a first successful recalculation all dirty flags are cleared, so the + second call must return False to avoid redundant work. + """ + rows = [{"A": 5}] + store = FakeStore(rows) + engine = make_engine({"t": store}) + + engine.set_formula("t", "B", "{A} + 1") + engine.recalculate_if_needed("t", store) # clears dirty flag + + result = engine.recalculate_if_needed("t", store) + assert result is False + + def test_i_can_recalculate_returns_true_when_dirty(self): + """Test that recalculate_if_needed returns True when columns are recalculated. + + Why: Source line 164 returns True after the evaluation loop. This return value + lets callers skip downstream work (e.g. rendering) when nothing changed. + """ + rows = [{"A": 5}] + store = FakeStore(rows) + engine = make_engine({"t": store}) + + engine.set_formula("t", "B", "{A} + 1") + + result = engine.recalculate_if_needed("t", store) + assert result is True + + def test_i_can_recalculate_with_empty_store(self): + """Test that recalculate_if_needed handles a store with no rows without crashing. + + Why: _evaluate_column guards on empty ns_row_data (source line 211-212). + No formula result should be written when there are no rows to process. + """ + store = FakeStore([]) # no rows + engine = make_engine({"t": store}) + + engine.set_formula("t", "B", "{A} + 1") + engine.recalculate_if_needed("t", store) # must not raise + + assert "B" not in store.ns_fast_access + + def test_i_can_verify_formula_values_appear_in_row_data(self): + """Test that formula values are written back into ns_row_data after recalculation. + + Why: _rebuild_row_data (source line 231-249) merges ns_fast_access values into + each row dict. This ensures formula results are available in row_data for + subsequent evaluation passes and for rendering. + """ + rows = [{"A": 5}] + store = FakeStore(rows) + engine = make_engine({"t": store}) + + engine.set_formula("t", "B", "{A} + 10") + engine.recalculate_if_needed("t", store) + + assert store.ns_row_data[0]["B"] == pytest.approx(15.0) + + # --- Known bug: chained formula columns --- + # _rebuild_row_data is called once AFTER the full evaluation loop, so formula + # column B is not yet in row_data when formula column C is evaluated. + # Fix needed: rebuild row_data between each column evaluation in the loop. + + def test_i_can_recalculate_chain_formula_initial(self): + """Test that C = f(B) is correct when B is itself a formula column (initial pass). + + Chain: A (data) → B = A + 10 → C = B * 2 + Expected: B = 15, C = 30. + """ + rows = [{"A": 5}] + store = FakeStore(rows) + engine = make_engine({"t": store}) + + engine.set_formula("t", "B", "{A} + 10") + engine.set_formula("t", "C", "{B} * 2") + engine.recalculate_if_needed("t", store) + + assert store.ns_fast_access["B"][0] == 15 + assert store.ns_fast_access["C"][0] == 30 + + def test_i_can_recalculate_chain_formula_after_data_change(self): + """Test that C = f(B) stays correct after the source data column A changes. + + Chain: A (data) → B = A + 10 → C = B * 2 + After A changes from 5 to 10: B = 20, C = 40. + """ + rows = [{"A": 5}] + store = FakeStore(rows) + engine = make_engine({"t": store}) + + engine.set_formula("t", "B", "{A} + 10") + engine.set_formula("t", "C", "{B} * 2") + engine.recalculate_if_needed("t", store) # first pass (B=15, C=None per bug above) + + rows[0]["A"] = 10 + store.ns_fast_access["A"] = np.array([10], dtype=object) + engine.mark_data_changed("t", "A") + engine.recalculate_if_needed("t", store) + + assert store.ns_fast_access["B"][0] == 20 + assert store.ns_fast_access["C"][0] == 40 + + +# ==================== TestCrossTable ==================== + +class TestCrossTable: + """Tests for cross-table reference resolution strategies.""" + + def test_i_can_handle_cross_table_formula(self): + """Test that cross-table references are resolved via registry (Strategy 3: row_index).""" + orders_rows = [{"Quantity": 3, "ProductId": 1}] + products_rows = [{"Price": 99.0}] + + orders_store = FakeStore(orders_rows) + products_store = FakeStore(products_rows) + products_store.ns_fast_access = {"Price": np.array([99.0], dtype=object)} + + engine = make_engine({ + "orders": orders_store, + "products": products_store, + }) + engine.set_formula("orders", "total", "{products.Price} * {Quantity}") + engine.recalculate_if_needed("orders", orders_store) + + assert "total" in orders_store.ns_fast_access + assert orders_store.ns_fast_access["total"][0] == pytest.approx(297.0) + + def test_i_can_resolve_cross_table_by_id_join(self): + """Test that cross-table references are resolved via implicit id-join (Strategy 2). + + Why: Strategy 2 (source lines 303-318) matches rows where both tables share + the same id value. This allows cross-table lookups without an explicit WHERE + clause when both stores expose an 'id' column in ns_fast_access. + """ + orders_rows = [{"id": 101, "Quantity": 3}, {"id": 102, "Quantity": 5}] + orders_store = FakeStore(orders_rows) + + products_rows = [{"id": 101, "Price": 10.0}, {"id": 102, "Price": 20.0}] + products_store = FakeStore(products_rows) + + engine = make_engine({ + "orders": orders_store, + "products": products_store, + }) + engine.set_formula("orders", "total", "{products.Price} * {Quantity}") + engine.recalculate_if_needed("orders", orders_store) + + totals = orders_store.ns_fast_access["total"] + assert totals[0] == 30, "Row with id=101: Price=10 * Qty=3" + assert totals[1] == 100, "Row with id=102: Price=20 * Qty=5" + + def test_i_can_handle_cross_table_without_registry(self): + """Test that a cross-table formula evaluates gracefully when no registry is set. + + Why: _make_cross_table_resolver guards on registry_resolver=None (source line 274). + The formula must evaluate to None without raising, preserving engine stability. + """ + rows = [{"Quantity": 3}] + store = FakeStore(rows) + + engine = FormulaEngine(registry_resolver=None) + engine.set_formula("orders", "total", "{products.Price} * {Quantity}") + engine.recalculate_if_needed("orders", store) # must not raise + + assert store.ns_fast_access["total"][0] is None + + def test_i_can_handle_cross_table_missing_table(self): + """Test that a cross-table formula evaluates gracefully when the remote table is absent. + + Why: Source line 282-284 returns None when registry_resolver returns None for + the requested table. The engine must not crash and must produce None for the row. + """ + rows = [{"Quantity": 3}] + orders_store = FakeStore(rows) + + engine = make_engine({"orders": orders_store}) # "products" not in registry + engine.set_formula("orders", "total", "{products.Price} * {Quantity}") + engine.recalculate_if_needed("orders", orders_store) # must not raise + + assert orders_store.ns_fast_access["total"][0] is None + + +# ==================== TestGetFormulaText ==================== + +class TestGetFormulaText: + """Tests for formula text retrieval.""" + + def test_i_can_get_formula_text(self): + """Test that registered formula text can be retrieved.""" + engine = make_engine() + engine.set_formula("t", "col", "{Price} * 2") + assert engine.get_formula_text("t", "col") == "{Price} * 2" + + def test_i_can_get_formula_text_returns_none_when_not_set(self): + """Test that get_formula_text returns None for non-formula columns.""" + engine = make_engine() + assert engine.get_formula_text("t", "non_existing") is None diff --git a/tests/core/formula/test_formula_evaluator.py b/tests/core/formula/test_formula_evaluator.py new file mode 100644 index 0000000..8e51ee6 --- /dev/null +++ b/tests/core/formula/test_formula_evaluator.py @@ -0,0 +1,188 @@ +""" +Tests for the FormulaEvaluator. +""" +import pytest + +from myfasthtml.core.formula.engine import parse_formula +from myfasthtml.core.formula.evaluator import FormulaEvaluator + + +def make_evaluator(resolver=None): + return FormulaEvaluator(cross_table_resolver=resolver) + + +def eval_formula(text, row_data, row_index=0, resolver=None): + """Helper: parse and evaluate a formula.""" + formula = parse_formula(text) + evaluator = make_evaluator(resolver) + return evaluator.evaluate(formula, row_data, row_index) + + +# ==================== Arithmetic ==================== + +@pytest.mark.parametrize("formula,row_data,expected", [ + ("{Price} * {Quantity}", {"Price": 10, "Quantity": 3}, 30.0), + ("{Price} + {Tax}", {"Price": 100, "Tax": 20}, 120.0), + ("{Total} - {Discount}", {"Total": 100, "Discount": 15}, 85.0), + ("{Total} / {Count}", {"Total": 100, "Count": 4}, 25.0), + ("{Value} % 3", {"Value": 10}, 1.0), + ("{Base} ^ 2", {"Base": 5}, 25.0), +]) +def test_i_can_evaluate_simple_arithmetic(formula, row_data, expected): + """Test that arithmetic formulas evaluate correctly.""" + result = eval_formula(formula, row_data) + assert result == pytest.approx(expected) + + +# ==================== Functions ==================== + +@pytest.mark.parametrize("formula,row_data,expected", [ + ("round({Price} * 1.2, 2)", {"Price": 10}, 12.0), + ("abs({Balance})", {"Balance": -50}, 50.0), + ("upper({Name})", {"Name": "hello"}, "HELLO"), + ("lower({Name})", {"Name": "WORLD"}, "world"), + ("len({Description})", {"Description": "abc"}, 3), + ("concat({First}, \" \", {Last})", {"First": "John", "Last": "Doe"}, "John Doe"), + ("left({Code}, 3)", {"Code": "ABCDEF"}, "ABC"), + ("right({Code}, 3)", {"Code": "ABCDEF"}, "DEF"), + ("trim({Name})", {"Name": " hello "}, "hello"), +]) +def test_i_can_evaluate_function(formula, row_data, expected): + """Test that built-in function calls evaluate correctly.""" + result = eval_formula(formula, row_data) + assert result == expected + + +def test_i_can_evaluate_nested_functions(): + """Test that nested function calls evaluate correctly.""" + result = eval_formula("round(abs({Val}), 1)", {"Val": -3.456}) + assert result == 3.5 + + +# ==================== Conditionals ==================== + +def test_i_can_evaluate_conditional_true(): + """Test conditional when condition is true.""" + result = eval_formula('{Price} * 0.8 if {Country} == "FR"', {"Price": 100, "Country": "FR"}) + assert result == 80.0 + + +def test_i_can_evaluate_conditional_false_no_else(): + """Test conditional returns None when condition is false and no else.""" + result = eval_formula('{Price} * 0.8 if {Country} == "FR"', {"Price": 100, "Country": "DE"}) + assert result is None + + +def test_i_can_evaluate_conditional_with_else(): + """Test conditional returns else value when condition is false.""" + result = eval_formula('{Price} * 0.8 if {Country} == "FR" else {Price}', {"Price": 100, "Country": "DE"}) + assert result == 100.0 + + +def test_i_can_evaluate_chained_conditional(): + """Test chained conditionals evaluate in order.""" + formula = '{Price} * 0.8 if {Country} == "FR" else {Price} * 0.9 if {Country} == "DE" else {Price}' + assert eval_formula(formula, {"Price": 100, "Country": "FR"}) == pytest.approx(80.0) + assert eval_formula(formula, {"Price": 100, "Country": "DE"}) == pytest.approx(90.0) + assert eval_formula(formula, {"Price": 100, "Country": "US"}) == pytest.approx(100.0) + + +# ==================== Logical operators ==================== + +@pytest.mark.parametrize("formula,row_data,expected", [ + ("{A} and {B}", {"A": True, "B": True}, True), + ("{A} and {B}", {"A": True, "B": False}, False), + ("{A} or {B}", {"A": False, "B": True}, True), + ("{A} or {B}", {"A": False, "B": False}, False), + ("not {A}", {"A": True}, False), + ("not {A}", {"A": False}, True), +]) +def test_i_can_evaluate_logical_operators(formula, row_data, expected): + """Test that logical operators evaluate correctly.""" + result = eval_formula(formula, row_data) + assert result == expected + + +# ==================== Error handling ==================== + +def test_i_can_handle_division_by_zero(): + """Test that division by zero returns None.""" + result = eval_formula("{A} / {B}", {"A": 10, "B": 0}) + assert result is None + + +def test_i_can_handle_missing_column(): + """Test that missing column reference returns None.""" + result = eval_formula("{NonExistent} * 2", {}) + assert result is None + + +def test_i_can_handle_none_operand(): + """Test that operations with None operands return None.""" + result = eval_formula("{A} * {B}", {"A": None, "B": 5}) + assert result is None + + +# ==================== Cross-table references ==================== + +def test_i_can_evaluate_cross_table_ref(): + """Test cross-table reference with mock resolver.""" + + def mock_resolver(table, column, where_clause, row_index): + assert table == "Products" + assert column == "Price" + return 99.0 + + result = eval_formula("{Products.Price} * {Quantity}", {"Quantity": 3}, resolver=mock_resolver) + assert result == pytest.approx(297.0) + + +def test_i_can_evaluate_cross_table_with_where(): + """Test cross-table reference with WHERE clause and mock resolver.""" + + def mock_resolver(table, column, where_clause, row_index): + assert where_clause is not None + assert where_clause.local_column == "ProductCode" + return 50.0 + + result = eval_formula( + "{Products.Price where Products.Code = ProductCode} * {Qty}", + {"ProductCode": "ABC", "Qty": 2}, + resolver=mock_resolver, + ) + assert result == pytest.approx(100.0) + + +# ==================== Aggregation ==================== + +def test_i_can_evaluate_aggregation(): + """Test that aggregation functions work with cross-table resolver.""" + values = [10.0, 20.0, 30.0] + call_count = [0] + + def mock_resolver(table, column, where_clause, row_index): + # For aggregation test, return a list to simulate multi-row match + val = values[call_count[0] % len(values)] + call_count[0] += 1 + return val + + formula = parse_formula("sum({OrderLines.Amount where OrderLines.OrderId = Id})") + evaluator = FormulaEvaluator(cross_table_resolver=mock_resolver) + result = evaluator.evaluate(formula, {"Id": 1}, 0) + # sum() with a single cross-table value returned by resolver + assert result is not None + + +# ==================== String operations ==================== + +@pytest.mark.parametrize("formula,row_data,expected", [ + ('{Name} contains "Corp"', {"Name": "Acme Corp"}, True), + ('{Code} startswith "ERR"', {"Code": "ERR001"}, True), + ('{File} endswith ".csv"', {"File": "data.csv"}, True), + ('{Status} in ["active", "new"]', {"Status": "active"}, True), + ('{Status} in ["active", "new"]', {"Status": "deleted"}, False), +]) +def test_i_can_evaluate_string_operations(formula, row_data, expected): + """Test string comparison operations.""" + result = eval_formula(formula, row_data) + assert result == expected diff --git a/tests/core/formula/test_formula_parser.py b/tests/core/formula/test_formula_parser.py new file mode 100644 index 0000000..4b99f35 --- /dev/null +++ b/tests/core/formula/test_formula_parser.py @@ -0,0 +1,188 @@ +""" +Tests for the formula parser (grammar + transformer integration). +""" +import pytest + +from myfasthtml.core.formula.dataclasses import ( + BinaryOp, + CrossTableRef, + ConditionalExpr, + FunctionCall, + LiteralNode, + FormulaDefinition, +) +from myfasthtml.core.formula.dsl.exceptions import FormulaSyntaxError +from myfasthtml.core.formula.engine import parse_formula + + +# ==================== Valid formulas ==================== + +@pytest.mark.parametrize("formula_text", [ + "{Price} * {Quantity}", + "{Price} + {Tax}", + "{Total} - {Discount}", + "{Total} / {Count}", + "{Value} % 2", + "{Base} ^ 2", +]) +def test_i_can_parse_simple_arithmetic(formula_text): + """Test that basic arithmetic formulas parse without error.""" + result = parse_formula(formula_text) + assert result is not None + assert isinstance(result, FormulaDefinition) + assert isinstance(result.expression, BinaryOp) + + +@pytest.mark.parametrize("formula_text,expected_func", [ + ("round({Price} * 1.2, 2)", "round"), + ("abs({Balance})", "abs"), + ("upper({Name})", "upper"), + ("len({Description})", "len"), + ("today()", "today"), + ("concat({First}, \" \", {Last})", "concat"), +]) +def test_i_can_parse_function_call(formula_text, expected_func): + """Test that function calls parse correctly.""" + result = parse_formula(formula_text) + assert result is not None + assert isinstance(result.expression, FunctionCall) + assert result.expression.function_name == expected_func + + +def test_i_can_parse_conditional_no_else(): + """Test that suffix-if without else parses correctly.""" + result = parse_formula('{Price} * 0.8 if {Country} == "FR"') + assert result is not None + expr = result.expression + assert isinstance(expr, ConditionalExpr) + assert expr.else_expr is None + assert isinstance(expr.value_expr, BinaryOp) + + +def test_i_can_parse_conditional_with_else(): + """Test that suffix-if with else parses correctly.""" + result = parse_formula('{Price} * 0.8 if {Country} == "FR" else {Price}') + assert result is not None + expr = result.expression + assert isinstance(expr, ConditionalExpr) + assert expr.else_expr is not None + + +def test_i_can_parse_chained_conditional(): + """Test that chained conditionals parse correctly.""" + formula = '{Price} * 0.8 if {Country} == "FR" else {Price} * 0.9 if {Country} == "DE" else {Price}' + result = parse_formula(formula) + assert result is not None + expr = result.expression + assert isinstance(expr, ConditionalExpr) + # The else_expr should be another ConditionalExpr + assert isinstance(expr.else_expr, ConditionalExpr) + + +def test_i_can_parse_cross_table_ref(): + """Test that cross-table references parse correctly.""" + result = parse_formula("{Products.Price} * {Quantity}") + assert result is not None + expr = result.expression + assert isinstance(expr, BinaryOp) + assert isinstance(expr.left, CrossTableRef) + assert expr.left.table == "Products" + assert expr.left.column == "Price" + + +def test_i_can_parse_cross_table_with_where(): + """Test that cross-table references with WHERE clause parse correctly.""" + result = parse_formula("{Products.Price where Products.Code = ProductCode} * {Quantity}") + assert result is not None + expr = result.expression + assert isinstance(expr, BinaryOp) + cross_ref = expr.left + assert isinstance(cross_ref, CrossTableRef) + assert cross_ref.where_clause is not None + assert cross_ref.where_clause.remote_table == "Products" + assert cross_ref.where_clause.remote_column == "Code" + assert cross_ref.where_clause.local_column == "ProductCode" + + +@pytest.mark.parametrize("formula_text,expected_op", [ + ("{A} and {B}", "and"), + ("{A} or {B}", "or"), + ("not {A}", "not"), +]) +def test_i_can_parse_logical_operators(formula_text, expected_op): + """Test that logical operators parse correctly.""" + result = parse_formula(formula_text) + assert result is not None + if expected_op == "not": + from myfasthtml.core.formula.dataclasses import UnaryOp + assert isinstance(result.expression, UnaryOp) + else: + assert isinstance(result.expression, BinaryOp) + assert result.expression.operator == expected_op + + +@pytest.mark.parametrize("formula_text", [ + "42", + "3.14", + '"hello"', + "true", + "false", +]) +def test_i_can_parse_literals(formula_text): + """Test that literal values parse correctly.""" + result = parse_formula(formula_text) + assert result is not None + assert isinstance(result.expression, LiteralNode) + + +def test_i_can_parse_aggregation(): + """Test that aggregation with cross-table WHERE parses correctly.""" + result = parse_formula("sum({OrderLines.Amount where OrderLines.OrderId = Id})") + assert result is not None + expr = result.expression + assert isinstance(expr, FunctionCall) + assert expr.function_name == "sum" + assert len(expr.arguments) == 1 + arg = expr.arguments[0] + assert isinstance(arg, CrossTableRef) + assert arg.where_clause is not None + + +def test_i_can_parse_empty_formula(): + """Test that empty formula returns None.""" + result = parse_formula("") + assert result is None + + +def test_i_can_parse_whitespace_formula(): + """Test that whitespace-only formula returns None.""" + result = parse_formula(" ") + assert result is None + + +def test_i_can_parse_nested_functions(): + """Test that nested function calls parse correctly.""" + result = parse_formula("round(avg({Q1}, {Q2}, {Q3}), 1)") + assert result is not None + expr = result.expression + assert isinstance(expr, FunctionCall) + assert expr.function_name == "round" + + +# ==================== Invalid formulas ==================== + +@pytest.mark.parametrize("formula_text", [ + "{Price} * * {Quantity}", # double operator + "round(", # unclosed paren + "123 + + 456", # double operator +]) +def test_i_cannot_parse_invalid_syntax(formula_text): + """Test that invalid syntax raises FormulaSyntaxError.""" + with pytest.raises(FormulaSyntaxError): + parse_formula(formula_text) + + +def test_i_cannot_parse_unclosed_brace(): + """Test that an unclosed brace raises FormulaSyntaxError.""" + with pytest.raises(FormulaSyntaxError): + parse_formula("{Price")