Files
MyFastHtml/src/myfasthtml/test/matcher.py

813 lines
26 KiB
Python

import re
from dataclasses import dataclass
from typing import Optional, Any
from fastcore.basics import NotStr
from fastcore.xml import FT
from myfasthtml.core.commands import Command
from myfasthtml.core.utils import quoted_str, snake_to_pascal
from myfasthtml.test.testclient import MyFT
MISSING_ATTR = "** MISSING **"
class Predicate:
def __init__(self, value):
self.value = value
def validate(self, actual):
raise NotImplementedError
def __str__(self):
if self.value is None:
str_value = ''
elif isinstance(self.value, str):
str_value = self.value
elif isinstance(self.value, (list, tuple)):
if len(self.value) == 1:
str_value = self.value[0]
else:
str_value = str(self.value)
else:
str_value = str(self.value)
return f"{self.__class__.__name__}({str_value})"
def __repr__(self):
return f"{self.__class__.__name__}({self.value if self.value is not None else ''})"
def __eq__(self, other):
if type(self) is not type(other):
return False
return self.value == other.value
def __hash__(self):
return hash(self.value)
class AttrPredicate(Predicate):
"""
Predicate that validates an attribute value.
It's given as a value of an attribute.
"""
pass
class StartsWith(AttrPredicate):
def __init__(self, value):
super().__init__(value)
def validate(self, actual):
return actual.startswith(self.value)
class EndsWith(AttrPredicate):
def __init__(self, value):
super().__init__(value)
def validate(self, actual):
return actual.endswith(self.value)
class Contains(AttrPredicate):
def __init__(self, *value, _word=False):
super().__init__(value)
self._word = _word
def validate(self, actual):
if self._word:
words = actual.split()
return all(val in words for val in self.value)
else:
return all(val in actual for val in self.value)
class DoesNotContain(AttrPredicate):
def __init__(self, *value):
super().__init__(value)
def validate(self, actual):
return all(val not in actual for val in self.value)
class AnyValue(AttrPredicate):
"""
True is the attribute is present and the value is not None.
"""
def __init__(self):
super().__init__(None)
def validate(self, actual):
return actual is not None
class Regex(AttrPredicate):
def __init__(self, pattern):
super().__init__(pattern)
def validate(self, actual):
return re.match(self.value, actual) is not None
class And(AttrPredicate):
def __init__(self, *predicates):
super().__init__(predicates)
def validate(self, actual):
return all(p.validate(actual) for p in self.value)
class ChildrenPredicate(Predicate):
"""
Predicate given as a child of an element.
"""
def to_debug(self, element):
return element
class Empty(ChildrenPredicate):
def __init__(self):
super().__init__(None)
def validate(self, actual):
return len(actual.children) == 0 and len(actual.attrs) == 0
class NoChildren(ChildrenPredicate):
def __init__(self):
super().__init__(None)
def validate(self, actual):
return len(actual.children) == 0
class AttributeForbidden(ChildrenPredicate):
"""
To validate that an attribute is not present in an element.
"""
def __init__(self, value):
super().__init__(value)
def validate(self, actual):
return self.value not in actual.attrs or actual.attrs[self.value] is None
def to_debug(self, element):
element.attrs[self.value] = "** NOT ALLOWED **"
return element
class HasHtmx(ChildrenPredicate):
def __init__(self, command: Command = None, **htmx_params):
super().__init__(None)
self.command = command
if command:
self.htmx_params = command.get_htmx_params() | htmx_params
else:
self.htmx_params = htmx_params
self.htmx_params = {k.replace("hx_", "hx-"): v for k, v in self.htmx_params.items()}
def validate(self, actual):
return all(actual.attrs.get(k) == v for k, v in self.htmx_params.items())
def to_debug(self, element):
for k, v in self.htmx_params.items():
element.attrs[k] = v
return element
class TestObject:
def __init__(self, cls, **kwargs):
self.cls = cls
self.attrs = kwargs
class TestIcon(TestObject):
def __init__(self, name: Optional[str] = '', command=None):
super().__init__("div")
self.name = snake_to_pascal(name) if (name and name[0].islower()) else name
self.children = [
TestObject(NotStr, s=Regex(f'<svg name="\\w+-{self.name}'))
]
if command:
self.attrs |= command.get_htmx_params()
def __str__(self):
return f'<div><svg name="{self.name}" .../></div>'
class TestIconNotStr(TestObject):
def __init__(self, name: Optional[str] = ''):
super().__init__(NotStr)
self.name = snake_to_pascal(name) if (name and name[0].islower()) else name
self.attrs["s"] = Regex(f'<svg name="\\w+-{self.name}')
def __str__(self):
return f'<svg name="{self.name}" .../>'
class TestCommand(TestObject):
def __init__(self, name, **kwargs):
super().__init__("Command", **kwargs)
self.attrs = {"name": name} | kwargs # name should be first
class TestScript(TestObject):
def __init__(self, script):
super().__init__("script")
self.script = script
self.children = [
NotStr(self.script),
]
@dataclass
class DoNotCheck:
desc: str = None
@dataclass
class Skip:
element: Any
desc: str = None
def _get_type(x):
if hasattr(x, "tag"):
return x.tag
if isinstance(x, (TestObject, TestIcon)):
return x.cls.__name__ if isinstance(x.cls, type) else str(x.cls)
return type(x).__name__
def _get_attr(x, attr):
if hasattr(x, "attrs"):
return x.attrs.get(attr, MISSING_ATTR)
if not hasattr(x, attr):
return MISSING_ATTR
return getattr(x, attr, MISSING_ATTR)
def _get_attributes(x):
"""Get the attributes dict from an element."""
if hasattr(x, "attrs"):
return x.attrs
return {}
def _get_children(x):
"""Get the children list from an element."""
if hasattr(x, "children"):
return x.children
return []
def _str_element(element, expected=None, keep_open=None):
# compare to itself if no expected element is provided
if expected is None:
expected = element
if hasattr(element, "tag"):
# the attributes are compared to the expected element
elt_attrs = {attr_name: _get_attr(element, attr_name) for attr_name in
[attr_name for attr_name in _get_attributes(expected) if attr_name is not None]}
elt_attrs_str = " ".join(f'"{attr_name}"="{attr_value}"' for attr_name, attr_value in elt_attrs.items())
tag_str = f"({element.tag} {elt_attrs_str}"
# manage the closing tag
if keep_open is False:
tag_str += " ...)" if len(element.children) > 0 else ")"
elif keep_open is True:
tag_str += "..." if elt_attrs_str == "" else " ..."
else:
# close the tag if there are no children
not_special_children = [c for c in element.children if not isinstance(c, Predicate)]
if len(not_special_children) == 0: tag_str += ")"
return tag_str
else:
return quoted_str(element)
class ErrorOutput:
def __init__(self, path, element, expected):
self.path = path
self.element = element
self.expected = expected
self.output = []
self.indent = ""
@staticmethod
def _unconstruct_path_item(item):
if "#" in item:
elt_name, elt_id = item.split("#")
return elt_name, "id", elt_id
elif match := re.match(r'(\w+)\[(class|name)=([^]]+)]', item):
return match.groups()
return item, None, None
def __str__(self):
return f"ErrorOutput({self.output})"
def compute(self):
# first render the path hierarchy
for p in self.path.split(".")[:-1]:
elt_name, attr_name, attr_value = self._unconstruct_path_item(p)
path_str = _str_element(MyFT(elt_name, {attr_name: attr_value}), keep_open=True)
self._add_to_output(f"{path_str}")
self.indent += " "
# then render the element
if hasattr(self.expected, "tag") and hasattr(self.element, "tag"):
# display the tag and its attributes
tag_str = _str_element(self.element, self.expected)
self._add_to_output(tag_str)
# Try to show where the differences are
error_str = self._detect_error(self.element, self.expected)
if error_str:
self._add_to_output(error_str)
# render the children
expected_children = [c for c in self.expected.children if not isinstance(c, ChildrenPredicate)]
if len(expected_children) > 0:
self.indent += " "
element_index = 0
for expected_child in expected_children:
if element_index >= len(self.element.children):
# When there are fewer children than expected, we display a placeholder
child_str = "! ** MISSING ** !"
self._add_to_output(child_str)
element_index += 1
continue
# display the child
element_child = self.element.children[element_index]
child_str = _str_element(element_child, expected_child, keep_open=False)
self._add_to_output(child_str)
# manage errors (only when the expected is a FT element
if hasattr(expected_child, "tag"):
child_error_str = self._detect_error(element_child, expected_child)
if child_error_str:
self._add_to_output(child_error_str)
# continue
element_index += 1
self.indent = self.indent[:-2]
self._add_to_output(")")
elif isinstance(self.expected, TestObject):
cls = _get_type(self.element)
attrs = {attr_name: _get_attr(self.element, attr_name) for attr_name in self.expected.attrs}
self._add_to_output(f"({cls} {_str_attrs(attrs)})")
# Try to show where the differences are
error_str = self._detect_error(self.element, self.expected)
if error_str:
self._add_to_output(error_str)
else:
self._add_to_output(str(self.element))
# Try to show where the differences are
error_str = self._detect_error(self.element, self.expected)
if error_str:
self._add_to_output(error_str)
def _add_to_output(self, msg):
self.output.append(f"{self.indent}{msg}")
def _detect_error(self, element, expected):
"""
Detect errors between element and expected, returning a visual marker string.
Unified version that handles both FT elements and TestObjects.
"""
# For elements with structure (FT or TestObject)
if hasattr(expected, "tag") or isinstance(expected, TestObject):
element_type = _get_type(element)
expected_type = _get_type(expected)
type_error = (" " if element_type == expected_type else "^") * len(element_type)
element_attrs = {attr_name: _get_attr(element, attr_name) for attr_name in _get_attributes(expected)}
expected_attrs = {attr_name: _get_attr(expected, attr_name) for attr_name in _get_attributes(expected)}
attrs_in_error = {attr_name for attr_name, attr_value in element_attrs.items() if
not self._matches(attr_value, expected_attrs[attr_name])}
attrs_error = " ".join(
len(f'"{name}"="{value}"') * ("^" if name in attrs_in_error else " ")
for name, value in element_attrs.items()
)
if type_error.strip() or attrs_error.strip():
return f" {type_error} {attrs_error}"
return None
# For simple values
else:
if not self._matches(element, expected):
return len(str(element)) * "^"
return None
@staticmethod
def _matches(element, expected):
if element == expected:
return True
elif isinstance(expected, Predicate):
return expected.validate(element)
else:
return element == expected
class ErrorComparisonOutput:
def __init__(self, actual_error_output, expected_error_output):
self.actual_error_output = actual_error_output
self.expected_error_output = expected_error_output
@staticmethod
def adjust(to_adjust, reference):
for index, ref_line in enumerate(reference):
if "^^" in ref_line:
# insert an empty line in to_adjust
to_adjust.insert(index, "")
return to_adjust
def render(self):
# init if needed
if not self.actual_error_output.output:
self.actual_error_output.compute()
if not self.expected_error_output.output:
self.expected_error_output.compute()
actual = self.actual_error_output.output
expected = self.expected_error_output.output
# actual = self.adjust(actual, expected) # does not seem to be needed
expected = self.adjust(expected, actual)
actual_max_length = len(max(actual, key=len))
# expected_max_length = len(max(expected, key=len))
output = []
for a, e in zip(actual, expected):
line = f"{a:<{actual_max_length}} | {e}".rstrip()
output.append(line)
return "\n".join(output)
class Matcher:
"""
Matcher class for comparing actual and expected elements.
Provides flexible comparison with support for predicates, nested structures,
and detailed error reporting.
"""
def __init__(self):
self.path = ""
def matches(self, actual, expected):
"""
Compare actual and expected elements.
Args:
actual: The actual element to compare
expected: The expected element or pattern
Returns:
True if elements match, raises AssertionError otherwise
"""
if actual is not None and expected is None:
self._assert_error("Actual is not None while expected is None", _actual=actual)
if isinstance(expected, DoNotCheck):
return True
if actual is None and expected is not None:
self._assert_error("Actual is None while expected is ", _expected=expected)
# set the path
current_path = self._get_current_path(actual)
original_path = self.path
self.path = self.path + "." + current_path if self.path else current_path
try:
self._match_elements(actual, expected)
finally:
# restore the original path for sibling comparisons
self.path = original_path
return True
def _match_elements(self, actual, expected):
"""Internal method that performs the actual comparison logic."""
if isinstance(expected, TestObject) or hasattr(expected, "tag"):
self._match_element(actual, expected)
return
if isinstance(expected, Predicate):
assert expected.validate(actual), \
self._error_msg(f"The condition '{expected}' is not satisfied.",
_actual=actual,
_expected=expected)
return
assert _get_type(actual) == _get_type(expected), \
self._error_msg("The types are different.", _actual=actual, _expected=expected)
if isinstance(expected, (list, tuple)):
self._match_list(actual, expected)
elif isinstance(expected, dict):
self._match_dict(actual, expected)
elif isinstance(expected, NotStr):
self._match_notstr(actual, expected)
else:
assert actual == expected, self._error_msg("The values are different",
_actual=actual,
_expected=expected)
def _match_element(self, actual, expected):
"""Match a TestObject or FT element."""
# Validate the type/tag
assert _get_type(actual) == _get_type(expected), \
self._error_msg("The types are different.", _actual=_get_type(actual), _expected=_get_type(expected))
# Special conditions (ChildrenPredicate)
expected_children = _get_children(expected)
for predicate in [c for c in expected_children if isinstance(c, ChildrenPredicate)]:
assert predicate.validate(actual), \
self._error_msg(f"The condition '{predicate}' is not satisfied.",
_actual=actual,
_expected=predicate.to_debug(expected))
# Compare the attributes
expected_attrs = _get_attributes(expected)
for expected_attr, expected_value in expected_attrs.items():
actual_value = _get_attr(actual, expected_attr)
# Check if attribute exists
if actual_value == MISSING_ATTR:
self._assert_error(f"'{expected_attr}' is not found in Actual. (attributes: {self._str_attrs(actual)})",
_actual=actual,
_expected=expected)
# Handle Predicate values
if isinstance(expected_value, Predicate):
assert expected_value.validate(actual_value), \
self._error_msg(f"The condition '{expected_value}' is not satisfied.",
_actual=actual,
_expected=expected)
# Handle TestObject recursive matching
elif isinstance(expected, TestObject):
try:
self.matches(actual_value, expected_value)
except AssertionError as e:
match = re.search(r"Error : (.+?)\n", str(e))
if match:
self._assert_error(f"{match.group(1)} for '{expected_attr}'.",
_actual=actual_value,
_expected=expected_value)
else:
self._assert_error(f"The values are different for '{expected_attr}'.",
_actual=actual_value,
_expected=expected_value)
# Handle regular value comparison
else:
assert actual_value == expected_value, \
self._error_msg(f"The values are different for '{expected_attr}'.",
_actual=actual,
_expected=expected)
# Compare the children (only if present)
if expected_children:
# Filter out Predicate children
expected_children = [c for c in expected_children if not isinstance(c, Predicate)]
actual_children = _get_children(actual)
if len(actual_children) < len(expected_children):
self._assert_error("Actual is lesser than expected.", _actual=actual, _expected=expected)
actual_child_index, expected_child_index = 0, 0
while expected_child_index < len(expected_children):
if actual_child_index >= len(actual_children):
self._assert_error("Nothing more to skip.", _actual=actual, _expected=expected)
actual_child = actual_children[actual_child_index]
expected_child = expected_children[expected_child_index]
if isinstance(expected_child, Skip):
try:
# if this is the element to skip, skip it and continue
self._match_element(actual_child, expected_child.element)
actual_child_index += 1
continue
except AssertionError:
# otherwise try to match with the following element
expected_child_index += 1
continue
assert self.matches(actual_child, expected_child)
actual_child_index += 1
expected_child_index += 1
def _match_list(self, actual, expected):
"""Match list or tuple."""
if len(actual) < len(expected):
self._assert_error("Actual is smaller than expected: ", _actual=actual, _expected=expected)
if len(actual) > len(expected):
self._assert_error("Actual is bigger than expected: ", _actual=actual, _expected=expected)
for actual_child, expected_child in zip(actual, expected):
assert self.matches(actual_child, expected_child)
def _match_dict(self, actual, expected):
"""Match dictionary."""
if len(actual) < len(expected):
self._assert_error("Actual is smaller than expected: ", _actual=actual, _expected=expected)
if len(actual) > len(expected):
self._assert_error("Actual is bigger than expected: ", _actual=actual, _expected=expected)
for k, v in expected.items():
assert self.matches(actual[k], v)
def _match_notstr(self, actual, expected):
"""Match NotStr type."""
to_compare = _get_attr(actual, "s").lstrip('\n').lstrip()
assert to_compare.startswith(expected.s), self._error_msg("Notstr values are different: ",
_actual=to_compare,
_expected=expected.s)
def _print_path(self):
"""Format the current path for error messages."""
return f"Path : '{self.path}'\n" if self.path else ""
def _debug_compare(self, a, b):
"""Generate a comparison debug output."""
actual_out = ErrorOutput(self.path, a, b)
expected_out = ErrorOutput(self.path, b, b)
comparison_out = ErrorComparisonOutput(actual_out, expected_out)
return comparison_out.render()
def _error_msg(self, msg, _actual=None, _expected=None):
"""Generate an error message with debug information."""
if _actual is None and _expected is None:
debug_info = ""
elif _actual is None:
debug_info = self._debug(_expected)
elif _expected is None:
debug_info = self._debug(_actual)
else:
debug_info = self._debug_compare(_actual, _expected)
return f"{self._print_path()}Error : {msg}\n{debug_info}"
def _assert_error(self, msg, _actual=None, _expected=None):
"""Raise an assertion error with formatted message."""
assert False, self._error_msg(msg, _actual=_actual, _expected=_expected)
@staticmethod
def _get_current_path(elt):
"""Get the path representation of an element."""
if hasattr(elt, "tag"):
res = f"{elt.tag}"
if "id" in elt.attrs:
res += f"#{elt.attrs['id']}"
elif "name" in elt.attrs:
res += f"[name={elt.attrs['name']}]"
elif "class" in elt.attrs:
res += f"[class={elt.attrs['class']}]"
return res
else:
return elt.__class__.__name__
@staticmethod
def _str_attrs(element):
"""Format attributes as a string."""
attrs = _get_attributes(element)
return " ".join(f'"{attr_name}"="{attr_value}"' for attr_name, attr_value in attrs.items())
@staticmethod
def _debug(elt):
"""Format an element for debug output."""
return _str_element(elt, keep_open=False) if elt else "None"
def matches(actual, expected, path=""):
"""
Compare actual and expected elements.
This is a convenience wrapper around the Matcher class.
Args:
actual: The actual element to compare
expected: The expected element or pattern
path: Optional initial path for error reporting
Returns:
True if elements match, raises AssertionError otherwise
"""
matcher = Matcher()
matcher.path = path
return matcher.matches(actual, expected)
def find(ft, expected):
"""
Find all occurrences of an expected element within a FastHTML tree.
Args:
ft: A FastHTML element or list of elements to search in
expected: The element pattern to find
Returns:
List of matching elements
Raises:
AssertionError: If no matching elements are found
"""
def _elements_match(actual, expected):
"""Check if two elements are the same based on tag, attributes, and children."""
if isinstance(expected, DoNotCheck):
return True
if isinstance(expected, NotStr):
to_compare = _get_attr(actual, "s").lstrip('\n').lstrip()
return to_compare.startswith(expected.s)
if isinstance(actual, NotStr) and _get_type(actual) != _get_type(expected):
return False # to manage the unexpected __eq__ behavior of NotStr
if not isinstance(expected, (TestObject, FT)):
return actual == expected
# Compare tags
if _get_type(actual) != _get_type(expected):
return False
# Compare attributes
expected_attrs = _get_attributes(expected)
for attr_name, expected_attr_value in expected_attrs.items():
actual_attr_value = _get_attr(actual, attr_name)
# attribute is missing
if actual_attr_value == MISSING_ATTR:
return False
# manage predicate values
if isinstance(expected_attr_value, Predicate):
return expected_attr_value.validate(actual_attr_value)
# finally compare values
return actual_attr_value == expected_attr_value
# Compare children recursively
expected_children = _get_children(expected)
actual_children = _get_children(actual)
for expected_child in expected_children:
# Check if this expected child exists somewhere in actual children
if not any(_elements_match(actual_child, expected_child) for actual_child in actual_children):
return False
return True
def _search_tree(current, pattern):
"""Recursively search for pattern in the tree rooted at current."""
# Check if current element matches
matches = []
if _elements_match(current, pattern):
matches.append(current)
# Recursively search in children, in the case that the pattern also appears in children
for child in _get_children(current):
matches.extend(_search_tree(child, pattern))
return matches
# Normalize input to list
elements_to_search = ft if isinstance(ft, (list, tuple, set)) else [ft]
# Search in all provided elements
all_matches = []
for element in elements_to_search:
all_matches.extend(_search_tree(element, expected))
return all_matches
def find_one(ft, expected):
found = find(ft, expected)
assert len(found) == 1, f"Found {len(found)} elements for '{expected}'"
return found[0]
def _str_attrs(attrs: dict):
return " ".join(f'"{attr_name}"="{attr_value}"' for attr_name, attr_value in attrs.items())