Added Application HolidayViewer
This commit is contained in:
116
tests/helpers.py
116
tests/helpers.py
@@ -2,6 +2,7 @@ import dataclasses
|
||||
import json
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from typing import Any
|
||||
|
||||
import numpy
|
||||
import pandas as pd
|
||||
@@ -43,6 +44,54 @@ class Contains:
|
||||
s: str
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class JsonViewerNode:
|
||||
is_expanded: bool | None
|
||||
key: str
|
||||
value: Any
|
||||
debug_key: Any = None
|
||||
debug_folding: Any = None
|
||||
|
||||
def find(self, path):
|
||||
"""
|
||||
Finds and returns a node in a hierarchical structure based on a dot-delimited path.
|
||||
|
||||
The method uses a recursive helper function to navigate through a tree-like
|
||||
hierarchical node structure. Each node in the structure is assumed to potentially
|
||||
have a "children" attribute, which is iterated to find matching keys in the path.
|
||||
If, at any point, a node does not have the expected structure or the key is not
|
||||
found within the children, the method will return None.
|
||||
|
||||
:param path: A dot-delimited string representing the hierarchical path to
|
||||
the desired node (e.g., "root.child.subchild").
|
||||
:return: The node in the hierarchy that matches the specified path or None
|
||||
if no such node exists.
|
||||
"""
|
||||
|
||||
def _find(node, path_parts):
|
||||
if len(path_parts) == 0:
|
||||
return node
|
||||
|
||||
element = node.value # to deal with ft element
|
||||
|
||||
if not hasattr(element, "children"):
|
||||
return None
|
||||
|
||||
to_find = path_parts[0]
|
||||
|
||||
for child in element.children:
|
||||
child_node = extract_jsonviewer_node(child)
|
||||
if child_node is not None and child_node.key == to_find:
|
||||
return _find(child_node, path_parts[1:])
|
||||
|
||||
return None
|
||||
|
||||
path_parts = path.split(".")
|
||||
return _find(self, path_parts)
|
||||
|
||||
def text_value(self):
|
||||
return str(self.value.children[0])
|
||||
|
||||
Empty = EmptyElement()
|
||||
|
||||
|
||||
@@ -424,6 +473,37 @@ def matches(actual, expected, path=""):
|
||||
return True
|
||||
|
||||
|
||||
def contains(lst, element, recursive=False):
|
||||
"""
|
||||
Check if any item in the list matches the given element pattern
|
||||
using the existing matches() function.
|
||||
|
||||
Args:
|
||||
lst: List of elements to search through
|
||||
element: Element pattern to match against
|
||||
recursive: If True, also search in children of each element
|
||||
|
||||
Returns:
|
||||
bool: True if a match is found, False otherwise
|
||||
"""
|
||||
if not lst:
|
||||
return False
|
||||
|
||||
for item in lst:
|
||||
try:
|
||||
if matches(item, element):
|
||||
return True
|
||||
except AssertionError:
|
||||
pass
|
||||
|
||||
# If recursive is True, check children too
|
||||
if recursive and hasattr(item, "children") and item.children:
|
||||
if contains(item.children, element, recursive=True):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_selected(return_elements):
|
||||
assert isinstance(return_elements, list), "result must be a list"
|
||||
for element in return_elements:
|
||||
@@ -616,6 +696,42 @@ def extract_popup_content(element, filter_input=True) -> OrderedDict:
|
||||
return res
|
||||
|
||||
|
||||
def extract_jsonviewer_node(element):
|
||||
# This structure of the Jsonview Node is
|
||||
# 3 children
|
||||
# 1st : Span(NotStr(name="expanded|collapse")) or None
|
||||
# 2nd : Span("key : ") or None (None is possible only for the root node)
|
||||
# 3rd : Span(value)
|
||||
|
||||
if not hasattr(element, "children") or len(element.children) != 3:
|
||||
return None
|
||||
|
||||
debug_folding = element.children[0]
|
||||
debug_key = element.children[1]
|
||||
value = element.children[2]
|
||||
|
||||
if contains([debug_folding], span_icon("expanded")):
|
||||
is_expanded = True
|
||||
elif contains([debug_folding], span_icon("collapsed")):
|
||||
is_expanded = False
|
||||
else:
|
||||
is_expanded = None
|
||||
|
||||
if debug_key is not None:
|
||||
assert hasattr(debug_key, "tag") and debug_key.tag == "span", "debug_key must be a span"
|
||||
key = debug_key.children[0].split(" : ")[0]
|
||||
else:
|
||||
key = None
|
||||
|
||||
return JsonViewerNode(
|
||||
is_expanded,
|
||||
key,
|
||||
value,
|
||||
debug_key,
|
||||
debug_folding
|
||||
)
|
||||
|
||||
|
||||
def to_array(dataframe: pd.DataFrame) -> list:
|
||||
return [[val for val in row] for _, row in dataframe.iterrows()]
|
||||
|
||||
|
||||
132
tests/test_calendar_helper.py
Normal file
132
tests/test_calendar_helper.py
Normal file
@@ -0,0 +1,132 @@
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from components.hoildays.helpers.calendar_helper import CalendarHelper
|
||||
from components.hoildays.helpers.nibelisparser import OffPeriodDetails
|
||||
|
||||
|
||||
def test_get_period_end_before_start():
|
||||
with pytest.raises(ValueError) as err:
|
||||
CalendarHelper.get_period(datetime.today(), datetime.today() - timedelta(days=1))
|
||||
|
||||
assert str(err.value) == "end date is before start date."
|
||||
|
||||
|
||||
def test_get_period():
|
||||
start = datetime.today()
|
||||
res = CalendarHelper.get_period(start, start + timedelta(days=3))
|
||||
|
||||
assert res == [
|
||||
start,
|
||||
start + timedelta(days=1),
|
||||
start + timedelta(days=2),
|
||||
start + timedelta(days=3)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("start_am_pm, end_am_pm, is_start, is_end, expected", [
|
||||
("am", None, True, False, "reason_am"),
|
||||
("pm", None, True, False, "reason_pm"),
|
||||
(None, "am", False, True, "reason_am"),
|
||||
(None, "pm", False, True, "reason_pm"),
|
||||
("am", "pm", True, True, "reason_am_pm"),
|
||||
])
|
||||
def test_get_reason(start_am_pm, end_am_pm, is_start, is_end, expected):
|
||||
record = OffPeriodDetails(
|
||||
first_name="first_name",
|
||||
last_name="last_name",
|
||||
start_date=datetime.today(),
|
||||
start_am_pm=start_am_pm,
|
||||
end_date=datetime.today() + timedelta(days=1),
|
||||
end_am_pm=end_am_pm,
|
||||
total=2,
|
||||
reason="reason",
|
||||
date_import=datetime.today(),
|
||||
)
|
||||
actual = CalendarHelper.get_reason(record, is_start, is_end)
|
||||
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_create_calendar_one_user():
|
||||
records = [
|
||||
OffPeriodDetails("john", "doo", datetime(2025, 6, 19), None, datetime(2025, 6, 20), None, 1, "CP", datetime.now())
|
||||
]
|
||||
|
||||
names, calendar = CalendarHelper.create_calendar(records, datetime(2025, 6, 18), datetime(2025, 6, 21))
|
||||
expected_calendar = {
|
||||
datetime(2025, 6, 18): [None],
|
||||
datetime(2025, 6, 19): [["CP"]],
|
||||
datetime(2025, 6, 20): [["CP"]],
|
||||
datetime(2025, 6, 21): [None],
|
||||
}
|
||||
|
||||
assert names == ["john doo"]
|
||||
assert calendar == expected_calendar
|
||||
|
||||
|
||||
def test_create_calendar_multiple_users():
|
||||
now = datetime.now()
|
||||
records = [
|
||||
OffPeriodDetails("john", "doo", datetime(2025, 6, 19), None, datetime(2025, 6, 20), None, 1, "CP", now),
|
||||
OffPeriodDetails("jane", "doo", datetime(2025, 6, 18), None, datetime(2025, 6, 19), None, 1, "CP", now)
|
||||
]
|
||||
|
||||
names, calendar = CalendarHelper.create_calendar(records, datetime(2025, 6, 18), datetime(2025, 6, 21))
|
||||
expected_calendar = {
|
||||
datetime(2025, 6, 18): [["CP"], None],
|
||||
datetime(2025, 6, 19): [["CP"], ["CP"]],
|
||||
datetime(2025, 6, 20): [None, ["CP"]],
|
||||
datetime(2025, 6, 21): [None, None],
|
||||
}
|
||||
|
||||
assert names == ["jane doo", "john doo"]
|
||||
assert calendar == expected_calendar
|
||||
|
||||
|
||||
def test_create_calendar_end_is_missing():
|
||||
records = [
|
||||
OffPeriodDetails("john", "doo", datetime(2025, 6, 19), None, datetime(2025, 6, 20), None, 1, "CP", datetime.now())
|
||||
]
|
||||
|
||||
names, calendar = CalendarHelper.create_calendar(records, datetime(2025, 6, 18))
|
||||
expected_calendar = {
|
||||
datetime(2025, 6, 18): [None],
|
||||
datetime(2025, 6, 19): [["CP"]],
|
||||
datetime(2025, 6, 20): [["CP"]],
|
||||
}
|
||||
|
||||
assert names == ["john doo"]
|
||||
assert calendar == expected_calendar
|
||||
|
||||
|
||||
def test_create_calendar_start_is_missing():
|
||||
records = [
|
||||
OffPeriodDetails("john", "doo", datetime(2025, 6, 19), None, datetime(2025, 6, 20), None, 1, "CP", datetime.now())
|
||||
]
|
||||
|
||||
names, calendar = CalendarHelper.create_calendar(records, None, datetime(2025, 6, 21))
|
||||
expected_calendar = {
|
||||
datetime(2025, 6, 19): [["CP"]],
|
||||
datetime(2025, 6, 20): [["CP"]],
|
||||
datetime(2025, 6, 21): [None],
|
||||
}
|
||||
|
||||
assert names == ["john doo"]
|
||||
assert calendar == expected_calendar
|
||||
|
||||
|
||||
def test_create_calendar_start_and_end_are_missing():
|
||||
records = [
|
||||
OffPeriodDetails("john", "doo", datetime(2025, 6, 19), None, datetime(2025, 6, 20), None, 1, "CP", datetime.now())
|
||||
]
|
||||
|
||||
names, calendar = CalendarHelper.create_calendar(records)
|
||||
expected_calendar = {
|
||||
datetime(2025, 6, 19): [["CP"]],
|
||||
datetime(2025, 6, 20): [["CP"]],
|
||||
}
|
||||
|
||||
assert names == ["john doo"]
|
||||
assert calendar == expected_calendar
|
||||
@@ -1,15 +1,9 @@
|
||||
from collections import OrderedDict
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from fastcore.basics import NotStr
|
||||
from fastcore.xml import to_xml
|
||||
from fasthtml.components import *
|
||||
|
||||
from components.datagrid.DataGrid import DataGrid
|
||||
from helpers import matches, search_elements_by_name, search_elements_by_path, extract_table_values, get_from_html, \
|
||||
extract_popup_content, \
|
||||
Empty, get_path_attributes, find_first_match, StartsWith, search_first_with_attribute, Contains
|
||||
from components.debugger.components.JsonViewer import JsonViewer
|
||||
from helpers import *
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -329,3 +323,118 @@ def test_i_can_search_first_with_attribute(tag, attr, expected, sample_structure
|
||||
assert result.tag == expected[0]
|
||||
assert attr in result.attrs
|
||||
assert result.attrs[attr] == expected[1]
|
||||
|
||||
|
||||
# Add tests for extract_jsonviewer_node
|
||||
def test_extract_jsonviewer_node():
|
||||
# Create a valid JsonViewer node element
|
||||
element = Div(
|
||||
span_icon("expanded"),
|
||||
Span("key : "),
|
||||
Span("value")
|
||||
)
|
||||
|
||||
result = extract_jsonviewer_node(element)
|
||||
|
||||
assert result is not None
|
||||
assert result.is_expanded is True
|
||||
assert result.key == "key"
|
||||
assert result.value == element.children[2]
|
||||
assert result.debug_key == element.children[1]
|
||||
assert result.debug_folding == element.children[0]
|
||||
|
||||
|
||||
def test_extract_jsonviewer_node_collapsed():
|
||||
# Create a collapsed JsonViewer node element
|
||||
element = Div(
|
||||
span_icon("collapsed"),
|
||||
Span("key : "),
|
||||
Span("value")
|
||||
)
|
||||
|
||||
result = extract_jsonviewer_node(element)
|
||||
|
||||
assert result is not None
|
||||
assert result.is_expanded is False
|
||||
assert result.key == "key"
|
||||
assert result.value == element.children[2]
|
||||
|
||||
|
||||
def test_extract_jsonviewer_node_no_expansion_state():
|
||||
# Create a JsonViewer node with no expansion state
|
||||
element = Div(
|
||||
Span(),
|
||||
Span("key : "),
|
||||
Span("value")
|
||||
)
|
||||
|
||||
result = extract_jsonviewer_node(element)
|
||||
|
||||
assert result is not None
|
||||
assert result.is_expanded is None
|
||||
assert result.key == "key"
|
||||
assert result.value == element.children[2]
|
||||
|
||||
|
||||
def test_extract_jsonviewer_node_root_node():
|
||||
# Create a root JsonViewer node (no key)
|
||||
element = Div(
|
||||
span_icon("expanded"),
|
||||
None,
|
||||
Span("value")
|
||||
)
|
||||
|
||||
result = extract_jsonviewer_node(element)
|
||||
|
||||
assert result is not None
|
||||
assert result.is_expanded is True
|
||||
assert result.key is None
|
||||
assert result.value == element.children[2]
|
||||
|
||||
|
||||
def test_extract_jsonviewer_node_invalid_structure():
|
||||
# Test with invalid node structure (not enough children)
|
||||
element = Div(
|
||||
span_icon("expanded"),
|
||||
Span("key : ")
|
||||
)
|
||||
|
||||
result = extract_jsonviewer_node(element)
|
||||
|
||||
assert result is None
|
||||
|
||||
# Test with element that has no children attribute
|
||||
element = "not an element with children"
|
||||
|
||||
result = extract_jsonviewer_node(element)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_json_viewer_find():
|
||||
value = {"a": [1, 2, 3], "b": {"x": "y", "z": True}}
|
||||
jsonviewer = JsonViewer(None, None, None, None, value)
|
||||
elements = jsonviewer.__ft__()
|
||||
root_div = search_elements_by_name(elements, "div", attrs={"id": f"{jsonviewer.get_id()}-root"})[0]
|
||||
first_level_div = root_div.children[0]
|
||||
|
||||
as_node = extract_jsonviewer_node(first_level_div)
|
||||
child_b = as_node.find("b")
|
||||
|
||||
assert isinstance(child_b, JsonViewerNode)
|
||||
assert child_b.key == "b"
|
||||
|
||||
|
||||
def test_json_viewer_find_with_path():
|
||||
value = {"a": {"x": None, "y": ["first", "second"], "z": True}}
|
||||
jsonviewer = JsonViewer(None, None, None, None, value)
|
||||
jsonviewer.set_folding_mode("expand")
|
||||
elements = jsonviewer.__ft__()
|
||||
root_div = search_elements_by_name(elements, "div", attrs={"id": f"{jsonviewer.get_id()}-root"})[0]
|
||||
first_level_div = root_div.children[0]
|
||||
|
||||
as_node = extract_jsonviewer_node(first_level_div)
|
||||
child = as_node.find("a.y.0")
|
||||
|
||||
assert isinstance(child, JsonViewerNode)
|
||||
assert child.key == "0"
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import pytest
|
||||
|
||||
from components.BaseComponent import BaseComponent
|
||||
from core.instance_manager import InstanceManager, SESSION_ID_KEY, NOT_LOGGED # Adjust import path as needed
|
||||
from constants import NO_SESSION, SESSION_USER_ID_KEY, NOT_LOGGED
|
||||
from core.instance_manager import InstanceManager
|
||||
|
||||
|
||||
class MockBaseComponent(BaseComponent):
|
||||
@@ -52,7 +53,7 @@ def session():
|
||||
"""
|
||||
Fixture to provide a default mocked session dictionary with a fixed user_id.
|
||||
"""
|
||||
return {SESSION_ID_KEY: "test_user"}
|
||||
return {SESSION_USER_ID_KEY: "test_user"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -140,7 +141,7 @@ def test_register_registers_instance(session, base_component_instance):
|
||||
"""
|
||||
InstanceManager.register(session, base_component_instance)
|
||||
|
||||
key = (session[SESSION_ID_KEY], base_component_instance._id)
|
||||
key = (session[SESSION_USER_ID_KEY], base_component_instance._id)
|
||||
assert key in InstanceManager._instances
|
||||
assert InstanceManager._instances[key] == base_component_instance
|
||||
|
||||
@@ -167,7 +168,7 @@ def test_register_fetches_id_from_instance_attribute(session):
|
||||
instance = MockInstanceWithId()
|
||||
InstanceManager.register(session, instance)
|
||||
|
||||
key = (session[SESSION_ID_KEY], instance._id) # `_id` value taken from the instance
|
||||
key = (session[SESSION_USER_ID_KEY], instance._id) # `_id` value taken from the instance
|
||||
assert key in InstanceManager._instances
|
||||
assert InstanceManager._instances[key] == instance
|
||||
|
||||
@@ -181,8 +182,8 @@ def test_register_many_without_session():
|
||||
|
||||
InstanceManager.register_many(instance1, instance2)
|
||||
|
||||
key1 = (NOT_LOGGED, "id1")
|
||||
key2 = (NOT_LOGGED, "id2")
|
||||
key1 = (NO_SESSION, "id1")
|
||||
key2 = (NO_SESSION, "id2")
|
||||
assert key1 in InstanceManager._instances
|
||||
assert InstanceManager._instances[key1] == instance1
|
||||
assert key2 in InstanceManager._instances
|
||||
@@ -197,7 +198,7 @@ def test_remove_registered_instance(session, instance_id):
|
||||
InstanceManager.register(session, instance)
|
||||
|
||||
InstanceManager.remove(session, instance_id)
|
||||
key = (session[SESSION_ID_KEY], instance_id)
|
||||
key = (session[SESSION_USER_ID_KEY], instance_id)
|
||||
assert key not in InstanceManager._instances
|
||||
|
||||
|
||||
@@ -210,7 +211,7 @@ def test_remove_with_dispose_method(session, instance_id):
|
||||
InstanceManager.remove(session, instance_id)
|
||||
|
||||
assert hasattr(instance, "disposed") and instance.disposed
|
||||
key = (session[SESSION_ID_KEY], instance_id)
|
||||
key = (session[SESSION_USER_ID_KEY], instance_id)
|
||||
assert key not in InstanceManager._instances
|
||||
|
||||
|
||||
@@ -230,7 +231,7 @@ def test_get_session_id_returns_logged_in_user_id(session):
|
||||
Test that _get_session_id extracts the session ID correctly.
|
||||
"""
|
||||
session_id = InstanceManager.get_session_id(session)
|
||||
assert session_id == session[SESSION_ID_KEY]
|
||||
assert session_id == session[SESSION_USER_ID_KEY]
|
||||
|
||||
|
||||
def test_get_session_id_returns_default_logged_out_value():
|
||||
@@ -238,4 +239,7 @@ def test_get_session_id_returns_default_logged_out_value():
|
||||
Test that _get_session_id returns NOT_LOGGED when session is None.
|
||||
"""
|
||||
session_id = InstanceManager.get_session_id(None)
|
||||
assert session_id == NO_SESSION
|
||||
|
||||
session_id = InstanceManager.get_session_id({})
|
||||
assert session_id == NOT_LOGGED
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import pytest
|
||||
from fasthtml.components import *
|
||||
|
||||
from components.debugger.components.JsonViewer import JsonViewer, DictNode, ListNode, ValueNode
|
||||
from helpers import matches, span_icon, search_elements_by_name
|
||||
from components.debugger.components.JsonViewer import *
|
||||
from helpers import matches, span_icon, search_elements_by_name, extract_jsonviewer_node
|
||||
|
||||
JSON_VIEWER_INSTANCE_ID = "json_viewer"
|
||||
ML_20 = "margin-left: 20px;"
|
||||
@@ -19,6 +18,11 @@ def json_viewer(session):
|
||||
return JsonViewer(session, JSON_VIEWER_INSTANCE_ID, None, USER_ID, {})
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def helper():
|
||||
return JsonViewerHelper()
|
||||
|
||||
|
||||
def jv_id(x):
|
||||
return f"{JSON_VIEWER_INSTANCE_ID}-{x}"
|
||||
|
||||
@@ -61,7 +65,7 @@ def test_i_can_render(json_viewer):
|
||||
def test_i_can_render_simple_value(session, value, expected_inner):
|
||||
jsonv = JsonViewer(session, JSON_VIEWER_INSTANCE_ID, None, USER_ID, value)
|
||||
actual = jsonv.__ft__()
|
||||
to_compare = search_elements_by_name(actual, "div", attrs={"id": f"{jv_id("root")}"})[0]
|
||||
to_compare = search_elements_by_name(actual, "div", attrs={"id": f"{jv_id('root')}"})[0]
|
||||
expected = Div(
|
||||
|
||||
Div(
|
||||
@@ -70,7 +74,7 @@ def test_i_can_render_simple_value(session, value, expected_inner):
|
||||
expected_inner,
|
||||
style=ML_20),
|
||||
|
||||
id=f"{jv_id("root")}")
|
||||
id=f"{jv_id('root')}")
|
||||
|
||||
assert matches(to_compare, expected)
|
||||
|
||||
@@ -78,21 +82,25 @@ def test_i_can_render_simple_value(session, value, expected_inner):
|
||||
def test_i_can_render_expanded_list_node(session):
|
||||
value = [1, "hello", True]
|
||||
jsonv = JsonViewer(session, JSON_VIEWER_INSTANCE_ID, None, USER_ID, value)
|
||||
# Force expansion of the node
|
||||
jsonv.set_folding_mode("expand")
|
||||
|
||||
actual = jsonv.__ft__()
|
||||
to_compare = search_elements_by_name(actual, "div", attrs={"id": f"{jv_id("root")}"})[0]
|
||||
to_compare = search_elements_by_name(actual, "div", attrs={"id": f"{jv_id('root')}"})[0]
|
||||
to_compare = to_compare.children[0] # I want to compare what is inside the div
|
||||
|
||||
expected_inner = Span("[",
|
||||
Div(None, Span("0 : "), Span('1'), style=ML_20),
|
||||
Div(None, Span("1 : "), Span('"hello"'), style=ML_20),
|
||||
Div(None, Span("2 : "), Span('true'), style=ML_20),
|
||||
Div("]")),
|
||||
Div(None, Span("0 : "), Span('1', cls=f"{CLS_PREFIX}-number"), style=ML_20),
|
||||
Div(None, Span("1 : "), Span('"hello"', cls=f"{CLS_PREFIX}-string"), style=ML_20),
|
||||
Div(None, Span("2 : "), Span('true', cls=f"{CLS_PREFIX}-bool"), style=ML_20),
|
||||
Div("]"))
|
||||
|
||||
expected = Div(
|
||||
span_icon("expanded"),
|
||||
None, # 'key :' is missing for the first node
|
||||
expected_inner,
|
||||
style=ML_20)
|
||||
style=ML_20,
|
||||
id=jv_id(0))
|
||||
|
||||
assert matches(to_compare, expected)
|
||||
|
||||
@@ -100,21 +108,25 @@ def test_i_can_render_expanded_list_node(session):
|
||||
def test_i_can_render_expanded_dict_node(session):
|
||||
value = {"a": 1, "b": "hello", "c": True}
|
||||
jsonv = JsonViewer(session, JSON_VIEWER_INSTANCE_ID, None, USER_ID, value)
|
||||
# Force expansion of the node
|
||||
jsonv.set_folding_mode("expand")
|
||||
|
||||
actual = jsonv.__ft__()
|
||||
to_compare = search_elements_by_name(actual, "div", attrs={"id": f"{jv_id("root")}"})[0]
|
||||
to_compare = search_elements_by_name(actual, "div", attrs={"id": f"{jv_id('root')}"})[0]
|
||||
to_compare = to_compare.children[0] # I want to compare what is inside the div
|
||||
|
||||
expected_inner = Span("{",
|
||||
Div(None, Span("a : "), Span('1'), style=ML_20),
|
||||
Div(None, Span("b : "), Span('"hello"'), style=ML_20),
|
||||
Div(None, Span("c : "), Span('true'), style=ML_20),
|
||||
Div(None, Span("a : "), Span('1', cls=f"{CLS_PREFIX}-number"), style=ML_20),
|
||||
Div(None, Span("b : "), Span('"hello"', cls=f"{CLS_PREFIX}-string"), style=ML_20),
|
||||
Div(None, Span("c : "), Span('true', cls=f"{CLS_PREFIX}-bool"), style=ML_20),
|
||||
Div("}"))
|
||||
|
||||
expected = Div(
|
||||
span_icon("expanded"),
|
||||
None, # 'key :' is missing for the first node
|
||||
expected_inner,
|
||||
style=ML_20)
|
||||
style=ML_20,
|
||||
id=jv_id(0))
|
||||
|
||||
assert matches(to_compare, expected)
|
||||
|
||||
@@ -122,8 +134,11 @@ def test_i_can_render_expanded_dict_node(session):
|
||||
def test_i_can_render_expanded_list_of_dict_node(session):
|
||||
value = [{"a": 1, "b": "hello"}]
|
||||
jsonv = JsonViewer(session, JSON_VIEWER_INSTANCE_ID, None, USER_ID, value)
|
||||
# Force expansion of all nodes
|
||||
jsonv.set_folding_mode("expand")
|
||||
|
||||
actual = jsonv.__ft__()
|
||||
to_compare = search_elements_by_name(actual, "div", attrs={"id": f"{jv_id("root")}"})[0]
|
||||
to_compare = search_elements_by_name(actual, "div", attrs={"id": f"{jv_id('root')}"})[0]
|
||||
to_compare = to_compare.children[0] # I want to compare what is inside the div
|
||||
|
||||
expected_inner = Span("[",
|
||||
@@ -131,9 +146,10 @@ def test_i_can_render_expanded_list_of_dict_node(session):
|
||||
Div(span_icon("expanded"),
|
||||
Span("0 : "),
|
||||
Span("{",
|
||||
Div(None, Span("a : "), Span('1'), style=ML_20),
|
||||
Div(None, Span("b : "), Span('"hello"'), style=ML_20),
|
||||
Div(None, Span("a : "), Span('1', cls=f"{CLS_PREFIX}-number"), style=ML_20),
|
||||
Div(None, Span("b : "), Span('"hello"', cls=f"{CLS_PREFIX}-string"), style=ML_20),
|
||||
Div("}")),
|
||||
style=ML_20,
|
||||
id=f"{jv_id(1)}"),
|
||||
|
||||
Div("]"))
|
||||
@@ -142,11 +158,193 @@ def test_i_can_render_expanded_list_of_dict_node(session):
|
||||
span_icon("expanded"),
|
||||
None, # 'key :' is missing for the first node
|
||||
expected_inner,
|
||||
style=ML_20)
|
||||
style=ML_20,
|
||||
id=jv_id(0))
|
||||
|
||||
assert matches(to_compare, expected)
|
||||
|
||||
|
||||
def test_render_with_collapse_folding_mode(session):
|
||||
# Create a nested structure to test collapse rendering
|
||||
value = {"a": [1, 2, 3], "b": {"x": "y", "z": True}}
|
||||
jsonv = JsonViewer(session, JSON_VIEWER_INSTANCE_ID, None, USER_ID, value)
|
||||
|
||||
# Ensure folding mode is set to collapse (should be default)
|
||||
jsonv.set_folding_mode("collapse")
|
||||
assert jsonv.get_folding_mode() == "collapse"
|
||||
|
||||
actual = jsonv.__ft__()
|
||||
root_div = search_elements_by_name(actual, "div", attrs={"id": f"{jv_id('root')}"})[0]
|
||||
|
||||
# In collapse mode, the first level should show collapsed representations
|
||||
# The dict node should be rendered as "{...}"
|
||||
first_level_div = root_div.children[0]
|
||||
|
||||
# Verify that the first level shows a collapsed view
|
||||
expected_first_level = Div(
|
||||
span_icon("collapsed"),
|
||||
None, # No key for the root node
|
||||
Span("{...}", id=jv_id(0)),
|
||||
style=ML_20,
|
||||
id=jv_id(0)
|
||||
)
|
||||
|
||||
assert matches(first_level_div, expected_first_level)
|
||||
|
||||
|
||||
def test_render_with_specific_node_expanded_in_collapse_mode(session):
|
||||
# Create a nested structure to test mixed collapse/expand rendering
|
||||
value = {"a": [1, 2, 3], "b": {"x": "y", "z": True}}
|
||||
jsonv = JsonViewer(session, JSON_VIEWER_INSTANCE_ID, None, USER_ID, value)
|
||||
|
||||
# Ensure folding mode is set to collapse
|
||||
jsonv.set_folding_mode(FoldingMode.COLLAPSE)
|
||||
|
||||
# Manually expand the root node
|
||||
jsonv.set_node_folding(f"{JSON_VIEWER_INSTANCE_ID}-0", "expand")
|
||||
|
||||
actual = jsonv.__ft__()
|
||||
root_div = search_elements_by_name(actual, "div", attrs={"id": f"{jv_id('root')}"})[0]
|
||||
first_level_div = root_div.children[0]
|
||||
|
||||
as_node = extract_jsonviewer_node(first_level_div)
|
||||
|
||||
# The first level should now be expanded but children should be collapsed
|
||||
assert as_node.is_expanded is True
|
||||
|
||||
# Find div with "a" key
|
||||
a_node = as_node.find("a")
|
||||
b_node = as_node.find("b")
|
||||
|
||||
# Verify that both a and b nodes show collapsed representations
|
||||
assert a_node is not None
|
||||
assert b_node is not None
|
||||
|
||||
assert a_node.is_expanded is False
|
||||
assert a_node.text_value() == "[...]"
|
||||
|
||||
assert b_node.is_expanded is False
|
||||
assert b_node.text_value() == "{...}"
|
||||
|
||||
|
||||
def test_multiple_folding_levels_in_collapse_mode(session):
|
||||
# Create a deeply nested structure
|
||||
value = {"level1": {"level2": {"level3": [1, 2, 3]}}}
|
||||
jsonv = JsonViewer(session, JSON_VIEWER_INSTANCE_ID, None, USER_ID, value)
|
||||
|
||||
# Set folding mode to collapse
|
||||
jsonv.set_folding_mode(FoldingMode.COLLAPSE)
|
||||
|
||||
# Expand the first two levels
|
||||
jsonv.set_node_folding(f"{jsonv.get_id()}-0", FoldingMode.EXPAND) # top level
|
||||
jsonv.set_node_folding(f"{jsonv.get_id()}-1", FoldingMode.EXPAND) # level1
|
||||
jsonv.set_node_folding(f"{jsonv.get_id()}-2", FoldingMode.EXPAND) # level2
|
||||
|
||||
actual = jsonv.__ft__()
|
||||
root_div = search_elements_by_name(actual, "div", attrs={"id": f"{jv_id('root')}"})[0]
|
||||
|
||||
# Navigate to level3 to verify it's still collapsed
|
||||
first_level_div = root_div.children[0]
|
||||
first_level_node = extract_jsonviewer_node(first_level_div)
|
||||
assert first_level_node.is_expanded is True
|
||||
|
||||
# Find level2 in the rendered structure
|
||||
level2_node = first_level_node.find("level1.level2")
|
||||
assert level2_node is not None
|
||||
assert level2_node.is_expanded is True
|
||||
|
||||
# Find level3 in the rendered structure
|
||||
level3_node = level2_node.find("level3")
|
||||
assert level3_node is not None
|
||||
assert level3_node.is_expanded is False
|
||||
assert level3_node.text_value() == "[...]"
|
||||
|
||||
|
||||
def test_toggle_between_folding_modes(session):
|
||||
value = {"a": [1, 2, 3], "b": {"x": "y"}}
|
||||
jsonv = JsonViewer(session, JSON_VIEWER_INSTANCE_ID, None, USER_ID, value)
|
||||
|
||||
# Start with collapse mode
|
||||
jsonv.set_folding_mode("collapse")
|
||||
|
||||
# Expand specific node
|
||||
jsonv.set_node_folding(f"{JSON_VIEWER_INSTANCE_ID}-0", "expand")
|
||||
|
||||
# Verify node is in tracked nodes (exceptions to collapse mode)
|
||||
assert f"{JSON_VIEWER_INSTANCE_ID}-0" in jsonv._nodes_to_track
|
||||
|
||||
# Now switch to expand mode
|
||||
jsonv.set_folding_mode("expand")
|
||||
|
||||
# Tracked nodes should be cleared
|
||||
assert len(jsonv._nodes_to_track) == 0
|
||||
|
||||
# Collapse specific node
|
||||
jsonv.set_node_folding(f"{JSON_VIEWER_INSTANCE_ID}-0", "collapse")
|
||||
|
||||
# Verify node is in tracked nodes (exceptions to expand mode)
|
||||
assert f"{JSON_VIEWER_INSTANCE_ID}-0" in jsonv._nodes_to_track
|
||||
|
||||
# Render and verify the output
|
||||
actual = jsonv.__ft__()
|
||||
root_div = search_elements_by_name(actual, "div", attrs={"id": f"{jv_id('root')}"})[0]
|
||||
first_level_div = root_div.children[0]
|
||||
|
||||
# First level should be collapsed in an otherwise expanded tree
|
||||
as_node = extract_jsonviewer_node(first_level_div)
|
||||
assert as_node.is_expanded is False
|
||||
assert as_node.text_value() == "{...}"
|
||||
|
||||
|
||||
def test_custom_hook_rendering(session, helper):
|
||||
# Define a custom hook for testing
|
||||
def custom_predicate(key, node, h):
|
||||
return isinstance(node.value, str) and node.value == "custom_hook_test"
|
||||
|
||||
def custom_renderer(key, node, h):
|
||||
return Span("CUSTOM_HOOK_RENDER", cls="custom-hook-class")
|
||||
|
||||
hooks = [(custom_predicate, custom_renderer)]
|
||||
|
||||
# Create JsonViewer with the custom hook
|
||||
jsonv = JsonViewer(session, JSON_VIEWER_INSTANCE_ID, None, USER_ID, "custom_hook_test", hooks=hooks)
|
||||
|
||||
actual = jsonv.__ft__()
|
||||
to_compare = search_elements_by_name(actual, "div", attrs={"id": f"{jv_id('root')}"})[0]
|
||||
|
||||
expected = Div(
|
||||
Div(
|
||||
None,
|
||||
None,
|
||||
Span("CUSTOM_HOOK_RENDER", cls="custom-hook-class"),
|
||||
style=ML_20),
|
||||
id=f"{jv_id('root')}")
|
||||
|
||||
assert matches(to_compare, expected)
|
||||
|
||||
|
||||
def test_folding_mode_operations(session):
|
||||
jsonv = JsonViewer(session, JSON_VIEWER_INSTANCE_ID, None, USER_ID, {"a": [1, 2, 3]})
|
||||
|
||||
# Check default folding mode
|
||||
assert jsonv.get_folding_mode() == "collapse"
|
||||
|
||||
# Change folding mode
|
||||
jsonv.set_folding_mode("expand")
|
||||
assert jsonv.get_folding_mode() == "expand"
|
||||
|
||||
# Set node folding
|
||||
node_id = f"{JSON_VIEWER_INSTANCE_ID}-0"
|
||||
jsonv.set_node_folding(node_id, "collapse")
|
||||
|
||||
# Node should be in tracked nodes since it differs from the default mode
|
||||
assert node_id in jsonv._nodes_to_track
|
||||
|
||||
# Restore to match default mode
|
||||
jsonv.set_node_folding(node_id, "expand")
|
||||
assert node_id not in jsonv._nodes_to_track
|
||||
|
||||
|
||||
@pytest.mark.parametrize("input_value, expected_output", [
|
||||
('Hello World', '"Hello World"'), # No quotes in input
|
||||
('Hello "World"', "'Hello \"World\"'"), # Contains double quotes
|
||||
@@ -157,3 +355,16 @@ def test_i_can_render_expanded_list_of_dict_node(session):
|
||||
def test_add_quotes(input_value, expected_output):
|
||||
result = JsonViewer.add_quotes(input_value)
|
||||
assert result == expected_output
|
||||
|
||||
|
||||
def test_helper_is_sha256(helper):
|
||||
# Valid SHA256
|
||||
assert helper.is_sha256("a" * 64)
|
||||
assert helper.is_sha256("0123456789abcdef" * 4)
|
||||
assert helper.is_sha256("0123456789ABCDEF" * 4)
|
||||
|
||||
# Invalid cases
|
||||
assert not helper.is_sha256("a" * 63) # Too short
|
||||
assert not helper.is_sha256("a" * 65) # Too long
|
||||
assert not helper.is_sha256("g" * 64) # Invalid character
|
||||
assert not helper.is_sha256("test") # Not a hash
|
||||
|
||||
248
tests/test_mcp_server.py
Normal file
248
tests/test_mcp_server.py
Normal file
@@ -0,0 +1,248 @@
|
||||
import pytest
|
||||
import inspect
|
||||
from typing import Dict, Any, List, Optional, Callable
|
||||
|
||||
from ai.mcp_server import DummyMCPServer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_server(session):
|
||||
return DummyMCPServer(session, None)
|
||||
|
||||
|
||||
def test_register_tool_basic(mcp_server):
|
||||
"""Test basic tool registration"""
|
||||
|
||||
# Define a simple test function
|
||||
def test_func(param1: str, param2: int):
|
||||
return f"Test {param1} {param2}"
|
||||
|
||||
# Register the tool
|
||||
result = mcp_server.register_tool("test_tool", test_func)
|
||||
|
||||
# Verify the tool was registered
|
||||
assert "test_tool" in mcp_server.available_tools
|
||||
assert mcp_server.available_tools["test_tool"]["name"] == "test_tool"
|
||||
assert mcp_server.available_tools["test_tool"]["handler"] == test_func
|
||||
|
||||
# Verify method chaining works
|
||||
assert result is mcp_server
|
||||
|
||||
|
||||
def test_register_tool_with_description(mcp_server):
|
||||
"""Test tool registration with a custom description"""
|
||||
|
||||
def test_func():
|
||||
return "test"
|
||||
|
||||
mcp_server.register_tool("test_tool", test_func, description="Custom description")
|
||||
|
||||
assert mcp_server.available_tools["test_tool"]["description"] == "Custom description"
|
||||
|
||||
|
||||
def test_register_tool_without_description(mcp_server):
|
||||
"""Test tool registration without a description"""
|
||||
|
||||
def test_func():
|
||||
return "test"
|
||||
|
||||
mcp_server.register_tool("test_tool", test_func)
|
||||
|
||||
assert mcp_server.available_tools["test_tool"]["description"] == "Tool test_tool"
|
||||
|
||||
|
||||
def test_register_tool_parameter_types(mcp_server):
|
||||
"""Test parameter type inference"""
|
||||
|
||||
def test_func(str_param: str, int_param: int, float_param: float, bool_param: bool, untyped_param):
|
||||
return "test"
|
||||
|
||||
mcp_server.register_tool("test_tool", test_func)
|
||||
|
||||
params = mcp_server.available_tools["test_tool"]["parameters"]
|
||||
|
||||
assert params["str_param"]["type"] == "string"
|
||||
assert params["int_param"]["type"] == "integer"
|
||||
assert params["float_param"]["type"] == "number"
|
||||
assert params["bool_param"]["type"] == "boolean"
|
||||
assert params["untyped_param"]["type"] == "string" # Default type for untyped parameters
|
||||
|
||||
|
||||
def test_register_tool_parameter_descriptions(mcp_server):
|
||||
"""Test parameter descriptions"""
|
||||
|
||||
def test_func(param1, param2):
|
||||
return "test"
|
||||
|
||||
mcp_server.register_tool("test_tool", test_func)
|
||||
|
||||
params = mcp_server.available_tools["test_tool"]["parameters"]
|
||||
|
||||
assert params["param1"]["description"] == "Parameter param1"
|
||||
assert params["param2"]["description"] == "Parameter param2"
|
||||
|
||||
|
||||
def test_register_tool_with_sphinx_docstring(mcp_server):
|
||||
"""Test parameter descriptions from Sphinx-style docstrings"""
|
||||
|
||||
def test_func(name: str, age: int):
|
||||
"""Test function with Sphinx docstring
|
||||
|
||||
:param name: The person's name
|
||||
:param age: The person's age in years
|
||||
:return: A greeting message
|
||||
"""
|
||||
return f"Hello {name}, you are {age} years old!"
|
||||
|
||||
mcp_server.register_tool("sphinx_doc_tool", test_func)
|
||||
|
||||
params = mcp_server.available_tools["sphinx_doc_tool"]["parameters"]
|
||||
|
||||
assert params["name"]["description"] == "The person's name"
|
||||
assert params["age"]["description"] == "The person's age in years"
|
||||
|
||||
|
||||
def test_register_tool_with_google_docstring(mcp_server):
|
||||
"""Test parameter descriptions from Google-style docstrings"""
|
||||
|
||||
def test_func(name: str, age: int, height: float):
|
||||
"""Test function with Google-style docstring
|
||||
|
||||
Args:
|
||||
name: The person's name
|
||||
age: The person's age in years
|
||||
height: The person's height in meters
|
||||
|
||||
Returns:
|
||||
A greeting message
|
||||
"""
|
||||
return f"Hello {name}, you are {age} years old and {height}m tall!"
|
||||
|
||||
mcp_server.register_tool("google_doc_tool", test_func)
|
||||
|
||||
params = mcp_server.available_tools["google_doc_tool"]["parameters"]
|
||||
|
||||
assert params["name"]["description"] == "The person's name"
|
||||
assert params["age"]["description"] == "The person's age in years"
|
||||
assert params["height"]["description"] == "The person's height in meters"
|
||||
|
||||
|
||||
def test_register_tool_with_parameters_keyword(mcp_server):
|
||||
"""Test parameter descriptions with 'Parameters:' keyword instead of 'Args:'"""
|
||||
|
||||
def test_func(x: int, y: int):
|
||||
"""Test function with Parameters keyword
|
||||
|
||||
Parameters:
|
||||
x: The x coordinate
|
||||
y: The y coordinate
|
||||
|
||||
Returns:
|
||||
The sum of coordinates
|
||||
"""
|
||||
return x + y
|
||||
|
||||
mcp_server.register_tool("parameters_doc_tool", test_func)
|
||||
|
||||
params = mcp_server.available_tools["parameters_doc_tool"]["parameters"]
|
||||
|
||||
assert params["x"]["description"] == "The x coordinate"
|
||||
assert params["y"]["description"] == "The y coordinate"
|
||||
|
||||
|
||||
def test_register_tool_with_mixed_docstrings(mcp_server):
|
||||
"""Test parameter descriptions with mixed docstring styles"""
|
||||
|
||||
def test_func(a: int, b: str, c: float):
|
||||
"""Test function with mixed docstring styles
|
||||
|
||||
:param a: Parameter a from Sphinx style
|
||||
|
||||
Args:
|
||||
b: Parameter b from Google style
|
||||
c: Parameter c from Google style
|
||||
"""
|
||||
return f"{a} {b} {c}"
|
||||
|
||||
mcp_server.register_tool("mixed_doc_tool", test_func)
|
||||
|
||||
params = mcp_server.available_tools["mixed_doc_tool"]["parameters"]
|
||||
|
||||
assert params["a"]["description"] == "Parameter a from Sphinx style"
|
||||
assert params["b"]["description"] == "Parameter b from Google style"
|
||||
assert params["c"]["description"] == "Parameter c from Google style"
|
||||
|
||||
|
||||
def test_register_tool_with_missing_docstrings(mcp_server):
|
||||
"""Test parameter descriptions when some parameters are missing from docstring"""
|
||||
|
||||
def test_func(a: int, b: str, c: float):
|
||||
"""Test function with incomplete docstring
|
||||
|
||||
Args:
|
||||
a: Parameter a description
|
||||
"""
|
||||
return f"{a} {b} {c}"
|
||||
|
||||
mcp_server.register_tool("incomplete_doc_tool", test_func)
|
||||
|
||||
params = mcp_server.available_tools["incomplete_doc_tool"]["parameters"]
|
||||
|
||||
assert params["a"]["description"] == "Parameter a description"
|
||||
assert params["b"]["description"] == "Parameter b" # Default description
|
||||
assert params["c"]["description"] == "Parameter c" # Default description
|
||||
|
||||
|
||||
async def test_tool_can_be_called(mcp_server):
|
||||
"""Test that a registered tool can be called through call_tool method"""
|
||||
|
||||
# Define a simple test function
|
||||
def test_func(value: int):
|
||||
return value * 2
|
||||
|
||||
# Register the tool
|
||||
mcp_server.register_tool("multiply", test_func)
|
||||
|
||||
# Call the tool
|
||||
result = await mcp_server.call_tool("multiply", {"value": 5})
|
||||
|
||||
# Verify the result
|
||||
assert result["success"] is True
|
||||
assert result["result"] == 10
|
||||
assert result["tool_name"] == "multiply"
|
||||
|
||||
|
||||
async def test_async_tool_can_be_called(mcp_server):
|
||||
"""Test that a registered async tool can be called through call_tool method"""
|
||||
|
||||
# Define an async test function
|
||||
async def async_test_func(value: int):
|
||||
return value * 3
|
||||
|
||||
# Register the async tool
|
||||
mcp_server.register_tool("async_multiply", async_test_func)
|
||||
|
||||
# Call the async tool
|
||||
result = await mcp_server.call_tool("async_multiply", {"value": 5})
|
||||
|
||||
# Verify the result
|
||||
assert result["success"] is True
|
||||
assert result["result"] == 15
|
||||
assert result["tool_name"] == "async_multiply"
|
||||
|
||||
|
||||
def test_multiple_tools_registration(mcp_server):
|
||||
"""Test registering multiple tools"""
|
||||
|
||||
def tool1(): return "tool1"
|
||||
|
||||
def tool2(): return "tool2"
|
||||
|
||||
nb_internal_tools = len(mcp_server.available_tools)
|
||||
mcp_server.register_tool("tool1", tool1)
|
||||
mcp_server.register_tool("tool2", tool2)
|
||||
|
||||
# Check both tools were registered
|
||||
assert "tool1" in mcp_server.available_tools
|
||||
assert "tool2" in mcp_server.available_tools
|
||||
assert len(mcp_server.available_tools) == nb_internal_tools + 2
|
||||
@@ -176,7 +176,7 @@ def test_add_table_success(db, settings_manager_with_existing_repo):
|
||||
|
||||
def test_add_table_repository_not_found(db):
|
||||
"""Test adding a table to a non-existent repository."""
|
||||
with pytest.raises(ValueError, match="Repository 'NonExistentRepo' does not exist."):
|
||||
with pytest.raises(NameError, match="Repository 'NonExistentRepo' does not exist."):
|
||||
db.add_table("NonExistentRepo", "NewTable")
|
||||
|
||||
|
||||
@@ -210,13 +210,13 @@ def test_modify_table_success(db, settings_manager_with_existing_repo):
|
||||
|
||||
def test_modify_table_repository_not_found(db):
|
||||
"""Test modifying a table in a non-existent repository."""
|
||||
with pytest.raises(ValueError, match="Repository 'NonExistentRepo' does not exist."):
|
||||
with pytest.raises(NameError, match="Repository 'NonExistentRepo' does not exist."):
|
||||
db.modify_table("NonExistentRepo", "Table1", "NewTable")
|
||||
|
||||
|
||||
def test_modify_table_not_found(db, settings_manager_with_existing_repo):
|
||||
"""Test modifying a non-existent table."""
|
||||
with pytest.raises(ValueError, match="Table 'NonExistentTable' does not exist in repository 'ExistingRepo'."):
|
||||
with pytest.raises(NameError, match="Table 'NonExistentTable' does not exist in repository 'ExistingRepo'."):
|
||||
db.modify_table("ExistingRepo", "NonExistentTable", "NewTable")
|
||||
|
||||
|
||||
@@ -249,13 +249,13 @@ def test_remove_table_success(db, settings_manager_with_existing_repo):
|
||||
|
||||
def test_remove_table_repository_not_found(db):
|
||||
"""Test removing a table from a non-existent repository."""
|
||||
with pytest.raises(ValueError, match="Repository 'NonExistentRepo' does not exist."):
|
||||
with pytest.raises(NameError, match="Repository 'NonExistentRepo' does not exist."):
|
||||
db.remove_table("NonExistentRepo", "Table1")
|
||||
|
||||
|
||||
def test_remove_table_not_found(db, settings_manager_with_existing_repo):
|
||||
"""Test removing a non-existent table."""
|
||||
with pytest.raises(ValueError, match="Table 'NonExistentTable' does not exist in repository 'ExistingRepo'."):
|
||||
with pytest.raises(NameError, match="Table 'NonExistentTable' does not exist in repository 'ExistingRepo'."):
|
||||
db.remove_table("ExistingRepo", "NonExistentTable")
|
||||
|
||||
|
||||
@@ -273,3 +273,41 @@ def test_remove_table_empty_repository_name(db):
|
||||
db.remove_table("", "Table1")
|
||||
with pytest.raises(ValueError, match="Repository name cannot be empty."):
|
||||
db.remove_table(None, "Table1")
|
||||
|
||||
def test_repository_exists(db, settings_manager):
|
||||
assert db.exists_repository("SomeRepo") is False
|
||||
|
||||
settings = RepositoriesSettings()
|
||||
repo = Repository(name="SomeRepo", tables=["Table1"])
|
||||
settings.repositories.append(repo)
|
||||
settings_manager.save(db.session, REPOSITORIES_SETTINGS_ENTRY, settings)
|
||||
|
||||
assert db.exists_repository("SomeRepo") is True
|
||||
|
||||
|
||||
def test_repository_table(db, settings_manager_with_existing_repo):
|
||||
assert db.exists_table("ExistingRepo", "SomeTable") is False
|
||||
|
||||
db.add_table("ExistingRepo", "SomeTable")
|
||||
|
||||
assert db.exists_table("ExistingRepo", "SomeTable") is True
|
||||
|
||||
def test_exists_table_fails_when_repo_doesnt_exist(db):
|
||||
assert db.exists_table("NonExistentRepo", "SomeTable") is False
|
||||
|
||||
def test_ensure_exists(db, settings_manager):
|
||||
settings = settings_manager.load(db.session, REPOSITORIES_SETTINGS_ENTRY, default=RepositoriesSettings())
|
||||
assert len(settings.repositories) == 0
|
||||
|
||||
db.ensure_exists("SomeRepo", "SomeTable")
|
||||
|
||||
settings = settings_manager.load(db.session, REPOSITORIES_SETTINGS_ENTRY)
|
||||
assert len(settings.repositories) == 1
|
||||
assert settings.repositories[0].name == "SomeRepo"
|
||||
assert settings.repositories[0].tables == ["SomeTable"]
|
||||
|
||||
db.ensure_exists("SomeRepo", "SomeTable") # as no effect when called twice
|
||||
settings = settings_manager.load(db.session, REPOSITORIES_SETTINGS_ENTRY)
|
||||
assert len(settings.repositories) == 1
|
||||
assert settings.repositories[0].name == "SomeRepo"
|
||||
assert settings.repositories[0].tables == ["SomeTable"]
|
||||
|
||||
@@ -2,7 +2,7 @@ import dataclasses
|
||||
|
||||
import pytest
|
||||
|
||||
from core.settings_management import SettingsManager, MemoryDbEngine
|
||||
from core.settings_management import SettingsManager, MemoryDbEngine, GenericDbManager, NestedSettingsManager
|
||||
|
||||
FAKE_USER_ID = "FakeUserId"
|
||||
|
||||
@@ -20,6 +20,19 @@ class DummySettings:
|
||||
prop2: str
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DummyObjectWithDefault:
|
||||
a: int = 5
|
||||
b: str = "default_b"
|
||||
c: bool = False
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DummySettingsWithDefault:
|
||||
prop1: DummyObjectWithDefault = dataclasses.field(default_factory=DummyObjectWithDefault)
|
||||
prop2: str = "prop2"
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def manager():
|
||||
return SettingsManager(MemoryDbEngine())
|
||||
@@ -33,6 +46,16 @@ def settings():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def generic_db_manager(session, manager):
|
||||
return GenericDbManager(session, manager, "TestSettings", DummySettingsWithDefault)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def nested_settings_manager(session, manager):
|
||||
return NestedSettingsManager(session, manager, "TestSettings", DummySettingsWithDefault, "prop1")
|
||||
|
||||
|
||||
def test_i_can_save_and_load_settings(session, manager, settings):
|
||||
manager.save(session, "MyEntry", settings)
|
||||
|
||||
@@ -86,3 +109,213 @@ def test_i_can_put_many_items_list(session, manager):
|
||||
assert loaded['key1'] == 'value1'
|
||||
assert loaded['key2'] == 'value2'
|
||||
assert loaded['key3'] == 'value3'
|
||||
|
||||
|
||||
# Tests for GenericDbManager
|
||||
|
||||
def test_generic_db_manager_get_attribute(generic_db_manager, manager, session):
|
||||
# Setup initial settings
|
||||
initial_settings = DummySettingsWithDefault(
|
||||
prop1=DummyObjectWithDefault(1, "2", True),
|
||||
prop2="initial_value"
|
||||
)
|
||||
manager.save(session, "TestSettings", initial_settings)
|
||||
|
||||
# Get attribute via GenericDbManager
|
||||
assert generic_db_manager.prop2 == "initial_value"
|
||||
assert generic_db_manager.prop1.a == 1
|
||||
assert generic_db_manager.prop1.b == "2"
|
||||
assert generic_db_manager.prop1.c is True
|
||||
|
||||
|
||||
def test_generic_db_manager_set_attribute(generic_db_manager, manager, session):
|
||||
# Setup initial settings
|
||||
initial_settings = DummySettingsWithDefault(
|
||||
prop1=DummyObjectWithDefault(1, "2", True),
|
||||
prop2="initial_value"
|
||||
)
|
||||
manager.save(session, "TestSettings", initial_settings)
|
||||
|
||||
# Set attribute via GenericDbManager
|
||||
generic_db_manager.prop2 = "updated_value"
|
||||
|
||||
# Verify that the change was saved to the database
|
||||
loaded_settings = manager.load(session, "TestSettings")
|
||||
assert loaded_settings.prop2 == "updated_value"
|
||||
|
||||
# Also verify direct access works
|
||||
assert generic_db_manager.prop2 == "updated_value"
|
||||
|
||||
|
||||
def test_generic_db_manager_set_nested_attribute(generic_db_manager, manager, session):
|
||||
# Setup initial settings
|
||||
initial_settings = DummySettingsWithDefault(
|
||||
prop1=DummyObjectWithDefault(1, "2", True),
|
||||
prop2="initial_value"
|
||||
)
|
||||
manager.save(session, "TestSettings", initial_settings)
|
||||
|
||||
# Set nested attribute
|
||||
generic_db_manager.prop1.a = 42
|
||||
generic_db_manager.prop1.b = "modified"
|
||||
generic_db_manager.prop1.c = False
|
||||
|
||||
# Verify the changes were saved
|
||||
loaded_settings = manager.load(session, "TestSettings")
|
||||
assert loaded_settings.prop1.a == 42
|
||||
assert loaded_settings.prop1.b == "modified"
|
||||
assert loaded_settings.prop1.c is False
|
||||
|
||||
|
||||
def test_generic_db_manager_attribute_error(generic_db_manager):
|
||||
# Test that accessing a non-existent attribute raises AttributeError
|
||||
with pytest.raises(AttributeError) as excinfo:
|
||||
generic_db_manager.non_existent_attribute
|
||||
|
||||
assert "has no attribute 'non_existent_attribute'." in str(excinfo.value)
|
||||
|
||||
|
||||
def test_generic_db_manager_set_attribute_error(generic_db_manager):
|
||||
# Test that setting a non-existent attribute raises AttributeError
|
||||
with pytest.raises(AttributeError) as excinfo:
|
||||
generic_db_manager.non_existent_attribute = "value"
|
||||
|
||||
assert "has no attribute 'non_existent_attribute'." in str(excinfo.value)
|
||||
|
||||
|
||||
def test_generic_db_manager_no_initialization(session, manager):
|
||||
# Test initialization with default object
|
||||
db_manager = GenericDbManager(session, manager, "TestSettings", DummySettingsWithDefault)
|
||||
|
||||
value = db_manager.prop2 # Accessing an attribute will create a new entry
|
||||
assert value == "prop2"
|
||||
|
||||
|
||||
def test_generic_db_manager_no_initialization_set(session, manager):
|
||||
db_manager = GenericDbManager(session, manager, "TestSettings", DummySettingsWithDefault)
|
||||
|
||||
db_manager.prop2 = "new_value"
|
||||
|
||||
# Verify that a default object was created and saved
|
||||
loaded_settings = manager.load(session, "TestSettings")
|
||||
assert isinstance(loaded_settings, DummySettingsWithDefault)
|
||||
|
||||
# The attributes should have their default values
|
||||
assert loaded_settings.prop1.a == 5
|
||||
assert loaded_settings.prop1.b == "default_b"
|
||||
assert loaded_settings.prop1.c is False
|
||||
assert loaded_settings.prop2 == "new_value"
|
||||
|
||||
|
||||
# Tests for NestedSettingsManager
|
||||
|
||||
def test_nested_settings_manager_get_attribute(nested_settings_manager, manager, session):
|
||||
# Setup initial settings
|
||||
initial_settings = DummySettingsWithDefault(
|
||||
prop1=DummyObjectWithDefault(10, "test_value", True),
|
||||
prop2="initial_value"
|
||||
)
|
||||
manager.save(session, "TestSettings", initial_settings)
|
||||
|
||||
# Get attributes via NestedSettingsManager
|
||||
assert nested_settings_manager.a == 10
|
||||
assert nested_settings_manager.b == "test_value"
|
||||
assert nested_settings_manager.c is True
|
||||
|
||||
|
||||
def test_nested_settings_manager_set_attribute(nested_settings_manager, manager, session):
|
||||
# Setup initial settings
|
||||
initial_settings = DummySettingsWithDefault(
|
||||
prop1=DummyObjectWithDefault(10, "test_value", True),
|
||||
prop2="initial_value"
|
||||
)
|
||||
manager.save(session, "TestSettings", initial_settings)
|
||||
|
||||
# Set attribute via NestedSettingsManager
|
||||
nested_settings_manager.a = 99
|
||||
nested_settings_manager.b = "updated_nested_value"
|
||||
nested_settings_manager.c = False
|
||||
|
||||
# Verify that the changes were saved to the database
|
||||
loaded_settings = manager.load(session, "TestSettings")
|
||||
assert loaded_settings.prop1.a == 99
|
||||
assert loaded_settings.prop1.b == "updated_nested_value"
|
||||
assert loaded_settings.prop1.c is False
|
||||
|
||||
# Also verify direct access works
|
||||
assert nested_settings_manager.a == 99
|
||||
assert nested_settings_manager.b == "updated_nested_value"
|
||||
assert nested_settings_manager.c is False
|
||||
|
||||
|
||||
def test_nested_settings_manager_attribute_error(nested_settings_manager):
|
||||
# Test that accessing a non-existent attribute raises AttributeError
|
||||
with pytest.raises(AttributeError) as excinfo:
|
||||
nested_settings_manager.non_existent_attribute
|
||||
|
||||
assert "has no attribute 'non_existent_attribute'" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_nested_settings_manager_set_attribute_error(nested_settings_manager):
|
||||
# Test that setting a non-existent attribute raises AttributeError
|
||||
with pytest.raises(AttributeError) as excinfo:
|
||||
nested_settings_manager.non_existent_attribute = "value"
|
||||
|
||||
assert "has no attribute 'non_existent_attribute'" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_nested_settings_manager_no_initialization(session, manager):
|
||||
# Test initialization with default object
|
||||
nested_db_manager = NestedSettingsManager(session, manager, "TestSettings", DummySettingsWithDefault, "prop1")
|
||||
|
||||
# Accessing an attribute will create a new entry with default values
|
||||
assert nested_db_manager.a == 5
|
||||
assert nested_db_manager.b == "default_b"
|
||||
assert nested_db_manager.c is False
|
||||
|
||||
|
||||
def test_nested_settings_manager_no_initialization_set(session, manager):
|
||||
nested_db_manager = NestedSettingsManager(session, manager, "TestSettings", DummySettingsWithDefault, "prop1")
|
||||
|
||||
# Set attribute will create a new entry with the modified value
|
||||
nested_db_manager.a = 42
|
||||
|
||||
# Verify that a default object was created and saved
|
||||
loaded_settings = manager.load(session, "TestSettings")
|
||||
assert isinstance(loaded_settings, DummySettingsWithDefault)
|
||||
|
||||
# The specified attribute should be updated, while others retain default values
|
||||
assert loaded_settings.prop1.a == 42
|
||||
assert loaded_settings.prop1.b == "default_b"
|
||||
assert loaded_settings.prop1.c is False
|
||||
assert loaded_settings.prop2 == "prop2"
|
||||
|
||||
|
||||
def test_nested_settings_manager_invalid_nested_attribute(session, manager):
|
||||
# Test with an invalid nested attribute
|
||||
invalid_nested_manager = NestedSettingsManager(session, manager, "TestSettings", DummySettingsWithDefault,
|
||||
"non_existent")
|
||||
|
||||
# Accessing any attribute should raise an AttributeError
|
||||
with pytest.raises(AttributeError) as excinfo:
|
||||
invalid_nested_manager.a
|
||||
|
||||
assert "has no attribute 'non_existent'" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_i_can_update_nested_settings(nested_settings_manager, session, manager):
|
||||
new_values = {
|
||||
"a": 10,
|
||||
"b": "new_value",
|
||||
"c": True,
|
||||
}
|
||||
nested_settings_manager.update(new_values)
|
||||
|
||||
# Verify that a default object was created and saved
|
||||
loaded_settings = manager.load(session, "TestSettings")
|
||||
assert isinstance(loaded_settings, DummySettingsWithDefault)
|
||||
|
||||
# The specified attribute should be updated, while others retain default values
|
||||
assert loaded_settings.prop1.a == 10
|
||||
assert loaded_settings.prop1.b == "new_value"
|
||||
assert loaded_settings.prop1.c is True
|
||||
|
||||
Reference in New Issue
Block a user