744 lines
24 KiB
Python
744 lines
24 KiB
Python
import re
|
|
from dataclasses import dataclass
|
|
from typing import Optional
|
|
|
|
from fastcore.basics import NotStr
|
|
from fastcore.xml import FT
|
|
|
|
from myfasthtml.core.utils import quoted_str, snake_to_pascal
|
|
from myfasthtml.test.testclient import MyFT
|
|
|
|
MISSING_ATTR = "** MISSING **"
|
|
|
|
|
|
class Predicate:
|
|
def __init__(self, value):
|
|
self.value = value
|
|
|
|
def validate(self, actual):
|
|
raise NotImplementedError
|
|
|
|
def __str__(self):
|
|
if self.value is None:
|
|
str_value = ''
|
|
elif isinstance(self.value, str):
|
|
str_value = self.value
|
|
elif isinstance(self.value, (list, tuple)):
|
|
if len(self.value) == 1:
|
|
str_value = self.value[0]
|
|
else:
|
|
str_value = str(self.value)
|
|
else:
|
|
str_value = str(self.value)
|
|
return f"{self.__class__.__name__}({str_value})"
|
|
|
|
def __repr__(self):
|
|
return f"{self.__class__.__name__}({self.value if self.value is not None else ''})"
|
|
|
|
def __eq__(self, other):
|
|
if type(self) is not type(other):
|
|
return False
|
|
return self.value == other.value
|
|
|
|
def __hash__(self):
|
|
return hash(self.value)
|
|
|
|
|
|
class AttrPredicate(Predicate):
|
|
"""
|
|
Predicate that validates an attribute value.
|
|
It's given as a value of an attribute.
|
|
"""
|
|
pass
|
|
|
|
|
|
class StartsWith(AttrPredicate):
|
|
def __init__(self, value):
|
|
super().__init__(value)
|
|
|
|
def validate(self, actual):
|
|
return actual.startswith(self.value)
|
|
|
|
|
|
class EndsWith(AttrPredicate):
|
|
def __init__(self, value):
|
|
super().__init__(value)
|
|
|
|
def validate(self, actual):
|
|
return actual.endswith(self.value)
|
|
|
|
|
|
class Contains(AttrPredicate):
|
|
def __init__(self, *value):
|
|
super().__init__(value)
|
|
|
|
def validate(self, actual):
|
|
return all(val in actual for val in self.value)
|
|
|
|
|
|
class DoesNotContain(AttrPredicate):
|
|
def __init__(self, *value):
|
|
super().__init__(value)
|
|
|
|
def validate(self, actual):
|
|
return all(val not in actual for val in self.value)
|
|
|
|
|
|
class AnyValue(AttrPredicate):
|
|
"""
|
|
True is the attribute is present and the value is not None.
|
|
"""
|
|
|
|
def __init__(self):
|
|
super().__init__(None)
|
|
|
|
def validate(self, actual):
|
|
return actual is not None
|
|
|
|
|
|
class Regex(AttrPredicate):
|
|
def __init__(self, pattern):
|
|
super().__init__(pattern)
|
|
|
|
def validate(self, actual):
|
|
return re.match(self.value, actual) is not None
|
|
|
|
|
|
class ChildrenPredicate(Predicate):
|
|
"""
|
|
Predicate given as a child of an element.
|
|
"""
|
|
|
|
def to_debug(self, element):
|
|
return element
|
|
|
|
|
|
class Empty(ChildrenPredicate):
|
|
def __init__(self):
|
|
super().__init__(None)
|
|
|
|
def validate(self, actual):
|
|
return len(actual.children) == 0 and len(actual.attrs) == 0
|
|
|
|
|
|
class NoChildren(ChildrenPredicate):
|
|
def __init__(self):
|
|
super().__init__(None)
|
|
|
|
def validate(self, actual):
|
|
return len(actual.children) == 0
|
|
|
|
|
|
class AttributeForbidden(ChildrenPredicate):
|
|
"""
|
|
To validate that an attribute is not present in an element.
|
|
"""
|
|
|
|
def __init__(self, value):
|
|
super().__init__(value)
|
|
|
|
def validate(self, actual):
|
|
return self.value not in actual.attrs or actual.attrs[self.value] is None
|
|
|
|
def to_debug(self, element):
|
|
element.attrs[self.value] = "** NOT ALLOWED **"
|
|
return element
|
|
|
|
|
|
class TestObject:
|
|
def __init__(self, cls, **kwargs):
|
|
self.cls = cls
|
|
self.attrs = kwargs
|
|
|
|
|
|
class TestIcon(TestObject):
|
|
def __init__(self, name: Optional[str] = ''):
|
|
super().__init__("div")
|
|
self.name = snake_to_pascal(name) if (name and name[0].islower()) else name
|
|
self.children = [
|
|
TestObject(NotStr, s=Regex(f'<svg name="\\w+-{self.name}'))
|
|
]
|
|
|
|
def __str__(self):
|
|
return f'<div><svg name="{self.name}" .../></div>'
|
|
|
|
|
|
class TestCommand(TestObject):
|
|
def __init__(self, name, **kwargs):
|
|
super().__init__("Command", **kwargs)
|
|
self.attrs = {"name": name} | kwargs # name should be first
|
|
|
|
|
|
class TestScript(TestObject):
|
|
def __init__(self, script):
|
|
super().__init__("script")
|
|
self.script = script
|
|
self.children = [
|
|
NotStr(self.script),
|
|
]
|
|
|
|
|
|
@dataclass
|
|
class DoNotCheck:
|
|
desc: str = None
|
|
|
|
|
|
def _get_type(x):
|
|
if hasattr(x, "tag"):
|
|
return x.tag
|
|
if isinstance(x, (TestObject, TestIcon)):
|
|
return x.cls.__name__ if isinstance(x.cls, type) else str(x.cls)
|
|
return type(x).__name__
|
|
|
|
|
|
def _get_attr(x, attr):
|
|
if hasattr(x, "attrs"):
|
|
return x.attrs.get(attr, MISSING_ATTR)
|
|
|
|
if not hasattr(x, attr):
|
|
return MISSING_ATTR
|
|
|
|
return getattr(x, attr, MISSING_ATTR)
|
|
|
|
|
|
def _get_attributes(x):
|
|
"""Get the attributes dict from an element."""
|
|
if hasattr(x, "attrs"):
|
|
return x.attrs
|
|
return {}
|
|
|
|
|
|
def _get_children(x):
|
|
"""Get the children list from an element."""
|
|
if hasattr(x, "children"):
|
|
return x.children
|
|
return []
|
|
|
|
|
|
class ErrorOutput:
|
|
def __init__(self, path, element, expected):
|
|
self.path = path
|
|
self.element = element
|
|
self.expected = expected
|
|
self.output = []
|
|
self.indent = ""
|
|
|
|
@staticmethod
|
|
def _unconstruct_path_item(item):
|
|
if "#" in item:
|
|
elt_name, elt_id = item.split("#")
|
|
return elt_name, "id", elt_id
|
|
elif match := re.match(r'(\w+)\[(class|name)=([^]]+)]', item):
|
|
return match.groups()
|
|
return item, None, None
|
|
|
|
def __str__(self):
|
|
return f"ErrorOutput({self.output})"
|
|
|
|
def compute(self):
|
|
# first render the path hierarchy
|
|
for p in self.path.split(".")[:-1]:
|
|
elt_name, attr_name, attr_value = self._unconstruct_path_item(p)
|
|
path_str = self._str_element(MyFT(elt_name, {attr_name: attr_value}), keep_open=True)
|
|
self._add_to_output(f"{path_str}")
|
|
self.indent += " "
|
|
|
|
# then render the element
|
|
if hasattr(self.expected, "tag") and hasattr(self.element, "tag"):
|
|
# display the tag and its attributes
|
|
tag_str = self._str_element(self.element, self.expected)
|
|
self._add_to_output(tag_str)
|
|
|
|
# Try to show where the differences are
|
|
error_str = self._detect_error(self.element, self.expected)
|
|
if error_str:
|
|
self._add_to_output(error_str)
|
|
|
|
# render the children
|
|
expected_children = [c for c in self.expected.children if not isinstance(c, ChildrenPredicate)]
|
|
if len(expected_children) > 0:
|
|
self.indent += " "
|
|
element_index = 0
|
|
for expected_child in expected_children:
|
|
if element_index >= len(self.element.children):
|
|
# When there are fewer children than expected, we display a placeholder
|
|
child_str = "! ** MISSING ** !"
|
|
self._add_to_output(child_str)
|
|
element_index += 1
|
|
continue
|
|
|
|
# display the child
|
|
element_child = self.element.children[element_index]
|
|
child_str = self._str_element(element_child, expected_child, keep_open=False)
|
|
self._add_to_output(child_str)
|
|
|
|
# manage errors (only when the expected is a FT element
|
|
if hasattr(expected_child, "tag"):
|
|
child_error_str = self._detect_error(element_child, expected_child)
|
|
if child_error_str:
|
|
self._add_to_output(child_error_str)
|
|
|
|
# continue
|
|
element_index += 1
|
|
|
|
self.indent = self.indent[:-2]
|
|
self._add_to_output(")")
|
|
|
|
elif isinstance(self.expected, TestObject):
|
|
cls = _get_type(self.element)
|
|
attrs = {attr_name: _get_attr(self.element, attr_name) for attr_name in self.expected.attrs}
|
|
self._add_to_output(f"({cls} {_str_attrs(attrs)})")
|
|
# Try to show where the differences are
|
|
error_str = self._detect_error(self.element, self.expected)
|
|
if error_str:
|
|
self._add_to_output(error_str)
|
|
|
|
else:
|
|
self._add_to_output(str(self.element))
|
|
# Try to show where the differences are
|
|
error_str = self._detect_error(self.element, self.expected)
|
|
if error_str:
|
|
self._add_to_output(error_str)
|
|
|
|
def _add_to_output(self, msg):
|
|
self.output.append(f"{self.indent}{msg}")
|
|
|
|
@staticmethod
|
|
def _str_element(element, expected=None, keep_open=None):
|
|
# compare to itself if no expected element is provided
|
|
if expected is None:
|
|
expected = element
|
|
|
|
if hasattr(element, "tag"):
|
|
# the attributes are compared to the expected element
|
|
elt_attrs = {attr_name: element.attrs.get(attr_name, MISSING_ATTR) for attr_name in
|
|
[attr_name for attr_name in expected.attrs if attr_name is not None]}
|
|
|
|
elt_attrs_str = " ".join(f'"{attr_name}"="{attr_value}"' for attr_name, attr_value in elt_attrs.items())
|
|
tag_str = f"({element.tag} {elt_attrs_str}"
|
|
|
|
# manage the closing tag
|
|
if keep_open is False:
|
|
tag_str += " ...)" if len(element.children) > 0 else ")"
|
|
elif keep_open is True:
|
|
tag_str += "..." if elt_attrs_str == "" else " ..."
|
|
else:
|
|
# close the tag if there are no children
|
|
not_special_children = [c for c in element.children if not isinstance(c, Predicate)]
|
|
if len(not_special_children) == 0: tag_str += ")"
|
|
return tag_str
|
|
|
|
else:
|
|
return quoted_str(element)
|
|
|
|
def _detect_error(self, element, expected):
|
|
"""
|
|
Detect errors between element and expected, returning a visual marker string.
|
|
Unified version that handles both FT elements and TestObjects.
|
|
"""
|
|
# For elements with structure (FT or TestObject)
|
|
if hasattr(expected, "tag") or isinstance(expected, TestObject):
|
|
element_type = _get_type(element)
|
|
expected_type = _get_type(expected)
|
|
type_error = (" " if element_type == expected_type else "^") * len(element_type)
|
|
|
|
element_attrs = {attr_name: _get_attr(element, attr_name) for attr_name in _get_attributes(expected)}
|
|
expected_attrs = {attr_name: _get_attr(expected, attr_name) for attr_name in _get_attributes(expected)}
|
|
attrs_in_error = {attr_name for attr_name, attr_value in element_attrs.items() if
|
|
not self._matches(attr_value, expected_attrs[attr_name])}
|
|
|
|
attrs_error = " ".join(
|
|
len(f'"{name}"="{value}"') * ("^" if name in attrs_in_error else " ")
|
|
for name, value in element_attrs.items()
|
|
)
|
|
|
|
if type_error.strip() or attrs_error.strip():
|
|
return f" {type_error} {attrs_error}"
|
|
return None
|
|
|
|
# For simple values
|
|
else:
|
|
if not self._matches(element, expected):
|
|
return len(str(element)) * "^"
|
|
return None
|
|
|
|
@staticmethod
|
|
def _matches(element, expected):
|
|
if element == expected:
|
|
return True
|
|
elif isinstance(expected, Predicate):
|
|
return expected.validate(element)
|
|
else:
|
|
return element == expected
|
|
|
|
|
|
class ErrorComparisonOutput:
|
|
def __init__(self, actual_error_output, expected_error_output):
|
|
self.actual_error_output = actual_error_output
|
|
self.expected_error_output = expected_error_output
|
|
|
|
@staticmethod
|
|
def adjust(to_adjust, reference):
|
|
for index, ref_line in enumerate(reference):
|
|
if "^^" in ref_line:
|
|
# insert an empty line in to_adjust
|
|
to_adjust.insert(index, "")
|
|
|
|
return to_adjust
|
|
|
|
def render(self):
|
|
# init if needed
|
|
if not self.actual_error_output.output:
|
|
self.actual_error_output.compute()
|
|
if not self.expected_error_output.output:
|
|
self.expected_error_output.compute()
|
|
|
|
actual = self.actual_error_output.output
|
|
expected = self.expected_error_output.output
|
|
|
|
# actual = self.adjust(actual, expected) # does not seem to be needed
|
|
expected = self.adjust(expected, actual)
|
|
|
|
actual_max_length = len(max(actual, key=len))
|
|
# expected_max_length = len(max(expected, key=len))
|
|
|
|
output = []
|
|
for a, e in zip(actual, expected):
|
|
line = f"{a:<{actual_max_length}} | {e}".rstrip()
|
|
output.append(line)
|
|
|
|
return "\n".join(output)
|
|
|
|
|
|
class Matcher:
|
|
"""
|
|
Matcher class for comparing actual and expected elements.
|
|
Provides flexible comparison with support for predicates, nested structures,
|
|
and detailed error reporting.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.path = ""
|
|
|
|
def matches(self, actual, expected):
|
|
"""
|
|
Compare actual and expected elements.
|
|
|
|
Args:
|
|
actual: The actual element to compare
|
|
expected: The expected element or pattern
|
|
|
|
Returns:
|
|
True if elements match, raises AssertionError otherwise
|
|
"""
|
|
if actual is not None and expected is None:
|
|
self._assert_error("Actual is not None while expected is None", _actual=actual)
|
|
|
|
if isinstance(expected, DoNotCheck):
|
|
return True
|
|
|
|
if actual is None and expected is not None:
|
|
self._assert_error("Actual is None while expected is ", _expected=expected)
|
|
|
|
# set the path
|
|
current_path = self._get_current_path(actual)
|
|
original_path = self.path
|
|
self.path = self.path + "." + current_path if self.path else current_path
|
|
|
|
try:
|
|
self._match_elements(actual, expected)
|
|
finally:
|
|
# restore the original path for sibling comparisons
|
|
self.path = original_path
|
|
|
|
return True
|
|
|
|
def _match_elements(self, actual, expected):
|
|
"""Internal method that performs the actual comparison logic."""
|
|
if isinstance(expected, TestObject) or hasattr(expected, "tag"):
|
|
self._match_element(actual, expected)
|
|
return
|
|
|
|
if isinstance(expected, Predicate):
|
|
assert expected.validate(actual), \
|
|
self._error_msg(f"The condition '{expected}' is not satisfied.",
|
|
_actual=actual,
|
|
_expected=expected)
|
|
return
|
|
|
|
assert _get_type(actual) == _get_type(expected), \
|
|
self._error_msg("The types are different.", _actual=actual, _expected=expected)
|
|
|
|
if isinstance(expected, (list, tuple)):
|
|
self._match_list(actual, expected)
|
|
elif isinstance(expected, dict):
|
|
self._match_dict(actual, expected)
|
|
elif isinstance(expected, NotStr):
|
|
self._match_notstr(actual, expected)
|
|
else:
|
|
assert actual == expected, self._error_msg("The values are different",
|
|
_actual=actual,
|
|
_expected=expected)
|
|
|
|
def _match_element(self, actual, expected):
|
|
"""Match a TestObject or FT element."""
|
|
# Validate the type/tag
|
|
assert _get_type(actual) == _get_type(expected), \
|
|
self._error_msg("The types are different.", _actual=_get_type(actual), _expected=_get_type(expected))
|
|
|
|
# Special conditions (ChildrenPredicate)
|
|
expected_children = _get_children(expected)
|
|
for predicate in [c for c in expected_children if isinstance(c, ChildrenPredicate)]:
|
|
assert predicate.validate(actual), \
|
|
self._error_msg(f"The condition '{predicate}' is not satisfied.",
|
|
_actual=actual,
|
|
_expected=predicate.to_debug(expected))
|
|
|
|
# Compare the attributes
|
|
expected_attrs = _get_attributes(expected)
|
|
for expected_attr, expected_value in expected_attrs.items():
|
|
actual_value = _get_attr(actual, expected_attr)
|
|
|
|
# Check if attribute exists
|
|
if actual_value == MISSING_ATTR:
|
|
self._assert_error(f"'{expected_attr}' is not found in Actual. (attributes: {self._str_attrs(actual)})",
|
|
_actual=actual,
|
|
_expected=expected)
|
|
|
|
# Handle Predicate values
|
|
if isinstance(expected_value, Predicate):
|
|
assert expected_value.validate(actual_value), \
|
|
self._error_msg(f"The condition '{expected_value}' is not satisfied.",
|
|
_actual=actual,
|
|
_expected=expected)
|
|
|
|
# Handle TestObject recursive matching
|
|
elif isinstance(expected, TestObject):
|
|
try:
|
|
self.matches(actual_value, expected_value)
|
|
except AssertionError as e:
|
|
match = re.search(r"Error : (.+?)\n", str(e))
|
|
if match:
|
|
self._assert_error(f"{match.group(1)} for '{expected_attr}'.",
|
|
_actual=actual_value,
|
|
_expected=expected_value)
|
|
else:
|
|
self._assert_error(f"The values are different for '{expected_attr}'.",
|
|
_actual=actual_value,
|
|
_expected=expected_value)
|
|
|
|
# Handle regular value comparison
|
|
else:
|
|
assert actual_value == expected_value, \
|
|
self._error_msg(f"The values are different for '{expected_attr}'.",
|
|
_actual=actual,
|
|
_expected=expected)
|
|
|
|
# Compare the children (only if present)
|
|
if expected_children:
|
|
# Filter out Predicate children
|
|
expected_children = [c for c in expected_children if not isinstance(c, Predicate)]
|
|
actual_children = _get_children(actual)
|
|
|
|
if len(actual_children) < len(expected_children):
|
|
self._assert_error("Actual is lesser than expected.", _actual=actual, _expected=expected)
|
|
|
|
for actual_child, expected_child in zip(actual_children, expected_children):
|
|
assert self.matches(actual_child, expected_child)
|
|
|
|
def _match_list(self, actual, expected):
|
|
"""Match list or tuple."""
|
|
if len(actual) < len(expected):
|
|
self._assert_error("Actual is smaller than expected: ", _actual=actual, _expected=expected)
|
|
if len(actual) > len(expected):
|
|
self._assert_error("Actual is bigger than expected: ", _actual=actual, _expected=expected)
|
|
|
|
for actual_child, expected_child in zip(actual, expected):
|
|
assert self.matches(actual_child, expected_child)
|
|
|
|
def _match_dict(self, actual, expected):
|
|
"""Match dictionary."""
|
|
if len(actual) < len(expected):
|
|
self._assert_error("Actual is smaller than expected: ", _actual=actual, _expected=expected)
|
|
if len(actual) > len(expected):
|
|
self._assert_error("Actual is bigger than expected: ", _actual=actual, _expected=expected)
|
|
for k, v in expected.items():
|
|
assert self.matches(actual[k], v)
|
|
|
|
def _match_notstr(self, actual, expected):
|
|
"""Match NotStr type."""
|
|
to_compare = _get_attr(actual, "s").lstrip('\n').lstrip()
|
|
assert to_compare.startswith(expected.s), self._error_msg("Notstr values are different: ",
|
|
_actual=to_compare,
|
|
_expected=expected.s)
|
|
|
|
def _print_path(self):
|
|
"""Format the current path for error messages."""
|
|
return f"Path : '{self.path}'\n" if self.path else ""
|
|
|
|
def _debug_compare(self, a, b):
|
|
"""Generate a comparison debug output."""
|
|
actual_out = ErrorOutput(self.path, a, b)
|
|
expected_out = ErrorOutput(self.path, b, b)
|
|
|
|
comparison_out = ErrorComparisonOutput(actual_out, expected_out)
|
|
return comparison_out.render()
|
|
|
|
def _error_msg(self, msg, _actual=None, _expected=None):
|
|
"""Generate an error message with debug information."""
|
|
if _actual is None and _expected is None:
|
|
debug_info = ""
|
|
elif _actual is None:
|
|
debug_info = self._debug(_expected)
|
|
elif _expected is None:
|
|
debug_info = self._debug(_actual)
|
|
else:
|
|
debug_info = self._debug_compare(_actual, _expected)
|
|
|
|
return f"{self._print_path()}Error : {msg}\n{debug_info}"
|
|
|
|
def _assert_error(self, msg, _actual=None, _expected=None):
|
|
"""Raise an assertion error with formatted message."""
|
|
assert False, self._error_msg(msg, _actual=_actual, _expected=_expected)
|
|
|
|
@staticmethod
|
|
def _get_current_path(elt):
|
|
"""Get the path representation of an element."""
|
|
if hasattr(elt, "tag"):
|
|
res = f"{elt.tag}"
|
|
if "id" in elt.attrs:
|
|
res += f"#{elt.attrs['id']}"
|
|
elif "name" in elt.attrs:
|
|
res += f"[name={elt.attrs['name']}]"
|
|
elif "class" in elt.attrs:
|
|
res += f"[class={elt.attrs['class']}]"
|
|
return res
|
|
else:
|
|
return elt.__class__.__name__
|
|
|
|
@staticmethod
|
|
def _str_attrs(element):
|
|
"""Format attributes as a string."""
|
|
attrs = _get_attributes(element)
|
|
return " ".join(f'"{attr_name}"="{attr_value}"' for attr_name, attr_value in attrs.items())
|
|
|
|
@staticmethod
|
|
def _debug(elt):
|
|
"""Format an element for debug output."""
|
|
return str(elt) if elt else "None"
|
|
|
|
|
|
def matches(actual, expected, path=""):
|
|
"""
|
|
Compare actual and expected elements.
|
|
|
|
This is a convenience wrapper around the Matcher class.
|
|
|
|
Args:
|
|
actual: The actual element to compare
|
|
expected: The expected element or pattern
|
|
path: Optional initial path for error reporting
|
|
|
|
Returns:
|
|
True if elements match, raises AssertionError otherwise
|
|
"""
|
|
matcher = Matcher()
|
|
matcher.path = path
|
|
return matcher.matches(actual, expected)
|
|
|
|
|
|
def find(ft, expected):
|
|
"""
|
|
Find all occurrences of an expected element within a FastHTML tree.
|
|
|
|
Args:
|
|
ft: A FastHTML element or list of elements to search in
|
|
expected: The element pattern to find
|
|
|
|
Returns:
|
|
List of matching elements
|
|
|
|
Raises:
|
|
AssertionError: If no matching elements are found
|
|
"""
|
|
|
|
def _elements_match(actual, expected):
|
|
"""Check if two elements are the same based on tag, attributes, and children."""
|
|
if isinstance(expected, DoNotCheck):
|
|
return True
|
|
|
|
if isinstance(expected, NotStr):
|
|
to_compare = _get_attr(actual, "s").lstrip('\n').lstrip()
|
|
return to_compare.startswith(expected.s)
|
|
|
|
if isinstance(actual, NotStr) and _get_type(actual) != _get_type(expected):
|
|
return False # to manage the unexpected __eq__ behavior of NotStr
|
|
|
|
if not isinstance(expected, (TestObject, FT)):
|
|
return actual == expected
|
|
|
|
# Compare tags
|
|
if _get_type(actual) != _get_type(expected):
|
|
return False
|
|
|
|
# Compare attributes
|
|
expected_attrs = _get_attributes(expected)
|
|
for attr_name, expected_attr_value in expected_attrs.items():
|
|
actual_attr_value = _get_attr(actual, attr_name)
|
|
# attribute is missing
|
|
if actual_attr_value == MISSING_ATTR:
|
|
return False
|
|
# manage predicate values
|
|
if isinstance(expected_attr_value, Predicate):
|
|
return expected_attr_value.validate(actual_attr_value)
|
|
# finally compare values
|
|
return actual_attr_value == expected_attr_value
|
|
|
|
# Compare children recursively
|
|
expected_children = _get_children(expected)
|
|
actual_children = _get_children(actual)
|
|
|
|
for expected_child in expected_children:
|
|
# Check if this expected child exists somewhere in actual children
|
|
if not any(_elements_match(actual_child, expected_child) for actual_child in actual_children):
|
|
return False
|
|
|
|
return True
|
|
|
|
def _search_tree(current, pattern):
|
|
"""Recursively search for pattern in the tree rooted at current."""
|
|
# Check if current element matches
|
|
matches = []
|
|
if _elements_match(current, pattern):
|
|
matches.append(current)
|
|
|
|
# Recursively search in children, in the case that the pattern also appears in children
|
|
for child in _get_children(current):
|
|
matches.extend(_search_tree(child, pattern))
|
|
|
|
return matches
|
|
|
|
# Normalize input to list
|
|
elements_to_search = ft if isinstance(ft, (list, tuple, set)) else [ft]
|
|
|
|
# Search in all provided elements
|
|
all_matches = []
|
|
for element in elements_to_search:
|
|
all_matches.extend(_search_tree(element, expected))
|
|
|
|
# Raise error if nothing found
|
|
if not all_matches:
|
|
raise AssertionError(f"No element found for '{expected}'")
|
|
|
|
return all_matches
|
|
|
|
|
|
def find_one(ft, expected):
|
|
found = find(ft, expected)
|
|
assert len(found) == 1, f"Found {len(found)} elements for '{expected}'"
|
|
return found[0]
|
|
|
|
|
|
def _str_attrs(attrs: dict):
|
|
return " ".join(f'"{attr_name}"="{attr_value}"' for attr_name, attr_value in attrs.items())
|