import re from dataclasses import dataclass from typing import Optional from fastcore.basics import NotStr from fastcore.xml import FT from myfasthtml.core.utils import quoted_str, snake_to_pascal from myfasthtml.test.testclient import MyFT MISSING_ATTR = "** MISSING **" class Predicate: def __init__(self, value): self.value = value def validate(self, actual): raise NotImplementedError def __str__(self): if self.value is None: str_value = '' elif isinstance(self.value, str): str_value = self.value elif isinstance(self.value, (list, tuple)): if len(self.value) == 1: str_value = self.value[0] else: str_value = str(self.value) else: str_value = str(self.value) return f"{self.__class__.__name__}({str_value})" def __repr__(self): return f"{self.__class__.__name__}({self.value if self.value is not None else ''})" def __eq__(self, other): if type(self) is not type(other): return False return self.value == other.value def __hash__(self): return hash(self.value) class AttrPredicate(Predicate): """ Predicate that validates an attribute value. It's given as a value of an attribute. """ pass class StartsWith(AttrPredicate): def __init__(self, value): super().__init__(value) def validate(self, actual): return actual.startswith(self.value) class EndsWith(AttrPredicate): def __init__(self, value): super().__init__(value) def validate(self, actual): return actual.endswith(self.value) class Contains(AttrPredicate): def __init__(self, *value): super().__init__(value) def validate(self, actual): return all(val in actual for val in self.value) class DoesNotContain(AttrPredicate): def __init__(self, *value): super().__init__(value) def validate(self, actual): return all(val not in actual for val in self.value) class AnyValue(AttrPredicate): """ True is the attribute is present and the value is not None. """ def __init__(self): super().__init__(None) def validate(self, actual): return actual is not None class Regex(AttrPredicate): def __init__(self, pattern): super().__init__(pattern) def validate(self, actual): return re.match(self.value, actual) is not None class ChildrenPredicate(Predicate): """ Predicate given as a child of an element. """ def to_debug(self, element): return element class Empty(ChildrenPredicate): def __init__(self): super().__init__(None) def validate(self, actual): return len(actual.children) == 0 and len(actual.attrs) == 0 class NoChildren(ChildrenPredicate): def __init__(self): super().__init__(None) def validate(self, actual): return len(actual.children) == 0 class AttributeForbidden(ChildrenPredicate): """ To validate that an attribute is not present in an element. """ def __init__(self, value): super().__init__(value) def validate(self, actual): return self.value not in actual.attrs or actual.attrs[self.value] is None def to_debug(self, element): element.attrs[self.value] = "** NOT ALLOWED **" return element class TestObject: def __init__(self, cls, **kwargs): self.cls = cls self.attrs = kwargs class TestIcon(TestObject): def __init__(self, name: Optional[str] = ''): super().__init__("div") self.name = snake_to_pascal(name) if (name and name[0].islower()) else name self.children = [ TestObject(NotStr, s=Regex(f'' class TestCommand(TestObject): def __init__(self, name, **kwargs): super().__init__("Command", **kwargs) self.attrs = {"name": name} | kwargs # name should be first class TestScript(TestObject): def __init__(self, script): super().__init__("script") self.script = script self.children = [ NotStr(self.script), ] @dataclass class DoNotCheck: desc: str = None def _get_type(x): if hasattr(x, "tag"): return x.tag if isinstance(x, (TestObject, TestIcon)): return x.cls.__name__ if isinstance(x.cls, type) else str(x.cls) return type(x).__name__ def _get_attr(x, attr): if hasattr(x, "attrs"): return x.attrs.get(attr, MISSING_ATTR) if not hasattr(x, attr): return MISSING_ATTR return getattr(x, attr, MISSING_ATTR) def _get_attributes(x): """Get the attributes dict from an element.""" if hasattr(x, "attrs"): return x.attrs return {} def _get_children(x): """Get the children list from an element.""" if hasattr(x, "children"): return x.children return [] class ErrorOutput: def __init__(self, path, element, expected): self.path = path self.element = element self.expected = expected self.output = [] self.indent = "" @staticmethod def _unconstruct_path_item(item): if "#" in item: elt_name, elt_id = item.split("#") return elt_name, "id", elt_id elif match := re.match(r'(\w+)\[(class|name)=([^]]+)]', item): return match.groups() return item, None, None def __str__(self): return f"ErrorOutput({self.output})" def compute(self): # first render the path hierarchy for p in self.path.split(".")[:-1]: elt_name, attr_name, attr_value = self._unconstruct_path_item(p) path_str = self._str_element(MyFT(elt_name, {attr_name: attr_value}), keep_open=True) self._add_to_output(f"{path_str}") self.indent += " " # then render the element if hasattr(self.expected, "tag") and hasattr(self.element, "tag"): # display the tag and its attributes tag_str = self._str_element(self.element, self.expected) self._add_to_output(tag_str) # Try to show where the differences are error_str = self._detect_error(self.element, self.expected) if error_str: self._add_to_output(error_str) # render the children expected_children = [c for c in self.expected.children if not isinstance(c, ChildrenPredicate)] if len(expected_children) > 0: self.indent += " " element_index = 0 for expected_child in expected_children: if element_index >= len(self.element.children): # When there are fewer children than expected, we display a placeholder child_str = "! ** MISSING ** !" self._add_to_output(child_str) element_index += 1 continue # display the child element_child = self.element.children[element_index] child_str = self._str_element(element_child, expected_child, keep_open=False) self._add_to_output(child_str) # manage errors (only when the expected is a FT element if hasattr(expected_child, "tag"): child_error_str = self._detect_error(element_child, expected_child) if child_error_str: self._add_to_output(child_error_str) # continue element_index += 1 self.indent = self.indent[:-2] self._add_to_output(")") elif isinstance(self.expected, TestObject): cls = _get_type(self.element) attrs = {attr_name: _get_attr(self.element, attr_name) for attr_name in self.expected.attrs} self._add_to_output(f"({cls} {_str_attrs(attrs)})") # Try to show where the differences are error_str = self._detect_error(self.element, self.expected) if error_str: self._add_to_output(error_str) else: self._add_to_output(str(self.element)) # Try to show where the differences are error_str = self._detect_error(self.element, self.expected) if error_str: self._add_to_output(error_str) def _add_to_output(self, msg): self.output.append(f"{self.indent}{msg}") @staticmethod def _str_element(element, expected=None, keep_open=None): # compare to itself if no expected element is provided if expected is None: expected = element if hasattr(element, "tag"): # the attributes are compared to the expected element elt_attrs = {attr_name: element.attrs.get(attr_name, MISSING_ATTR) for attr_name in [attr_name for attr_name in expected.attrs if attr_name is not None]} elt_attrs_str = " ".join(f'"{attr_name}"="{attr_value}"' for attr_name, attr_value in elt_attrs.items()) tag_str = f"({element.tag} {elt_attrs_str}" # manage the closing tag if keep_open is False: tag_str += " ...)" if len(element.children) > 0 else ")" elif keep_open is True: tag_str += "..." if elt_attrs_str == "" else " ..." else: # close the tag if there are no children not_special_children = [c for c in element.children if not isinstance(c, Predicate)] if len(not_special_children) == 0: tag_str += ")" return tag_str else: return quoted_str(element) def _detect_error(self, element, expected): """ Detect errors between element and expected, returning a visual marker string. Unified version that handles both FT elements and TestObjects. """ # For elements with structure (FT or TestObject) if hasattr(expected, "tag") or isinstance(expected, TestObject): element_type = _get_type(element) expected_type = _get_type(expected) type_error = (" " if element_type == expected_type else "^") * len(element_type) element_attrs = {attr_name: _get_attr(element, attr_name) for attr_name in _get_attributes(expected)} expected_attrs = {attr_name: _get_attr(expected, attr_name) for attr_name in _get_attributes(expected)} attrs_in_error = {attr_name for attr_name, attr_value in element_attrs.items() if not self._matches(attr_value, expected_attrs[attr_name])} attrs_error = " ".join( len(f'"{name}"="{value}"') * ("^" if name in attrs_in_error else " ") for name, value in element_attrs.items() ) if type_error.strip() or attrs_error.strip(): return f" {type_error} {attrs_error}" return None # For simple values else: if not self._matches(element, expected): return len(str(element)) * "^" return None @staticmethod def _matches(element, expected): if element == expected: return True elif isinstance(expected, Predicate): return expected.validate(element) else: return element == expected class ErrorComparisonOutput: def __init__(self, actual_error_output, expected_error_output): self.actual_error_output = actual_error_output self.expected_error_output = expected_error_output @staticmethod def adjust(to_adjust, reference): for index, ref_line in enumerate(reference): if "^^" in ref_line: # insert an empty line in to_adjust to_adjust.insert(index, "") return to_adjust def render(self): # init if needed if not self.actual_error_output.output: self.actual_error_output.compute() if not self.expected_error_output.output: self.expected_error_output.compute() actual = self.actual_error_output.output expected = self.expected_error_output.output # actual = self.adjust(actual, expected) # does not seem to be needed expected = self.adjust(expected, actual) actual_max_length = len(max(actual, key=len)) # expected_max_length = len(max(expected, key=len)) output = [] for a, e in zip(actual, expected): line = f"{a:<{actual_max_length}} | {e}".rstrip() output.append(line) return "\n".join(output) class Matcher: """ Matcher class for comparing actual and expected elements. Provides flexible comparison with support for predicates, nested structures, and detailed error reporting. """ def __init__(self): self.path = "" def matches(self, actual, expected): """ Compare actual and expected elements. Args: actual: The actual element to compare expected: The expected element or pattern Returns: True if elements match, raises AssertionError otherwise """ if actual is not None and expected is None: self._assert_error("Actual is not None while expected is None", _actual=actual) if isinstance(expected, DoNotCheck): return True if actual is None and expected is not None: self._assert_error("Actual is None while expected is ", _expected=expected) # set the path current_path = self._get_current_path(actual) original_path = self.path self.path = self.path + "." + current_path if self.path else current_path try: self._match_elements(actual, expected) finally: # restore the original path for sibling comparisons self.path = original_path return True def _match_elements(self, actual, expected): """Internal method that performs the actual comparison logic.""" if isinstance(expected, TestObject) or hasattr(expected, "tag"): self._match_element(actual, expected) return if isinstance(expected, Predicate): assert expected.validate(actual), \ self._error_msg(f"The condition '{expected}' is not satisfied.", _actual=actual, _expected=expected) return assert _get_type(actual) == _get_type(expected), \ self._error_msg("The types are different.", _actual=actual, _expected=expected) if isinstance(expected, (list, tuple)): self._match_list(actual, expected) elif isinstance(expected, dict): self._match_dict(actual, expected) elif isinstance(expected, NotStr): self._match_notstr(actual, expected) else: assert actual == expected, self._error_msg("The values are different", _actual=actual, _expected=expected) def _match_element(self, actual, expected): """Match a TestObject or FT element.""" # Validate the type/tag assert _get_type(actual) == _get_type(expected), \ self._error_msg("The types are different.", _actual=_get_type(actual), _expected=_get_type(expected)) # Special conditions (ChildrenPredicate) expected_children = _get_children(expected) for predicate in [c for c in expected_children if isinstance(c, ChildrenPredicate)]: assert predicate.validate(actual), \ self._error_msg(f"The condition '{predicate}' is not satisfied.", _actual=actual, _expected=predicate.to_debug(expected)) # Compare the attributes expected_attrs = _get_attributes(expected) for expected_attr, expected_value in expected_attrs.items(): actual_value = _get_attr(actual, expected_attr) # Check if attribute exists if actual_value == MISSING_ATTR: self._assert_error(f"'{expected_attr}' is not found in Actual. (attributes: {self._str_attrs(actual)})", _actual=actual, _expected=expected) # Handle Predicate values if isinstance(expected_value, Predicate): assert expected_value.validate(actual_value), \ self._error_msg(f"The condition '{expected_value}' is not satisfied.", _actual=actual, _expected=expected) # Handle TestObject recursive matching elif isinstance(expected, TestObject): try: self.matches(actual_value, expected_value) except AssertionError as e: match = re.search(r"Error : (.+?)\n", str(e)) if match: self._assert_error(f"{match.group(1)} for '{expected_attr}'.", _actual=actual_value, _expected=expected_value) else: self._assert_error(f"The values are different for '{expected_attr}'.", _actual=actual_value, _expected=expected_value) # Handle regular value comparison else: assert actual_value == expected_value, \ self._error_msg(f"The values are different for '{expected_attr}'.", _actual=actual, _expected=expected) # Compare the children (only if present) if expected_children: # Filter out Predicate children expected_children = [c for c in expected_children if not isinstance(c, Predicate)] actual_children = _get_children(actual) if len(actual_children) < len(expected_children): self._assert_error("Actual is lesser than expected.", _actual=actual, _expected=expected) for actual_child, expected_child in zip(actual_children, expected_children): assert self.matches(actual_child, expected_child) def _match_list(self, actual, expected): """Match list or tuple.""" if len(actual) < len(expected): self._assert_error("Actual is smaller than expected: ", _actual=actual, _expected=expected) if len(actual) > len(expected): self._assert_error("Actual is bigger than expected: ", _actual=actual, _expected=expected) for actual_child, expected_child in zip(actual, expected): assert self.matches(actual_child, expected_child) def _match_dict(self, actual, expected): """Match dictionary.""" if len(actual) < len(expected): self._assert_error("Actual is smaller than expected: ", _actual=actual, _expected=expected) if len(actual) > len(expected): self._assert_error("Actual is bigger than expected: ", _actual=actual, _expected=expected) for k, v in expected.items(): assert self.matches(actual[k], v) def _match_notstr(self, actual, expected): """Match NotStr type.""" to_compare = _get_attr(actual, "s").lstrip('\n').lstrip() assert to_compare.startswith(expected.s), self._error_msg("Notstr values are different: ", _actual=to_compare, _expected=expected.s) def _print_path(self): """Format the current path for error messages.""" return f"Path : '{self.path}'\n" if self.path else "" def _debug_compare(self, a, b): """Generate a comparison debug output.""" actual_out = ErrorOutput(self.path, a, b) expected_out = ErrorOutput(self.path, b, b) comparison_out = ErrorComparisonOutput(actual_out, expected_out) return comparison_out.render() def _error_msg(self, msg, _actual=None, _expected=None): """Generate an error message with debug information.""" if _actual is None and _expected is None: debug_info = "" elif _actual is None: debug_info = self._debug(_expected) elif _expected is None: debug_info = self._debug(_actual) else: debug_info = self._debug_compare(_actual, _expected) return f"{self._print_path()}Error : {msg}\n{debug_info}" def _assert_error(self, msg, _actual=None, _expected=None): """Raise an assertion error with formatted message.""" assert False, self._error_msg(msg, _actual=_actual, _expected=_expected) @staticmethod def _get_current_path(elt): """Get the path representation of an element.""" if hasattr(elt, "tag"): res = f"{elt.tag}" if "id" in elt.attrs: res += f"#{elt.attrs['id']}" elif "name" in elt.attrs: res += f"[name={elt.attrs['name']}]" elif "class" in elt.attrs: res += f"[class={elt.attrs['class']}]" return res else: return elt.__class__.__name__ @staticmethod def _str_attrs(element): """Format attributes as a string.""" attrs = _get_attributes(element) return " ".join(f'"{attr_name}"="{attr_value}"' for attr_name, attr_value in attrs.items()) @staticmethod def _debug(elt): """Format an element for debug output.""" return str(elt) if elt else "None" def matches(actual, expected, path=""): """ Compare actual and expected elements. This is a convenience wrapper around the Matcher class. Args: actual: The actual element to compare expected: The expected element or pattern path: Optional initial path for error reporting Returns: True if elements match, raises AssertionError otherwise """ matcher = Matcher() matcher.path = path return matcher.matches(actual, expected) def find(ft, expected): """ Find all occurrences of an expected element within a FastHTML tree. Args: ft: A FastHTML element or list of elements to search in expected: The element pattern to find Returns: List of matching elements Raises: AssertionError: If no matching elements are found """ def _elements_match(actual, expected): """Check if two elements are the same based on tag, attributes, and children.""" if isinstance(expected, DoNotCheck): return True if isinstance(expected, NotStr): to_compare = _get_attr(actual, "s").lstrip('\n').lstrip() return to_compare.startswith(expected.s) if isinstance(actual, NotStr) and _get_type(actual) != _get_type(expected): return False # to manage the unexpected __eq__ behavior of NotStr if not isinstance(expected, (TestObject, FT)): return actual == expected # Compare tags if _get_type(actual) != _get_type(expected): return False # Compare attributes expected_attrs = _get_attributes(expected) for attr_name, expected_attr_value in expected_attrs.items(): actual_attr_value = _get_attr(actual, attr_name) # attribute is missing if actual_attr_value == MISSING_ATTR: return False # manage predicate values if isinstance(expected_attr_value, Predicate): return expected_attr_value.validate(actual_attr_value) # finally compare values return actual_attr_value == expected_attr_value # Compare children recursively expected_children = _get_children(expected) actual_children = _get_children(actual) for expected_child in expected_children: # Check if this expected child exists somewhere in actual children if not any(_elements_match(actual_child, expected_child) for actual_child in actual_children): return False return True def _search_tree(current, pattern): """Recursively search for pattern in the tree rooted at current.""" # Check if current element matches matches = [] if _elements_match(current, pattern): matches.append(current) # Recursively search in children, in the case that the pattern also appears in children for child in _get_children(current): matches.extend(_search_tree(child, pattern)) return matches # Normalize input to list elements_to_search = ft if isinstance(ft, (list, tuple, set)) else [ft] # Search in all provided elements all_matches = [] for element in elements_to_search: all_matches.extend(_search_tree(element, expected)) # Raise error if nothing found if not all_matches: raise AssertionError(f"No element found for '{expected}'") return all_matches def find_one(ft, expected): found = find(ft, expected) assert len(found) == 1, f"Found {len(found)} elements for '{expected}'" return found[0] def _str_attrs(attrs: dict): return " ".join(f'"{attr_name}"="{attr_value}"' for attr_name, attr_value in attrs.items())