import re from dataclasses import dataclass from typing import Optional, Any, Literal from fastcore.basics import NotStr from fastcore.xml import FT from fasthtml.components import Span from myfasthtml.core.commands import Command from myfasthtml.core.utils import quoted_str, snake_to_pascal from myfasthtml.test.testclient import MyFT MISSING_ATTR = "** MISSING **" class Predicate: def __init__(self, value): self.value = value def validate(self, actual): raise NotImplementedError def __str__(self): if self.value is None: str_value = '' elif isinstance(self.value, str): str_value = self.value elif isinstance(self.value, (list, tuple)): if len(self.value) == 1: str_value = self.value[0] else: str_value = str(self.value) else: str_value = str(self.value) return f"{self.__class__.__name__}({str_value})" def __repr__(self): return f"{self.__class__.__name__}({self.value if self.value is not None else ''})" def __eq__(self, other): if type(self) is not type(other): return False return self.value == other.value def __hash__(self): return hash(self.value) class AttrPredicate(Predicate): """ Predicate that validates an attribute value. It's given as a value of an attribute. """ pass class StartsWith(AttrPredicate): def __init__(self, value): super().__init__(value) def validate(self, actual): return actual.startswith(self.value) class EndsWith(AttrPredicate): def __init__(self, value): super().__init__(value) def validate(self, actual): return actual.endswith(self.value) class Contains(AttrPredicate): def __init__(self, *value, _word=False): super().__init__(value) self._word = _word def validate(self, actual): if self._word: words = actual.split() return all(val in words for val in self.value) else: return all(val in actual for val in self.value) class DoesNotContain(AttrPredicate): def __init__(self, *value): super().__init__(value) def validate(self, actual): return all(val not in actual for val in self.value) class AnyValue(AttrPredicate): """ True is the attribute is present and the value is not None. """ def __init__(self): super().__init__(None) def validate(self, actual): return actual is not None class Regex(AttrPredicate): def __init__(self, pattern): super().__init__(pattern) def validate(self, actual): return re.match(self.value, actual) is not None class And(AttrPredicate): def __init__(self, *predicates): super().__init__(predicates) def validate(self, actual): return all(p.validate(actual) for p in self.value) class ChildrenPredicate(Predicate): """ Predicate given as a child of an element. """ def to_debug(self, element): return element class Empty(ChildrenPredicate): def __init__(self): super().__init__(None) def validate(self, actual): return len(actual.children) == 0 and len(actual.attrs) == 0 class NoChildren(ChildrenPredicate): def __init__(self): super().__init__(None) def validate(self, actual): return len(actual.children) == 0 class AttributeForbidden(ChildrenPredicate): """ To validate that an attribute is not present in an element. """ def __init__(self, value): super().__init__(value) def validate(self, actual): return self.value not in actual.attrs or actual.attrs[self.value] is None def to_debug(self, element): element.attrs[self.value] = "** NOT ALLOWED **" return element class HasHtmx(ChildrenPredicate): def __init__(self, command: Command = None, **htmx_params): super().__init__(None) self.command = command if command: self.htmx_params = command.get_htmx_params() | htmx_params else: self.htmx_params = htmx_params self.htmx_params = {k.replace("hx_", "hx-"): v for k, v in self.htmx_params.items()} def validate(self, actual): return all(actual.attrs.get(k) == v for k, v in self.htmx_params.items()) def to_debug(self, element): for k, v in self.htmx_params.items(): element.attrs[k] = v return element class TestObject: def __init__(self, cls, **kwargs): self.cls = cls self.attrs = kwargs class TestLabel(TestObject): def __init__(self, label: str, icon: str = None, command=None): super().__init__("span") self.label = label self.icon = snake_to_pascal(icon) if (icon and icon[0].islower()) else icon self.children = [] if self.icon: self.children.append(TestIcon(self.icon, wrapper="span")) self.children.append(Span(label)) if command: self.attrs |= command.get_htmx_params() def __str__(self): icon_str = f"{icon_str}{self.label}' class TestIcon(TestObject): def __init__(self, name: Optional[str] = '', wrapper: Literal["div", "span"] = "div", command=None): super().__init__(wrapper) self.wrapper = wrapper self.name = snake_to_pascal(name) if (name and name[0].islower()) else name self.children = [ TestObject(NotStr, s=Regex(f'' class TestIconNotStr(TestObject): def __init__(self, name: Optional[str] = ''): super().__init__(NotStr) self.name = snake_to_pascal(name) if (name and name[0].islower()) else name self.attrs["s"] = Regex(f'' class TestCommand(TestObject): def __init__(self, name, **kwargs): super().__init__("Command", **kwargs) self.attrs = {"name": name} | kwargs # name should be first class TestScript(TestObject): def __init__(self, script): super().__init__("script") self.script = script self.children = [ NotStr(self.script), ] @dataclass class DoNotCheck: desc: str = None @dataclass class Skip: element: Any desc: str = None def _get_type(x): if hasattr(x, "tag"): return x.tag if isinstance(x, TestObject): return x.cls.__name__ if isinstance(x.cls, type) else str(x.cls) return type(x).__name__ def _get_attr(x, attr): if isinstance(x, TestObject) and "s" in x.attrs and isinstance(x.attrs["s"], Regex): return x.attrs["s"].value + " />" if hasattr(x, "attrs"): return x.attrs.get(attr, MISSING_ATTR) if not hasattr(x, attr): return MISSING_ATTR if isinstance(x, NotStr) and attr == "s": # Special case for NotStr: return the name of the svg svg = getattr(x, attr, MISSING_ATTR) match = re.search(r'name\s*=\s*["\']([^"\']+)["\']', svg) if match: return f'' return getattr(x, attr, MISSING_ATTR) def _get_attributes(x): """Get the attributes dict from an element.""" if hasattr(x, "attrs"): return x.attrs return {} def _get_children(x): """Get the children list from an element.""" if hasattr(x, "children"): return x.children return [] def _str_element(element, expected=None, keep_open=None): # compare to itself if no expected element is provided if expected is None: expected = element if hasattr(element, "tag"): # the attributes are compared to the expected element elt_attrs = {attr_name: _get_attr(element, attr_name) for attr_name in [attr_name for attr_name in _get_attributes(expected) if attr_name is not None]} elt_attrs_str = " ".join(f'"{attr_name}"="{attr_value}"' for attr_name, attr_value in elt_attrs.items()) tag_str = f"({element.tag} {elt_attrs_str}" # manage the closing tag if keep_open is False: tag_str += " ...)" if len(element.children) > 0 else ")" elif keep_open is True: tag_str += "..." if elt_attrs_str == "" else " ..." else: # close the tag if there are no children not_special_children = [c for c in element.children if not isinstance(c, Predicate)] if len(not_special_children) == 0: tag_str += ")" return tag_str else: return quoted_str(element) class ErrorOutput: def __init__(self, path, element, expected): self.path = path self.element = element self.expected = expected self.output = [] self.indent = "" @staticmethod def _unconstruct_path_item(item): if "#" in item: elt_name, elt_id = item.split("#") return elt_name, "id", elt_id elif match := re.match(r'(\w+)\[(class|name)=([^]]+)]', item): return match.groups() return item, None, None def __str__(self): return f"ErrorOutput({self.output})" def compute(self): # first render the path hierarchy for p in self.path.split(".")[:-1]: elt_name, attr_name, attr_value = self._unconstruct_path_item(p) path_str = _str_element(MyFT(elt_name, {attr_name: attr_value}), keep_open=True) self._add_to_output(f"{path_str}") self.indent += " " # then render the element if hasattr(self.expected, "tag") and hasattr(self.element, "tag"): # display the tag and its attributes tag_str = _str_element(self.element, self.expected) self._add_to_output(tag_str) # Try to show where the differences are error_str = self._detect_error(self.element, self.expected) if error_str: self._add_to_output(error_str) # render the children expected_children = [c for c in self.expected.children if not isinstance(c, ChildrenPredicate)] if len(expected_children) > 0: self.indent += " " element_index = 0 for expected_child in expected_children: if element_index >= len(self.element.children): # When there are fewer children than expected, we display a placeholder child_str = "! ** MISSING ** !" self._add_to_output(child_str) element_index += 1 continue # display the child element_child = self.element.children[element_index] child_str = _str_element(element_child, expected_child, keep_open=False) self._add_to_output(child_str) # manage errors (only when the expected is a FT element if hasattr(expected_child, "tag"): child_error_str = self._detect_error(element_child, expected_child) if child_error_str: self._add_to_output(child_error_str) # continue element_index += 1 self.indent = self.indent[:-2] self._add_to_output(")") elif isinstance(self.expected, TestObject): cls = _get_type(self.element) attrs = {attr_name: _get_attr(self.element, attr_name) for attr_name in self.expected.attrs} self._add_to_output(f"({cls} {_str_attrs(attrs)})") # Try to show where the differences are error_str = self._detect_error(self.element, self.expected) if error_str: self._add_to_output(error_str) else: self._add_to_output(str(self.element)) # Try to show where the differences are error_str = self._detect_error(self.element, self.expected) if error_str: self._add_to_output(error_str) def _add_to_output(self, msg): self.output.append(f"{self.indent}{msg}") def _detect_error(self, element, expected): """ Detect errors between element and expected, returning a visual marker string. Unified version that handles both FT elements and TestObjects. """ # For elements with structure (FT or TestObject) if hasattr(expected, "tag") or isinstance(expected, TestObject): element_type = _get_type(element) expected_type = _get_type(expected) type_error = (" " if element_type == expected_type else "^") * len(element_type) element_attrs = {attr_name: _get_attr(element, attr_name) for attr_name in _get_attributes(expected)} expected_attrs = {attr_name: _get_attr(expected, attr_name) for attr_name in _get_attributes(expected)} attrs_in_error = {attr_name for attr_name, attr_value in element_attrs.items() if not self._matches(attr_value, expected_attrs[attr_name])} attrs_error = " ".join( len(f'"{name}"="{value}"') * ("^" if name in attrs_in_error else " ") for name, value in element_attrs.items() ) if type_error.strip() or attrs_error.strip(): return f" {type_error} {attrs_error}" return None # For simple values else: if not self._matches(element, expected): return len(str(element)) * "^" return None @staticmethod def _matches(element, expected): if element == expected: return True elif isinstance(expected, Predicate): return expected.validate(element) else: return element == expected class ErrorComparisonOutput: def __init__(self, actual_error_output, expected_error_output): self.actual_error_output = actual_error_output self.expected_error_output = expected_error_output @staticmethod def adjust(to_adjust, reference): for index, ref_line in enumerate(reference): if "^^" in ref_line: # insert an empty line in to_adjust to_adjust.insert(index, "") return to_adjust def render(self): # init if needed if not self.actual_error_output.output: self.actual_error_output.compute() if not self.expected_error_output.output: self.expected_error_output.compute() actual = self.actual_error_output.output expected = self.expected_error_output.output # actual = self.adjust(actual, expected) # does not seem to be needed expected = self.adjust(expected, actual) actual_max_length = len(max(actual, key=len)) expected_max_length = len(max(expected, key=len)) output = [f"{' Actual ':=^{actual_max_length}} | {' Expected ':=^{expected_max_length}}"] for a, e in zip(actual, expected): line = f"{a:<{actual_max_length}} | {e}".rstrip() output.append(line) return "\n".join(output) class Matcher: """ Matcher class for comparing actual and expected elements. Provides flexible comparison with support for predicates, nested structures, and detailed error reporting. """ def __init__(self): self.path = "" def matches(self, actual, expected): """ Compare actual and expected elements. Args: actual: The actual element to compare expected: The expected element or pattern Returns: True if elements match, raises AssertionError otherwise """ if actual is not None and expected is None: self._assert_error("Actual is not None while expected is None", _actual=actual) if isinstance(expected, DoNotCheck): return True if actual is None and expected is not None: self._assert_error("Actual is None while expected is ", _expected=expected) # set the path current_path = self._get_current_path(actual) original_path = self.path self.path = self.path + "." + current_path if self.path else current_path try: self._match_elements(actual, expected) finally: # restore the original path for sibling comparisons self.path = original_path return True def _match_elements(self, actual, expected): """Internal method that performs the actual comparison logic.""" if isinstance(expected, TestObject) or hasattr(expected, "tag"): self._match_element(actual, expected) return if isinstance(expected, Predicate): assert expected.validate(actual), \ self._error_msg(f"The condition '{expected}' is not satisfied.", _actual=actual, _expected=expected) return assert _get_type(actual) == _get_type(expected), \ self._error_msg("The types are different.", _actual=actual, _expected=expected) if isinstance(expected, (list, tuple)): self._match_list(actual, expected) elif isinstance(expected, dict): self._match_dict(actual, expected) elif isinstance(expected, NotStr): self._match_notstr(actual, expected) else: assert actual == expected, self._error_msg("The values are different", _actual=actual, _expected=expected) def _match_element(self, actual, expected): """Match a TestObject or FT element.""" # Validate the type/tag assert _get_type(actual) == _get_type(expected), \ self._error_msg("The types are different.", _actual=_get_type(actual), _expected=_get_type(expected)) # Special conditions (ChildrenPredicate) expected_children = _get_children(expected) for predicate in [c for c in expected_children if isinstance(c, ChildrenPredicate)]: assert predicate.validate(actual), \ self._error_msg(f"The condition '{predicate}' is not satisfied.", _actual=actual, _expected=predicate.to_debug(expected)) # Compare the attributes expected_attrs = _get_attributes(expected) for expected_attr, expected_value in expected_attrs.items(): actual_value = _get_attr(actual, expected_attr) # Check if attribute exists if actual_value == MISSING_ATTR: self._assert_error(f"'{expected_attr}' is not found in Actual. (attributes: {self._str_attrs(actual)})", _actual=actual, _expected=expected) # Handle Predicate values if isinstance(expected_value, Predicate): assert expected_value.validate(actual_value), \ self._error_msg(f"The condition '{expected_value}' is not satisfied.", _actual=actual, _expected=expected) # Handle TestObject recursive matching elif isinstance(expected, TestObject): try: self.matches(actual_value, expected_value) except AssertionError as e: match = re.search(r"Error : (.+?)\n", str(e)) if match: self._assert_error(f"{match.group(1)} for '{expected_attr}'.", _actual=actual_value, _expected=expected_value) else: self._assert_error(f"The values are different for '{expected_attr}'.", _actual=actual_value, _expected=expected_value) # Handle regular value comparison else: assert actual_value == expected_value, \ self._error_msg(f"The values are different for '{expected_attr}'.", _actual=actual, _expected=expected) # Compare the children (only if present) if expected_children: # Filter out Predicate children expected_children = [c for c in expected_children if not isinstance(c, Predicate)] actual_children = _get_children(actual) if len(actual_children) < len(expected_children): self._assert_error("Actual is lesser than expected.", _actual=actual, _expected=expected) actual_child_index, expected_child_index = 0, 0 while expected_child_index < len(expected_children): if actual_child_index >= len(actual_children): self._assert_error("Nothing more to skip.", _actual=actual, _expected=expected) actual_child = actual_children[actual_child_index] expected_child = expected_children[expected_child_index] if isinstance(expected_child, Skip): try: # if this is the element to skip, skip it and continue self._match_element(actual_child, expected_child.element) actual_child_index += 1 continue except AssertionError: # otherwise try to match with the following element expected_child_index += 1 continue assert self.matches(actual_child, expected_child) actual_child_index += 1 expected_child_index += 1 def _match_list(self, actual, expected): """Match list or tuple.""" if len(actual) < len(expected): self._assert_error("Actual is smaller than expected: ", _actual=actual, _expected=expected) if len(actual) > len(expected): self._assert_error("Actual is bigger than expected: ", _actual=actual, _expected=expected) for actual_child, expected_child in zip(actual, expected): assert self.matches(actual_child, expected_child) def _match_dict(self, actual, expected): """Match dictionary.""" if len(actual) < len(expected): self._assert_error("Actual is smaller than expected: ", _actual=actual, _expected=expected) if len(actual) > len(expected): self._assert_error("Actual is bigger than expected: ", _actual=actual, _expected=expected) for k, v in expected.items(): assert self.matches(actual[k], v) def _match_notstr(self, actual, expected): """Match NotStr type.""" to_compare = _get_attr(actual, "s").lstrip('\n').lstrip() assert to_compare.startswith(expected.s), self._error_msg("Notstr values are different: ", _actual=to_compare, _expected=expected.s) def _print_path(self): """Format the current path for error messages.""" return f"Path : '{self.path}'\n" if self.path else "" def _debug_compare(self, a, b): """Generate a comparison debug output.""" actual_out = ErrorOutput(self.path, a, b) expected_out = ErrorOutput(self.path, b, b) comparison_out = ErrorComparisonOutput(actual_out, expected_out) return comparison_out.render() def _error_msg(self, msg, _actual=None, _expected=None): """Generate an error message with debug information.""" if _actual is None and _expected is None: debug_info = "" elif _actual is None: debug_info = self._debug(_expected) elif _expected is None: debug_info = self._debug(_actual) else: debug_info = self._debug_compare(_actual, _expected) return f"{self._print_path()}Error : {msg}\n{debug_info}" def _assert_error(self, msg, _actual=None, _expected=None): """Raise an assertion error with formatted message.""" assert False, self._error_msg(msg, _actual=_actual, _expected=_expected) @staticmethod def _get_current_path(elt): """Get the path representation of an element.""" if hasattr(elt, "tag"): res = f"{elt.tag}" if "id" in elt.attrs: res += f"#{elt.attrs['id']}" elif "name" in elt.attrs: res += f"[name={elt.attrs['name']}]" elif "class" in elt.attrs: res += f"[class={elt.attrs['class']}]" return res else: return elt.__class__.__name__ @staticmethod def _str_attrs(element): """Format attributes as a string.""" attrs = _get_attributes(element) return " ".join(f'"{attr_name}"="{attr_value}"' for attr_name, attr_value in attrs.items()) @staticmethod def _debug(elt): """Format an element for debug output.""" return _str_element(elt, keep_open=False) if elt else "None" def matches(actual, expected, path=""): """ Compare actual and expected elements. This is a convenience wrapper around the Matcher class. Args: actual: The actual element to compare expected: The expected element or pattern path: Optional initial path for error reporting Returns: True if elements match, raises AssertionError otherwise """ matcher = Matcher() matcher.path = path return matcher.matches(actual, expected) def find(ft, expected): """ Find all occurrences of an expected element within a FastHTML tree. Args: ft: A FastHTML element or list of elements to search in expected: The element pattern to find Returns: List of matching elements Raises: AssertionError: If no matching elements are found """ def _elements_match(actual, expected): """Check if two elements are the same based on tag, attributes, and children.""" if isinstance(expected, DoNotCheck): return True if isinstance(expected, NotStr): to_compare = _get_attr(actual, "s").lstrip('\n').lstrip() return to_compare.startswith(expected.s) if isinstance(actual, NotStr) and _get_type(actual) != _get_type(expected): return False # to manage the unexpected __eq__ behavior of NotStr if not isinstance(expected, (TestObject, FT)): return actual == expected # Compare tags if _get_type(actual) != _get_type(expected): return False # Compare attributes expected_attrs = _get_attributes(expected) for attr_name, expected_attr_value in expected_attrs.items(): actual_attr_value = _get_attr(actual, attr_name) # attribute is missing if actual_attr_value == MISSING_ATTR: return False # manage predicate values if isinstance(expected_attr_value, Predicate): return expected_attr_value.validate(actual_attr_value) # finally compare values return actual_attr_value == expected_attr_value # Compare children recursively expected_children = _get_children(expected) actual_children = _get_children(actual) for expected_child in expected_children: # Check if this expected child exists somewhere in actual children if not any(_elements_match(actual_child, expected_child) for actual_child in actual_children): return False return True def _search_tree(current, pattern): """Recursively search for pattern in the tree rooted at current.""" # Check if current element matches matches = [] if _elements_match(current, pattern): matches.append(current) # Recursively search in children, in the case that the pattern also appears in children for child in _get_children(current): matches.extend(_search_tree(child, pattern)) return matches # Normalize input to list elements_to_search = ft if isinstance(ft, (list, tuple, set)) else [ft] # Search in all provided elements all_matches = [] for element in elements_to_search: all_matches.extend(_search_tree(element, expected)) return all_matches def find_one(ft, expected): found = find(ft, expected) assert len(found) == 1, f"Found {len(found)} elements for '{expected}'" return found[0] def _str_attrs(attrs: dict): return " ".join(f'"{attr_name}"="{attr_value}"' for attr_name, attr_value in attrs.items())