Added unit tests for Layout.py

This commit is contained in:
2025-11-30 22:48:11 +01:00
parent 96ed447eae
commit 93cb477c21
8 changed files with 712 additions and 327 deletions

View File

@@ -1,9 +1,11 @@
import re
from dataclasses import dataclass
from typing import Optional
from fastcore.basics import NotStr
from fastcore.xml import FT
from myfasthtml.core.utils import quoted_str
from myfasthtml.core.utils import quoted_str, snake_to_pascal
from myfasthtml.test.testclient import MyFT
MISSING_ATTR = "** MISSING **"
@@ -17,7 +19,18 @@ class Predicate:
raise NotImplementedError
def __str__(self):
return f"{self.__class__.__name__}({self.value if self.value is not None else ''})"
if self.value is None:
str_value = ''
elif isinstance(self.value, str):
str_value = self.value
elif isinstance(self.value, (list, tuple)):
if len(self.value) == 1:
str_value = self.value[0]
else:
str_value = str(self.value)
else:
str_value = str(self.value)
return f"{self.__class__.__name__}({str_value})"
def __repr__(self):
return f"{self.__class__.__name__}({self.value if self.value is not None else ''})"
@@ -47,20 +60,28 @@ class StartsWith(AttrPredicate):
return actual.startswith(self.value)
class Contains(AttrPredicate):
class EndsWith(AttrPredicate):
def __init__(self, value):
super().__init__(value)
def validate(self, actual):
return self.value in actual
return actual.endswith(self.value)
class Contains(AttrPredicate):
def __init__(self, *value):
super().__init__(value)
def validate(self, actual):
return all(val in actual for val in self.value)
class DoesNotContain(AttrPredicate):
def __init__(self, value):
def __init__(self, *value):
super().__init__(value)
def validate(self, actual):
return self.value not in actual
return all(val not in actual for val in self.value)
class AnyValue(AttrPredicate):
@@ -75,6 +96,14 @@ class AnyValue(AttrPredicate):
return actual is not None
class Regex(AttrPredicate):
def __init__(self, pattern):
super().__init__(pattern)
def validate(self, actual):
return re.match(self.value, actual) is not None
class ChildrenPredicate(Predicate):
"""
Predicate given as a child of an element.
@@ -122,12 +151,33 @@ class TestObject:
self.attrs = kwargs
class TestIcon(TestObject):
def __init__(self, name: Optional[str] = ''):
super().__init__("div")
self.name = snake_to_pascal(name) if (name and name[0].islower()) else name
self.children = [
TestObject(NotStr, s=Regex(f'<svg name="\\w+-{self.name}'))
]
def __str__(self):
return f'<div><svg name="{self.name}" .../></div>'
class TestCommand(TestObject):
def __init__(self, name, **kwargs):
super().__init__("Command", **kwargs)
self.attrs = {"name": name} | kwargs # name should be first
class TestScript(TestObject):
def __init__(self, script):
super().__init__("script")
self.script = script
self.children = [
NotStr(self.script),
]
@dataclass
class DoNotCheck:
desc: str = None
@@ -136,7 +186,7 @@ class DoNotCheck:
def _get_type(x):
if hasattr(x, "tag"):
return x.tag
if isinstance(x, TestObject):
if isinstance(x, (TestObject, TestIcon)):
return x.cls.__name__ if isinstance(x.cls, type) else str(x.cls)
return type(x).__name__
@@ -291,21 +341,21 @@ class ErrorOutput:
element_type = _get_type(element)
expected_type = _get_type(expected)
type_error = (" " if element_type == expected_type else "^") * len(element_type)
element_attrs = {attr_name: _get_attr(element, attr_name) for attr_name in _get_attributes(expected)}
expected_attrs = {attr_name: _get_attr(expected, attr_name) for attr_name in _get_attributes(expected)}
attrs_in_error = {attr_name for attr_name, attr_value in element_attrs.items() if
not self._matches(attr_value, expected_attrs[attr_name])}
attrs_error = " ".join(
len(f'"{name}"="{value}"') * ("^" if name in attrs_in_error else " ")
for name, value in element_attrs.items()
)
if type_error.strip() or attrs_error.strip():
return f" {type_error} {attrs_error}"
return None
# For simple values
else:
if not self._matches(element, expected):
@@ -435,7 +485,7 @@ class Matcher:
# Validate the type/tag
assert _get_type(actual) == _get_type(expected), \
self._error_msg("The types are different.", _actual=_get_type(actual), _expected=_get_type(expected))
# Special conditions (ChildrenPredicate)
expected_children = _get_children(expected)
for predicate in [c for c in expected_children if isinstance(c, ChildrenPredicate)]:
@@ -443,25 +493,25 @@ class Matcher:
self._error_msg(f"The condition '{predicate}' is not satisfied.",
_actual=actual,
_expected=predicate.to_debug(expected))
# Compare the attributes
expected_attrs = _get_attributes(expected)
for expected_attr, expected_value in expected_attrs.items():
actual_value = _get_attr(actual, expected_attr)
# Check if attribute exists
if actual_value == MISSING_ATTR:
self._assert_error(f"'{expected_attr}' is not found in Actual.",
self._assert_error(f"'{expected_attr}' is not found in Actual. (attributes: {self._str_attrs(actual)})",
_actual=actual,
_expected=expected)
# Handle Predicate values
if isinstance(expected_value, Predicate):
assert expected_value.validate(actual_value), \
self._error_msg(f"The condition '{expected_value}' is not satisfied.",
_actual=actual,
_expected=expected)
# Handle TestObject recursive matching
elif isinstance(expected, TestObject):
try:
@@ -476,23 +526,23 @@ class Matcher:
self._assert_error(f"The values are different for '{expected_attr}'.",
_actual=actual_value,
_expected=expected_value)
# Handle regular value comparison
else:
assert actual_value == expected_value, \
self._error_msg(f"The values are different for '{expected_attr}'.",
_actual=actual,
_expected=expected)
# Compare the children (only if present)
if expected_children:
# Filter out Predicate children
expected_children = [c for c in expected_children if not isinstance(c, Predicate)]
actual_children = _get_children(actual)
if len(actual_children) < len(expected_children):
self._assert_error("Actual is lesser than expected.", _actual=actual, _expected=expected)
for actual_child, expected_child in zip(actual_children, expected_children):
assert self.matches(actual_child, expected_child)
@@ -517,7 +567,7 @@ class Matcher:
def _match_notstr(self, actual, expected):
"""Match NotStr type."""
to_compare = actual.s.lstrip('\n').lstrip()
to_compare = _get_attr(actual, "s").lstrip('\n').lstrip()
assert to_compare.startswith(expected.s), self._error_msg("Notstr values are different: ",
_actual=to_compare,
_expected=expected.s)
@@ -567,8 +617,9 @@ class Matcher:
return elt.__class__.__name__
@staticmethod
def _str_attrs(attrs: dict):
def _str_attrs(element):
"""Format attributes as a string."""
attrs = _get_attributes(element)
return " ".join(f'"{attr_name}"="{attr_value}"' for attr_name, attr_value in attrs.items())
@staticmethod
@@ -610,75 +661,83 @@ def find(ft, expected):
Raises:
AssertionError: If no matching elements are found
"""
def _elements_match(actual, expected):
"""Check if two elements are the same based on tag, attributes, and children."""
# Quick equality check
if actual == expected:
if isinstance(expected, DoNotCheck):
return True
# Check if both are FT elements
if not (hasattr(actual, "tag") and hasattr(expected, "tag")):
return False
if isinstance(expected, NotStr):
to_compare = _get_attr(actual, "s").lstrip('\n').lstrip()
return to_compare.startswith(expected.s)
if isinstance(actual, NotStr) and _get_type(actual) != _get_type(expected):
return False # to manage the unexpected __eq__ behavior of NotStr
if not isinstance(expected, (TestObject, FT)):
return actual == expected
# Compare tags
if _get_type(actual) != _get_type(expected):
return False
# Compare attributes
expected_attrs = _get_attributes(expected)
actual_attrs = _get_attributes(actual)
for attr_name, attr_value in expected_attrs.items():
if attr_name not in actual_attrs or actual_attrs[attr_name] != attr_value:
for attr_name, expected_attr_value in expected_attrs.items():
actual_attr_value = _get_attr(actual, attr_name)
# attribute is missing
if actual_attr_value == MISSING_ATTR:
return False
# manage predicate values
if isinstance(expected_attr_value, Predicate):
return expected_attr_value.validate(actual_attr_value)
# finally compare values
return actual_attr_value == expected_attr_value
# Compare children recursively
expected_children = _get_children(expected)
actual_children = _get_children(actual)
for expected_child in expected_children:
# Check if this expected child exists somewhere in actual children
if not any(_elements_match(actual_child, expected_child) for actual_child in actual_children):
return False
return True
def _search_tree(current, pattern):
"""Recursively search for pattern in the tree rooted at current."""
# Type mismatch - can't be the same
if type(current) != type(pattern):
return []
# For non-FT elements, simple equality check
if not hasattr(current, "tag"):
return [current] if current == pattern else []
# Check if current element matches
matches = []
if _elements_match(current, pattern):
matches.append(current)
# Recursively search in children
# Recursively search in children, in the case that the pattern also appears in children
for child in _get_children(current):
matches.extend(_search_tree(child, pattern))
return matches
# Normalize input to list
elements_to_search = ft if isinstance(ft, (list, tuple, set)) else [ft]
# Search in all provided elements
all_matches = []
for element in elements_to_search:
all_matches.extend(_search_tree(element, expected))
# Raise error if nothing found
if not all_matches:
raise AssertionError(f"No element found for '{expected}'")
return all_matches
def find_one(ft, expected):
found = find(ft, expected)
assert len(found) == 1, f"Found {len(found)} elements for '{expected}'"
return found[0]
def _str_attrs(attrs: dict):
return " ".join(f'"{attr_name}"="{attr_value}"' for attr_name, attr_value in attrs.items())