Refactored instances management

This commit is contained in:
2025-11-23 19:52:03 +01:00
parent 97247f824c
commit b1be747101
24 changed files with 783 additions and 216 deletions

View File

@@ -10,7 +10,7 @@ from myfasthtml.controls.Keyboard import Keyboard
from myfasthtml.controls.Layout import Layout
from myfasthtml.controls.TabsManager import TabsManager
from myfasthtml.controls.helpers import Ids, mk
from myfasthtml.core.instances import InstancesManager, RootInstance
from myfasthtml.core.instances import SingleInstance
from myfasthtml.icons.carbon import volume_object_storage
from myfasthtml.icons.fluent_p3 import folder_open20_regular
from myfasthtml.myfastapp import create_app
@@ -32,21 +32,22 @@ app, rt = create_app(protect_routes=True,
@rt("/")
def index(session):
layout = InstancesManager.get(session, Ids.Layout, Layout, RootInstance, "Testing Layout")
session_instance = SingleInstance(session=session, _id=Ids.UserSession)
layout = Layout(session_instance, "Testing Layout")
layout.set_footer("Goodbye World")
tabs_manager = TabsManager(layout, _id=f"{Ids.TabsManager}-main")
tabs_manager = TabsManager(layout, _id=f"{TabsManager.get_prefix()}-main")
btn_show_right_drawer = mk.button("show",
command=layout.commands.toggle_drawer("right"),
id="btn_show_right_drawer_id")
instances_debugger = InstancesManager.get(session, Ids.InstancesDebugger, InstancesDebugger, layout)
instances_debugger = InstancesDebugger(layout)
btn_show_instances_debugger = mk.label("Instances",
icon=volume_object_storage,
command=tabs_manager.commands.add_tab("Instances", instances_debugger),
id=instances_debugger.get_id())
commands_debugger = InstancesManager.get(session, Ids.CommandsDebugger, CommandsDebugger, layout)
commands_debugger = CommandsDebugger(layout)
btn_show_commands_debugger = mk.label("Commands",
icon=None,
command=tabs_manager.commands.add_tab("Commands", commands_debugger),

View File

@@ -26,8 +26,8 @@ class Boundaries(SingleInstance):
Keep the boundaries updated
"""
def __init__(self, session, owner, container_id: str = None, on_resize=None):
super().__init__(session, Ids.Boundaries, owner)
def __init__(self, owner, container_id: str = None, on_resize=None, _id=None):
super().__init__(owner, _id=_id)
self._owner = owner
self._container_id = container_id or owner.get_id()
self._on_resize = on_resize

View File

@@ -1,13 +1,12 @@
from myfasthtml.controls.VisNetwork import VisNetwork
from myfasthtml.controls.helpers import Ids
from myfasthtml.core.commands import CommandsManager
from myfasthtml.core.instances import SingleInstance
from myfasthtml.core.network_utils import from_parent_child_list
class CommandsDebugger(SingleInstance):
def __init__(self, session, parent, _id=None):
super().__init__(session, Ids.CommandsDebugger, parent)
def __init__(self, parent, _id=None):
super().__init__(parent, _id=_id)
def render(self):
commands = self._get_commands()

View File

@@ -16,7 +16,7 @@ logger = logging.getLogger("FileUpload")
class FileUploadState(DbObject):
def __init__(self, owner):
super().__init__(owner.get_session(), owner.get_id())
super().__init__(owner)
with self.initializing():
# persisted in DB
@@ -37,7 +37,7 @@ class Commands(BaseCommands):
class FileUpload(MultipleInstance):
def __init__(self, parent, _id=None):
super().__init__(Ids.FileUpload, parent, _id=_id)
super().__init__(parent, _id=_id)
self.commands = Commands(self)
self._state = FileUploadState(self)

View File

@@ -1,12 +1,11 @@
from myfasthtml.controls.VisNetwork import VisNetwork
from myfasthtml.controls.helpers import Ids
from myfasthtml.core.instances import SingleInstance, InstancesManager
from myfasthtml.core.network_utils import from_parent_child_list
class InstancesDebugger(SingleInstance):
def __init__(self, session, parent, _id=None):
super().__init__(session, Ids.InstancesDebugger, parent)
def __init__(self, parent, _id=None):
super().__init__(parent, _id=_id)
def render(self):
instances = self._get_instances()
@@ -15,8 +14,15 @@ class InstancesDebugger(SingleInstance):
label_getter=lambda x: x.get_prefix(),
parent_getter=lambda x: x.get_parent().get_id() if x.get_parent() else None
)
for edge in edges:
edge["color"] = "green"
edge["arrows"] = {"to": {"enabled": False, "type": "circle"}}
for node in nodes:
node["shape"] = "box"
vis_network = VisNetwork(self, nodes=nodes, edges=edges)
#vis_network.add_to_options(physics={"wind": {"x": 0, "y": 1}})
return vis_network
def _get_instances(self):

View File

@@ -2,14 +2,13 @@ import json
from fasthtml.xtend import Script
from myfasthtml.controls.helpers import Ids
from myfasthtml.core.commands import BaseCommand
from myfasthtml.core.instances import MultipleInstance
class Keyboard(MultipleInstance):
def __init__(self, parent, _id=None, combinations=None):
super().__init__(Ids.Keyboard, parent)
super().__init__(parent, _id=_id)
self.combinations = combinations or {}
def add(self, sequence: str, command: BaseCommand):

View File

@@ -12,10 +12,10 @@ from fasthtml.common import *
from myfasthtml.controls.BaseCommands import BaseCommands
from myfasthtml.controls.Boundaries import Boundaries
from myfasthtml.controls.UserProfile import UserProfile
from myfasthtml.controls.helpers import mk, Ids
from myfasthtml.controls.helpers import mk
from myfasthtml.core.commands import Command
from myfasthtml.core.dbmanager import DbObject
from myfasthtml.core.instances import InstancesManager, SingleInstance
from myfasthtml.core.instances import SingleInstance
from myfasthtml.core.utils import get_id
from myfasthtml.icons.fluent import panel_left_expand20_regular as left_drawer_icon
from myfasthtml.icons.fluent_p2 import panel_right_expand20_regular as right_drawer_icon
@@ -25,7 +25,7 @@ logger = logging.getLogger("LayoutControl")
class LayoutState(DbObject):
def __init__(self, owner):
super().__init__(owner.get_session(), owner.get_id())
super().__init__(owner)
with self.initializing():
self.left_drawer_open: bool = True
self.right_drawer_open: bool = True
@@ -100,7 +100,7 @@ class Layout(SingleInstance):
def get_groups(self):
return self._groups
def __init__(self, session, app_name, parent=None):
def __init__(self, parent, app_name, _id=None):
"""
Initialize the Layout component.
@@ -109,13 +109,13 @@ class Layout(SingleInstance):
left_drawer (bool): Enable left drawer. Default is True.
right_drawer (bool): Enable right drawer. Default is True.
"""
super().__init__(session, Ids.Layout, parent)
super().__init__(parent, _id=_id)
self.app_name = app_name
# Content storage
self._main_content = None
self._state = LayoutState(self)
self._boundaries = Boundaries(session, self)
self._boundaries = Boundaries(self)
self.commands = Commands(self)
self.left_drawer = self.Content(self)
self.right_drawer = self.Content(self)
@@ -193,7 +193,7 @@ class Layout(SingleInstance):
),
Div( # right
*self.header_right.get_content(),
InstancesManager.get(self._session, Ids.UserProfile, UserProfile),
UserProfile(self),
cls="flex gap-1"
),
cls="mf-layout-header"

View File

@@ -35,14 +35,13 @@ class Search(MultipleInstance):
a callable for extracting a string value from items, and a template callable for rendering
the filtered items. It provides functionality to handle and organize item-based operations.
:param session: The session object to maintain state or context across operations.
:param _id: Optional identifier for the component.
:param items: An optional list of names for the items to be filtered.
:param get_attr: Callable function to extract a string value from an item for filtering. Defaults to a
function that returns the item as is.
:param template: Callable function to render the filtered items. Defaults to a Div rendering function.
"""
super().__init__(Ids.Search, parent, _id=_id)
super().__init__(parent, _id=_id)
self.items_names = items_names or ''
self.items = items or []
self.filtered = self.items.copy()

View File

@@ -9,7 +9,7 @@ from fasthtml.xtend import Script
from myfasthtml.controls.BaseCommands import BaseCommands
from myfasthtml.controls.Search import Search
from myfasthtml.controls.VisNetwork import VisNetwork
from myfasthtml.controls.helpers import Ids, mk
from myfasthtml.controls.helpers import mk
from myfasthtml.core.commands import Command
from myfasthtml.core.dbmanager import DbObject
from myfasthtml.core.instances import MultipleInstance, BaseInstance
@@ -45,7 +45,7 @@ class Boundaries:
class TabsManagerState(DbObject):
def __init__(self, owner):
super().__init__(owner.get_session(), owner.get_id())
super().__init__(owner)
with self.initializing():
# persisted in DB
self.tabs: dict[str, Any] = {}
@@ -78,7 +78,7 @@ class TabsManager(MultipleInstance):
_tab_count = 0
def __init__(self, parent, _id=None):
super().__init__(Ids.TabsManager, parent, _id=_id)
super().__init__(parent, _id=_id)
self._state = TabsManagerState(self)
self.commands = Commands(self)
self._boundaries = Boundaries()

View File

@@ -1,9 +1,10 @@
from fasthtml.components import *
from myfasthtml.controls.BaseCommands import BaseCommands
from myfasthtml.controls.helpers import Ids, mk
from myfasthtml.controls.helpers import mk
from myfasthtml.core.AuthProxy import AuthProxy
from myfasthtml.core.commands import Command
from myfasthtml.core.instances import SingleInstance, InstancesManager
from myfasthtml.core.instances import SingleInstance, InstancesManager, RootInstance
from myfasthtml.core.utils import retrieve_user_info
from myfasthtml.icons.material import dark_mode_filled, person_outline_sharp
from myfasthtml.icons.material_p1 import light_mode_filled, alternate_email_filled
@@ -25,7 +26,7 @@ class UserProfileState:
def save(self):
user_settings = {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
auth_proxy = InstancesManager.get_auth_proxy()
auth_proxy = AuthProxy(RootInstance)
auth_proxy.save_user_info(self._session["access_token"], {"user_settings": user_settings})
@@ -35,8 +36,8 @@ class Commands(BaseCommands):
class UserProfile(SingleInstance):
def __init__(self, session, parent=None):
super().__init__(session, Ids.UserProfile, parent)
def __init__(self, parent=None, _id=None):
super().__init__(parent, _id=_id)
self._state = UserProfileState(self)
self._commands = Commands(self)

View File

@@ -3,7 +3,6 @@ import logging
from fasthtml.components import Script, Div
from myfasthtml.controls.helpers import Ids
from myfasthtml.core.dbmanager import DbObject
from myfasthtml.core.instances import MultipleInstance
@@ -12,7 +11,7 @@ logger = logging.getLogger("VisNetwork")
class VisNetworkState(DbObject):
def __init__(self, owner):
super().__init__(owner.get_session(), owner.get_id())
super().__init__(owner)
with self.initializing():
# persisted in DB
self.nodes: list = []
@@ -30,7 +29,7 @@ class VisNetworkState(DbObject):
class VisNetwork(MultipleInstance):
def __init__(self, parent, _id=None, nodes=None, edges=None, options=None):
super().__init__(Ids.VisNetwork, parent, _id=_id)
super().__init__(parent, _id=_id)
logger.debug(f"VisNetwork created with id: {self._id}")
self._state = VisNetworkState(self)
@@ -50,7 +49,13 @@ class VisNetwork(MultipleInstance):
state.options = options
self._state.update(state)
def add_to_options(self, **kwargs):
logger.debug(f"add_to_options: {kwargs=}")
new_options = self._state.options.copy() | kwargs
self._update_state(None, None, new_options)
return self
def render(self):
# Serialize nodes and edges to JSON

View File

@@ -7,19 +7,8 @@ from myfasthtml.core.utils import merge_classes
class Ids:
# Please keep the alphabetical order
AuthProxy = "mf-auth-proxy"
Boundaries = "mf-boundaries"
CommandsDebugger = "mf-commands-debugger"
DbManager = "mf-dbmanager"
FileUpload = "mf-file-upload"
InstancesDebugger = "mf-instances-debugger"
Keyboard = "mf-keyboard"
Layout = "mf-layout"
Root = "mf-root"
Search = "mf-search"
TabsManager = "mf-tabs-manager"
UserProfile = "mf-user-profile"
VisNetwork = "mf-vis-network"
UserSession = "mf-user_session"
class mk:

View File

@@ -1,11 +1,10 @@
from myfasthtml.auth.utils import login_user, save_user_info, register_user
from myfasthtml.controls.helpers import Ids
from myfasthtml.core.instances import UniqueInstance, RootInstance
from myfasthtml.core.instances import SingleInstance
class AuthProxy(UniqueInstance):
def __init__(self, base_url: str = None):
super().__init__(Ids.AuthProxy, RootInstance)
class AuthProxy(SingleInstance):
def __init__(self, parent, base_url: str = None):
super().__init__(parent)
self._base_url = base_url
def login_user(self, email: str, password: str):

View File

@@ -3,14 +3,13 @@ from types import SimpleNamespace
from dbengine.dbengine import DbEngine
from myfasthtml.controls.helpers import Ids
from myfasthtml.core.instances import SingleInstance, InstancesManager
from myfasthtml.core.instances import SingleInstance, BaseInstance
from myfasthtml.core.utils import retrieve_user_info
class DbManager(SingleInstance):
def __init__(self, session, parent=None, root=".myFastHtmlDb", auto_register: bool = True):
super().__init__(session, Ids.DbManager, parent, auto_register=auto_register)
def __init__(self, parent, root=".myFastHtmlDb", auto_register: bool = True):
super().__init__(parent, auto_register=auto_register)
self.db = DbEngine(root=root)
def save(self, entry, obj):
@@ -35,12 +34,12 @@ class DbObject:
It loads from DB at startup
"""
_initializing = False
_forbidden_attrs = {"_initializing", "_db_manager", "_name", "_session", "_forbidden_attrs"}
_forbidden_attrs = {"_initializing", "_db_manager", "_name", "_owner", "_forbidden_attrs"}
def __init__(self, session, name=None, db_manager=None):
self._session = session
def __init__(self, owner: BaseInstance, name=None, db_manager=None):
self._owner = owner
self._name = name or self.__class__.__name__
self._db_manager = db_manager or InstancesManager.get(self._session, Ids.DbManager, DbManager)
self._db_manager = db_manager or DbManager(self._owner)
self._finalize_initialization()

View File

@@ -1,7 +1,8 @@
import uuid
from typing import Self
from typing import Optional
from myfasthtml.controls.helpers import Ids
from myfasthtml.core.utils import pascal_to_snake
special_session = {
"user_info": {"id": "** SPECIAL SESSION **"}
@@ -18,25 +19,76 @@ class BaseInstance:
Base class for all instances (manageable by InstancesManager)
"""
def __init__(self, session: dict, prefix: str, _id: str, parent: Self, auto_register: bool = True):
self._session = session
self._id = _id
self._prefix = prefix
def __new__(cls, *args, **kwargs):
# Extract arguments from both positional and keyword arguments
# Signature matches __init__: parent, session=None, _id=None, auto_register=True
parent = args[0] if len(args) > 0 and isinstance(args[0], BaseInstance) else kwargs.get("parent", None)
session = args[1] if len(args) > 1 and isinstance(args[1], dict) else kwargs.get("session", None)
_id = args[2] if len(args) > 2 and isinstance(args[2], str) else kwargs.get("_id", None)
# Compute _id if not provided
if _id is None:
_id = cls.compute_id()
if session is None:
if parent is not None:
session = parent.get_session()
else:
raise TypeError("Either session or parent must be provided")
session_id = InstancesManager.get_session_id(session)
key = (session_id, _id)
if key in InstancesManager.instances:
res = InstancesManager.instances[key]
if type(res) is not cls:
raise TypeError(f"Instance with id {_id} already exists, but is of type {type(res)}")
return res
# Otherwise create a new instance
instance = super().__new__(cls)
instance._is_new_instance = True # mark as fresh
return instance
def __init__(self, parent: Optional['BaseInstance'],
session: Optional[dict] = None,
_id: Optional[str] = None,
auto_register: bool = True):
if not getattr(self, "_is_new_instance", False):
# Skip __init__ if instance already existed
return
else:
# make sure that it's no longer considered as a new instance
self._is_new_instance = False
self._parent = parent
self._session = session or (parent.get_session() if parent else None)
self._id = _id or self.compute_id()
if auto_register:
InstancesManager.register(session, self)
InstancesManager.register(self._session, self)
def get_id(self):
return self._id
def get_session(self):
def get_session(self) -> dict:
return self._session
def get_prefix(self):
return self._prefix
def get_id(self) -> str:
return self._id
def get_parent(self):
def get_parent(self) -> Optional['BaseInstance']:
return self._parent
@classmethod
def get_prefix(cls):
return f"mf-{pascal_to_snake(cls.__name__)}"
@classmethod
def compute_id(cls):
prefix = cls.get_prefix()
if issubclass(cls, SingleInstance):
_id = prefix
else:
_id = f"{prefix}-{str(uuid.uuid4())}"
return _id
class SingleInstance(BaseInstance):
@@ -44,19 +96,12 @@ class SingleInstance(BaseInstance):
Base class for instances that can only have one instance at a time.
"""
def __init__(self, session: dict, prefix: str, parent, auto_register: bool = True):
super().__init__(session, prefix, prefix, parent, auto_register)
class UniqueInstance(BaseInstance):
"""
Base class for instances that can only have one instance at a time.
Does not throw exception if the instance already exists, it simply overwrites it.
"""
def __init__(self, prefix: str, parent: BaseInstance, auto_register: bool = True):
super().__init__(parent.get_session(), prefix, prefix, parent, auto_register)
self._prefix = prefix
def __init__(self,
parent: Optional[BaseInstance] = None,
session: Optional[dict] = None,
_id: Optional[str] = None,
auto_register: bool = True):
super().__init__(parent, session, _id, auto_register)
class MultipleInstance(BaseInstance):
@@ -64,9 +109,11 @@ class MultipleInstance(BaseInstance):
Base class for instances that can have multiple instances at a time.
"""
def __init__(self, prefix: str, parent: BaseInstance, auto_register: bool = True, _id=None):
super().__init__(parent.get_session(), prefix, _id or f"{prefix}-{str(uuid.uuid4())}", parent, auto_register)
self._prefix = prefix
def __init__(self, parent: BaseInstance,
session: Optional[dict] = None,
_id: Optional[str] = None,
auto_register: bool = True):
super().__init__(parent, session, _id, auto_register)
class InstancesManager:
@@ -80,7 +127,7 @@ class InstancesManager:
:param instance:
:return:
"""
key = (InstancesManager._get_session_id(session), instance.get_id())
key = (InstancesManager.get_session_id(session), instance.get_id())
if isinstance(instance, SingleInstance) and key in InstancesManager.instances:
raise DuplicateInstanceError(instance)
@@ -89,48 +136,27 @@ class InstancesManager:
return instance
@staticmethod
def get(session: dict, instance_id: str, instance_type: type = None, parent: BaseInstance = None, *args, **kwargs):
def get(session: dict, instance_id: str):
"""
Get or create an instance of the given type (from its id)
:param session:
:param instance_id:
:param instance_type:
:param parent:
:param args:
:param kwargs:
:return:
"""
try:
key = (InstancesManager._get_session_id(session), instance_id)
return InstancesManager.instances[key]
except KeyError:
if instance_type:
if not issubclass(instance_type, SingleInstance):
assert parent is not None, "Parent instance must be provided if not SingleInstance"
if isinstance(parent, MultipleInstance):
return instance_type(parent, _id=instance_id, *args, **kwargs)
else:
return instance_type(session, parent=parent, *args, **kwargs) # it will be automatically registered
else:
raise
key = (InstancesManager.get_session_id(session), instance_id)
return InstancesManager.instances[key]
@staticmethod
def _get_session_id(session):
if not session:
def get_session_id(session):
if session is None:
return "** NOT LOGGED IN **"
if "user_info" not in session:
return "** UNKNOWN USER **"
return session["user_info"].get("id", "** INVALID SESSION **")
@staticmethod
def get_auth_proxy():
return InstancesManager.get(special_session, Ids.AuthProxy)
@staticmethod
def reset():
return InstancesManager.instances.clear()
InstancesManager.instances.clear()
RootInstance = SingleInstance(special_session, Ids.Root, None)
RootInstance = SingleInstance(None, special_session, Ids.Root)

View File

@@ -144,50 +144,34 @@ def from_tree_with_metadata(
def from_parent_child_list(
items: list,
id_getter: callable = None,
label_getter: callable = None,
parent_getter: callable = None,
ghost_color: str = "#ff9999"
id_getter: Callable = None,
label_getter: Callable = None,
parent_getter: Callable = None,
ghost_color: str = "#ff9999",
root_color: str | None = "#ff9999"
) -> tuple[list, list]:
"""
Convert a list of items with parent references to vis.js nodes and edges format.
Args:
items: List of items (dicts or objects) with parent references
(e.g., [{"id": "child", "parent": "root", "label": "Child"}, ...])
id_getter: Optional callback to extract node ID from item
Default: lambda item: item.get("id")
label_getter: Optional callback to extract node label from item
Default: lambda item: item.get("label", "")
parent_getter: Optional callback to extract parent ID from item
Default: lambda item: item.get("parent")
ghost_color: Color to use for ghost nodes (nodes referenced as parents but not in list)
Default: "#ff9999" (light red)
id_getter: callback to extract node ID
label_getter: callback to extract node label
parent_getter: callback to extract parent ID
ghost_color: color for ghost nodes (referenced parents)
root_color: color for root nodes (nodes without parent)
Returns:
tuple: (nodes, edges) where:
- nodes: list of dicts with IDs from items, ghost nodes have color property
- edges: list of dicts with 'from' and 'to' keys
Note:
- Nodes with parent=None or parent="" are treated as root nodes
- If a parent is referenced but doesn't exist in items, a ghost node is created
with the ghost_color applied
Example:
>>> items = [
... {"id": "root", "label": "Root"},
... {"id": "child1", "parent": "root", "label": "Child 1"},
... {"id": "child2", "parent": "unknown", "label": "Child 2"}
... ]
>>> nodes, edges = from_parent_child_list(items)
>>> # "unknown" will be created as a ghost node with color="#ff9999"
tuple: (nodes, edges)
"""
# Default getters
if id_getter is None:
id_getter = lambda item: item.get("id")
if label_getter is None:
label_getter = lambda item: item.get("label", "")
if parent_getter is None:
parent_getter = lambda item: item.get("parent")
@@ -205,34 +189,48 @@ def from_parent_child_list(
existing_ids.add(node_id)
nodes.append({
"id": node_id,
"label": node_label
"label": node_label,
# root color assigned later
})
# Track ghost nodes to avoid duplicates
# Track ghost nodes
ghost_nodes = set()
# Second pass: create edges and identify ghost nodes
# Track which nodes have parents
nodes_with_parent = set()
# Second pass: create edges and detect ghost nodes
for item in items:
node_id = id_getter(item)
parent_id = parent_getter(item)
# Skip if no parent or parent is empty string or None
# Skip roots
if parent_id is None or parent_id == "":
continue
# Create edge from parent to child
# Child has a parent
nodes_with_parent.add(node_id)
# Create edge parent → child
edges.append({
"from": parent_id,
"to": node_id
})
# Check if parent exists, if not create ghost node
# Create ghost node if parent not found
if parent_id not in existing_ids and parent_id not in ghost_nodes:
ghost_nodes.add(parent_id)
nodes.append({
"id": parent_id,
"label": str(parent_id), # Use ID as label for ghost nodes
"label": str(parent_id),
"color": ghost_color
})
# Final pass: assign color to root nodes
if root_color is not None:
for node in nodes:
if node["id"] not in nodes_with_parent and node["id"] not in ghost_nodes:
# Root node
node["color"] = root_color
return nodes, edges

View File

@@ -1,4 +1,5 @@
import logging
import re
from bs4 import Tag
from fastcore.xml import FT
@@ -234,6 +235,18 @@ def get_id(obj):
return str(obj)
def pascal_to_snake(name: str) -> str:
"""Convert a PascalCase or CamelCase string to snake_case."""
if name is None:
return None
name = name.strip()
# Insert underscore before capital letters (except the first one)
s1 = re.sub(r'(.)([A-Z][a-z]+)', r'\1_\2', name)
# Handle consecutive capital letters (like 'HTTPServer' -> 'http_server')
s2 = re.sub(r'([a-z0-9])([A-Z])', r'\1_\2', s1)
return s2.lower()
@utils_rt(Routes.Commands)
def post(session, c_id: str, client_response: dict = None):
"""

View File

@@ -11,6 +11,10 @@ import re
def pascal_to_snake(name: str) -> str:
"""Convert a PascalCase or CamelCase string to snake_case."""
if name is None:
return None
name = name.strip()
# Insert underscore before capital letters (except the first one)
s1 = re.sub(r'(.)([A-Z][a-z]+)', r'\1_\2', name)
# Handle consecutive capital letters (like 'HTTPServer' -> 'http_server')

View File

@@ -10,6 +10,7 @@ from starlette.responses import Response
from myfasthtml.auth.routes import setup_auth_routes
from myfasthtml.auth.utils import create_auth_beforeware
from myfasthtml.core.AuthProxy import AuthProxy
from myfasthtml.core.instances import RootInstance
from myfasthtml.core.utils import utils_app
logger = logging.getLogger("MyFastHtml")
@@ -104,6 +105,6 @@ def create_app(daisyui: Optional[bool] = True,
setup_auth_routes(app, rt, base_url=base_url)
# create the AuthProxy instance
AuthProxy(base_url) # using the auto register mechanism to expose it
AuthProxy(RootInstance, base_url) # using the auto register mechanism to expose it
return app, rt

View File

@@ -20,4 +20,4 @@ def session():
@pytest.fixture(scope="session")
def root_instance(session):
return SingleInstance(session, "TestRoot", None)
return SingleInstance(None, session, "TestRoot")

View File

@@ -1,3 +1,5 @@
import shutil
import pytest
from fasthtml.components import *
from fasthtml.xtend import Script
@@ -10,9 +12,11 @@ from .conftest import session
@pytest.fixture()
def tabs_manager(root_instance):
shutil.rmtree(".myFastHtmlDb", ignore_errors=True)
yield TabsManager(root_instance)
InstancesManager.reset()
shutil.rmtree(".myFastHtmlDb", ignore_errors=True)
class TestTabsManagerBehaviour:

View File

@@ -4,6 +4,7 @@ from dataclasses import dataclass
import pytest
from myfasthtml.core.dbmanager import DbManager, DbObject
from myfasthtml.core.instances import SingleInstance, BaseInstance
@pytest.fixture(scope="session")
@@ -19,9 +20,14 @@ def session():
@pytest.fixture
def db_manager(session):
def parent(session):
return SingleInstance(session=session, _id="test_parent_id")
@pytest.fixture
def db_manager(parent):
shutil.rmtree("TestDb", ignore_errors=True)
db_manager_instance = DbManager(session, root="TestDb", auto_register=False)
db_manager_instance = DbManager(parent, root="TestDb", auto_register=False)
yield db_manager_instance
@@ -32,17 +38,17 @@ def simplify(res: dict) -> dict:
return {k: v for k, v in res.items() if not k.startswith("_")}
def test_i_can_init(session, db_manager):
def test_i_can_init(parent, db_manager):
class DummyObject(DbObject):
def __init__(self, sess: dict):
super().__init__(sess, "DummyObject", db_manager)
def __init__(self, owner: BaseInstance):
super().__init__(owner, "DummyObject", db_manager)
with self.initializing():
self.value: str = "hello"
self.number: int = 42
self.none_value: None = None
dummy = DummyObject(session)
dummy = DummyObject(parent)
props = dummy._get_properties()
@@ -52,17 +58,17 @@ def test_i_can_init(session, db_manager):
assert len(history) == 1
def test_i_can_init_from_dataclass(session, db_manager):
def test_i_can_init_from_dataclass(parent, db_manager):
@dataclass
class DummyObject(DbObject):
def __init__(self, sess: dict):
super().__init__(sess, "DummyObject", db_manager)
def __init__(self, owner: BaseInstance):
super().__init__(owner, "DummyObject", db_manager)
value: str = "hello"
number: int = 42
none_value: None = None
DummyObject(session)
DummyObject(parent)
in_db = db_manager.load("DummyObject")
history = db_manager.db.history(db_manager.get_tenant(), "DummyObject")
@@ -70,10 +76,10 @@ def test_i_can_init_from_dataclass(session, db_manager):
assert len(history) == 1
def test_i_can_init_from_db_with(session, db_manager):
def test_i_can_init_from_db_with(parent, db_manager):
class DummyObject(DbObject):
def __init__(self, sess: dict):
super().__init__(sess, "DummyObject", db_manager)
def __init__(self, owner: BaseInstance):
super().__init__(owner, "DummyObject", db_manager)
with self.initializing():
self.value: str = "hello"
@@ -82,17 +88,17 @@ def test_i_can_init_from_db_with(session, db_manager):
# insert other values in db
db_manager.save("DummyObject", {"value": "other_value", "number": 34})
dummy = DummyObject(session)
dummy = DummyObject(parent)
assert dummy.value == "other_value"
assert dummy.number == 34
def test_i_can_init_from_db_with_dataclass(session, db_manager):
def test_i_can_init_from_db_with_dataclass(parent, db_manager):
@dataclass
class DummyObject(DbObject):
def __init__(self, sess: dict):
super().__init__(sess, "DummyObject", db_manager)
def __init__(self, owner: BaseInstance):
super().__init__(owner, "DummyObject", db_manager)
value: str = "hello"
number: int = 42
@@ -100,16 +106,16 @@ def test_i_can_init_from_db_with_dataclass(session, db_manager):
# insert other values in db
db_manager.save("DummyObject", {"value": "other_value", "number": 34})
dummy = DummyObject(session)
dummy = DummyObject(parent)
assert dummy.value == "other_value"
assert dummy.number == 34
def test_i_do_not_save_when_prefixed_by_underscore_or_ns(session, db_manager):
def test_i_do_not_save_when_prefixed_by_underscore_or_ns(parent, db_manager):
class DummyObject(DbObject):
def __init__(self, sess: dict):
super().__init__(sess, "DummyObject", db_manager)
def __init__(self, owner: BaseInstance):
super().__init__(owner, "DummyObject", db_manager)
with self.initializing():
self.to_save: str = "value"
@@ -120,7 +126,7 @@ def test_i_do_not_save_when_prefixed_by_underscore_or_ns(session, db_manager):
_not_to_save: str = "value"
ns_not_to_save: str = "value"
dummy = DummyObject(session)
dummy = DummyObject(parent)
dummy.to_save = "other_value"
dummy.ns_not_to_save = "other_value"
dummy._not_to_save = "other_value"
@@ -131,17 +137,17 @@ def test_i_do_not_save_when_prefixed_by_underscore_or_ns(session, db_manager):
assert "ns_not_to_save" not in in_db
def test_i_do_not_save_when_prefixed_by_underscore_or_ns_with_dataclass(session, db_manager):
def test_i_do_not_save_when_prefixed_by_underscore_or_ns_with_dataclass(parent, db_manager):
@dataclass
class DummyObject(DbObject):
def __init__(self, sess: dict):
super().__init__(sess, "DummyObject", db_manager)
def __init__(self, owner: BaseInstance):
super().__init__(owner, "DummyObject", db_manager)
to_save: str = "value"
_not_to_save: str = "value"
ns_not_to_save: str = "value"
dummy = DummyObject(session)
dummy = DummyObject(parent)
dummy.to_save = "other_value"
dummy.ns_not_to_save = "other_value"
dummy._not_to_save = "other_value"
@@ -152,31 +158,31 @@ def test_i_do_not_save_when_prefixed_by_underscore_or_ns_with_dataclass(session,
assert "ns_not_to_save" not in in_db
def test_db_is_updated_when_attribute_is_modified(session, db_manager):
def test_db_is_updated_when_attribute_is_modified(parent, db_manager):
@dataclass
class DummyObject(DbObject):
def __init__(self, sess: dict):
super().__init__(sess, "DummyObject", db_manager)
def __init__(self, owner: BaseInstance):
super().__init__(owner, "DummyObject", db_manager)
value: str = "hello"
number: int = 42
dummy = DummyObject(session)
dummy = DummyObject(parent)
dummy.value = "other_value"
assert simplify(db_manager.load("DummyObject")) == {"value": "other_value", "number": 42}
def test_i_do_not_save_in_db_when_value_is_the_same(session, db_manager):
def test_i_do_not_save_in_db_when_value_is_the_same(parent, db_manager):
@dataclass
class DummyObject(DbObject):
def __init__(self, sess: dict):
super().__init__(sess, "DummyObject", db_manager)
def __init__(self, owner: BaseInstance):
super().__init__(owner, "DummyObject", db_manager)
value: str = "hello"
number: int = 42
dummy = DummyObject(session)
dummy = DummyObject(parent)
dummy.value = "other_value"
in_db_1 = db_manager.load("DummyObject")
@@ -186,16 +192,16 @@ def test_i_do_not_save_in_db_when_value_is_the_same(session, db_manager):
assert in_db_1["__parent__"] == in_db_2["__parent__"]
def test_i_can_update(session, db_manager):
def test_i_can_update(parent, db_manager):
@dataclass
class DummyObject(DbObject):
def __init__(self, sess: dict):
super().__init__(sess, "DummyObject", db_manager)
def __init__(self, owner: BaseInstance):
super().__init__(owner, "DummyObject", db_manager)
value: str = "hello"
number: int = 42
dummy = DummyObject(session)
dummy = DummyObject(parent)
clone = dummy.copy()
clone.number = 34
@@ -207,54 +213,52 @@ def test_i_can_update(session, db_manager):
assert simplify(db_manager.load("DummyObject")) == {"value": "other_value", "number": 34}
def test_forbidden_attributes_are_not_the_copy(session, db_manager):
def test_forbidden_attributes_are_not_the_copy(parent, db_manager):
class DummyObject(DbObject):
def __init__(self, sess: dict):
super().__init__(sess, "DummyObject", db_manager)
def __init__(self, owner: BaseInstance):
super().__init__(owner, "DummyObject", db_manager)
with self.initializing():
self.value: str = "hello"
self.number: int = 42
self.none_value: None = None
dummy = DummyObject(session)
dummy = DummyObject(parent)
clone = dummy.copy()
for k in DbObject._forbidden_attrs:
assert not hasattr(clone, k), f"Clone should not have forbidden attribute '{k}'"
def test_forbidden_attributes_are_not_the_copy_for_dataclass(session, db_manager):
def test_forbidden_attributes_are_not_the_copy_for_dataclass(parent, db_manager):
@dataclass
class DummyObject(DbObject):
def __init__(self, sess: dict):
super().__init__(sess, "DummyObject", db_manager)
def __init__(self, owner: BaseInstance):
super().__init__(owner, "DummyObject", db_manager)
value: str = "hello"
number: int = 42
none_value: None = None
dummy = DummyObject(session)
dummy = DummyObject(parent)
clone = dummy.copy()
for k in DbObject._forbidden_attrs:
assert not hasattr(clone, k), f"Clone should not have forbidden attribute '{k}'"
def test_i_cannot_update_a_forbidden_attribute(session, db_manager):
def test_i_cannot_update_a_forbidden_attribute(parent, db_manager):
@dataclass
class DummyObject(DbObject):
def __init__(self, sess: dict):
super().__init__(sess, "DummyObject", db_manager)
def __init__(self, owner: BaseInstance):
super().__init__(owner, "DummyObject", db_manager)
value: str = "hello"
number: int = 42
none_value: None = None
dummy = DummyObject(session)
dummy = DummyObject(parent)
dummy.update(_session="other_value")
dummy.update(_owner="other_value")
assert dummy._session == session
assert dummy._owner is parent

View File

@@ -0,0 +1,387 @@
import pytest
from myfasthtml.core.instances import (
BaseInstance,
SingleInstance,
MultipleInstance,
InstancesManager,
DuplicateInstanceError,
special_session,
Ids,
RootInstance
)
@pytest.fixture(autouse=True)
def reset_instances():
"""Reset instances before each test to ensure isolation."""
InstancesManager.instances.clear()
yield
InstancesManager.instances.clear()
@pytest.fixture
def session():
"""Create a test session."""
return {"user_info": {"id": "test-user-123"}}
@pytest.fixture
def another_session():
"""Create another test session."""
return {"user_info": {"id": "test-user-456"}}
@pytest.fixture
def root_instance(session):
"""Create a root instance for testing."""
return SingleInstance(parent=None, session=session, _id="test-root")
# Example subclasses for testing
class SubSingleInstance(SingleInstance):
"""Example subclass of SingleInstance with simplified signature."""
def __init__(self, parent):
super().__init__(parent=parent)
class SubMultipleInstance(MultipleInstance):
"""Example subclass of MultipleInstance with custom parameter."""
def __init__(self, parent, _id=None, custom_param=None):
super().__init__(parent=parent, _id=_id)
self.custom_param = custom_param
class TestBaseInstance:
def test_i_can_create_a_base_instance_with_positional_args(self, session, root_instance):
"""Test that a BaseInstance can be created with positional arguments."""
instance = BaseInstance(root_instance, session, "test_id")
assert instance is not None
assert instance.get_id() == "test_id"
assert instance.get_session() == session
assert instance.get_parent() == root_instance
def test_i_can_create_a_base_instance_with_kwargs(self, session, root_instance):
"""Test that a BaseInstance can be created with keyword arguments."""
instance = BaseInstance(parent=root_instance, session=session, _id="test_id")
assert instance is not None
assert instance.get_id() == "test_id"
assert instance.get_session() == session
assert instance.get_parent() == root_instance
def test_i_can_create_a_base_instance_with_mixed_args(self, session, root_instance):
"""Test that a BaseInstance can be created with mixed positional and keyword arguments."""
instance = BaseInstance(root_instance, session=session, _id="test_id")
assert instance is not None
assert instance.get_id() == "test_id"
assert instance.get_session() == session
assert instance.get_parent() == root_instance
def test_i_can_retrieve_the_same_instance_when_using_same_session_and_id(self, session, root_instance):
"""Test that creating an instance with same session and id returns the existing instance."""
instance1 = BaseInstance(root_instance, session, "same_id")
instance2 = BaseInstance(root_instance, session, "same_id")
assert instance1 is instance2
def test_i_can_control_instances_registration(self, session, root_instance):
"""Test that auto_register=False prevents automatic registration."""
BaseInstance(parent=root_instance, session=session, _id="test_id", auto_register=False)
session_id = InstancesManager.get_session_id(session)
key = (session_id, "test_id")
assert key not in InstancesManager.instances
def test_i_can_have_different_instances_for_different_sessions(self, session, another_session, root_instance):
"""Test that different sessions can have instances with the same id."""
root_instance2 = SingleInstance(parent=None, session=another_session, _id="test-root")
instance1 = BaseInstance(root_instance, session, "same_id")
instance2 = BaseInstance(root_instance2, another_session, "same_id")
assert instance1 is not instance2
assert instance1.get_session() == session
assert instance2.get_session() == another_session
def test_i_can_create_instance_with_parent_only(self, session, root_instance):
"""Test that session can be extracted from parent when not provided."""
instance = BaseInstance(parent=root_instance, _id="test_id")
assert instance.get_session() == root_instance.get_session()
assert instance.get_parent() == root_instance
def test_i_cannot_create_instance_without_parent_or_session(self):
"""Test that creating an instance without parent or session raises TypeError."""
with pytest.raises(TypeError, match="Either session or parent must be provided"):
BaseInstance(None, _id="test_id")
def test_i_can_get_auto_generated_id(self, session, root_instance):
"""Test that if _id is not provided, an ID is auto-generated via compute_id()."""
instance = BaseInstance(parent=root_instance, session=session)
assert instance.get_id() is not None
assert instance.get_id().startswith("mf-base_instance-")
def test_i_can_get_prefix_from_class_name(self):
"""Test that get_prefix() returns the correct snake_case prefix."""
prefix = BaseInstance.get_prefix()
assert prefix == "mf-base_instance"
class TestSingleInstance:
def test_i_can_create_a_single_instance(self, session, root_instance):
"""Test that a SingleInstance can be created."""
instance = SingleInstance(parent=root_instance, session=session)
assert instance is not None
assert instance.get_id() == "mf-single_instance"
assert instance.get_session() == session
assert instance.get_parent() == root_instance
def test_i_can_create_single_instance_with_positional_args(self, session, root_instance):
"""Test that a SingleInstance can be created with positional arguments."""
instance = SingleInstance(root_instance, session, "custom_id")
assert instance is not None
assert instance.get_id() == "custom_id"
assert instance.get_session() == session
assert instance.get_parent() == root_instance
def test_the_same_instance_is_returned(self, session):
"""Test that single instance is cached and returned on subsequent calls."""
instance1 = SingleInstance(parent=None, session=session, _id="unique_id")
instance2 = SingleInstance(parent=None, session=session, _id="unique_id")
assert instance1 is instance2
def test_i_cannot_create_duplicate_single_instance(self, session):
"""Test that creating a duplicate SingleInstance raises DuplicateInstanceError."""
instance = SingleInstance(parent=None, session=session, _id="unique_id")
with pytest.raises(DuplicateInstanceError):
InstancesManager.register(session, instance)
def test_i_can_retrieve_existing_single_instance(self, session):
"""Test that attempting to create an existing SingleInstance returns the same instance."""
instance1 = SingleInstance(parent=None, session=session, _id="same_id")
instance2 = SingleInstance(parent=None, session=session, _id="same_id", auto_register=False)
assert instance1 is instance2
def test_i_can_get_auto_computed_id_for_single_instance(self, session):
"""Test that the default ID equals prefix for SingleInstance."""
instance = SingleInstance(parent=None, session=session)
assert instance.get_id() == "mf-single_instance"
assert instance.get_id() == SingleInstance.get_prefix()
class TestSingleInstanceSubclass:
def test_i_can_create_subclass_of_single_instance(self, root_instance):
"""Test that a subclass of SingleInstance works correctly."""
instance = SubSingleInstance(root_instance)
assert instance is not None
assert isinstance(instance, SingleInstance)
assert isinstance(instance, SubSingleInstance)
def test_i_can_create_subclass_with_custom_signature(self, root_instance):
"""Test that subclass with simplified signature works correctly."""
instance = SubSingleInstance(root_instance)
assert instance.get_parent() == root_instance
assert instance.get_session() == root_instance.get_session()
assert instance.get_id() == "mf-sub_single_instance"
assert instance.get_prefix() == "mf-sub_single_instance"
def test_i_can_retrieve_subclass_instance_from_cache(self, root_instance):
"""Test that cache works for subclasses."""
instance1 = SubSingleInstance(root_instance)
instance2 = SubSingleInstance(root_instance)
assert instance1 is instance2
class TestMultipleInstance:
def test_i_can_create_multiple_instances_with_same_prefix(self, session, root_instance):
"""Test that multiple MultipleInstance objects can be created with the same prefix."""
instance1 = MultipleInstance(parent=root_instance, session=session)
instance2 = MultipleInstance(parent=root_instance, session=session)
assert instance1 is not instance2
assert instance1.get_id() != instance2.get_id()
assert instance1.get_id().startswith("mf-multiple_instance-")
assert instance2.get_id().startswith("mf-multiple_instance-")
def test_i_can_have_auto_generated_unique_ids(self, session, root_instance):
"""Test that each MultipleInstance receives a unique auto-generated ID."""
instances = [MultipleInstance(parent=root_instance, session=session) for _ in range(5)]
ids = [inst.get_id() for inst in instances]
# All IDs should be unique
assert len(ids) == len(set(ids))
# All IDs should start with the prefix
assert all(id.startswith("mf-multiple_instance-") for id in ids)
def test_i_can_provide_custom_id_to_multiple_instance(self, session, root_instance):
"""Test that a custom _id can be provided to MultipleInstance."""
custom_id = "custom-instance-id"
instance = MultipleInstance(parent=root_instance, session=session, _id=custom_id)
assert instance.get_id() == custom_id
def test_i_can_retrieve_multiple_instance_by_custom_id(self, session, root_instance):
"""Test that a MultipleInstance with custom _id can be retrieved from cache."""
custom_id = "custom-instance-id"
instance1 = MultipleInstance(parent=root_instance, session=session, _id=custom_id)
instance2 = MultipleInstance(parent=root_instance, session=session, _id=custom_id)
assert instance1 is instance2
class TestMultipleInstanceSubclass:
def test_i_can_create_subclass_of_multiple_instance(self, root_instance):
"""Test that a subclass of MultipleInstance works correctly."""
instance = SubMultipleInstance(root_instance, custom_param="test")
assert instance is not None
assert isinstance(instance, MultipleInstance)
assert isinstance(instance, SubMultipleInstance)
assert instance.custom_param == "test"
def test_i_can_create_multiple_subclass_instances_with_auto_generated_ids(self, root_instance):
"""Test that multiple instances of subclass can be created with unique IDs."""
instance1 = SubMultipleInstance(root_instance, custom_param="first")
instance2 = SubMultipleInstance(root_instance, custom_param="second")
assert instance1 is not instance2
assert instance1.get_id() != instance2.get_id()
assert instance1.get_id().startswith("mf-sub_multiple_instance-")
assert instance2.get_id().startswith("mf-sub_multiple_instance-")
def test_i_can_create_subclass_with_custom_signature(self, root_instance):
"""Test that subclass with custom parameters works correctly."""
instance = SubMultipleInstance(root_instance, custom_param="value")
assert instance.get_parent() == root_instance
assert instance.get_session() == root_instance.get_session()
assert instance.custom_param == "value"
def test_i_can_retrieve_subclass_instance_from_cache(self, root_instance):
"""Test that cache works for subclasses."""
instance1 = SubMultipleInstance(root_instance, custom_param="first")
instance2 = SubMultipleInstance(root_instance, custom_param="second", _id=instance1.get_id())
assert instance1 is instance2
def test_i_cannot_retrieve_subclass_instance_when_type_differs(self, root_instance):
"""Test that cache works for subclasses with custom _id."""
# Need to pass _id explicitly to enable caching
instance1 = SubMultipleInstance(root_instance)
with pytest.raises(TypeError):
MultipleInstance(parent=root_instance, _id=instance1.get_id())
def test_i_can_get_correct_prefix_for_multiple_subclass(self):
"""Test that subclass has correct auto-generated prefix."""
prefix = SubMultipleInstance.get_prefix()
assert prefix == "mf-sub_multiple_instance"
class TestInstancesManager:
def test_i_can_register_an_instance_manually(self, session, root_instance):
"""Test that an instance can be manually registered."""
instance = BaseInstance(parent=root_instance, session=session, _id="manual_id", auto_register=False)
InstancesManager.register(session, instance)
session_id = InstancesManager.get_session_id(session)
key = (session_id, "manual_id")
assert key in InstancesManager.instances
assert InstancesManager.instances[key] is instance
def test_i_can_get_existing_instance_by_id(self, session, root_instance):
"""Test that an existing instance can be retrieved by ID."""
instance = BaseInstance(parent=root_instance, session=session, _id="get_id")
retrieved = InstancesManager.get(session, "get_id")
assert retrieved is instance
def test_i_cannot_get_nonexistent_instance_without_type(self, session):
"""Test that getting a non-existent instance without type raises KeyError."""
with pytest.raises(KeyError):
InstancesManager.get(session, "nonexistent_id")
def test_i_can_get_session_id_from_valid_session(self, session):
"""Test that session ID is correctly extracted from a valid session."""
session_id = InstancesManager.get_session_id(session)
assert session_id == "test-user-123"
def test_i_can_handle_none_session(self):
"""Test that None session returns a special identifier."""
session_id = InstancesManager.get_session_id(None)
assert session_id == "** NOT LOGGED IN **"
def test_i_can_handle_invalid_session(self):
"""Test that invalid sessions return appropriate identifiers."""
# Session is None
session_id = InstancesManager.get_session_id(None)
assert session_id == "** NOT LOGGED IN **"
# Session without user_info
session_no_user = {}
session_id = InstancesManager.get_session_id(session_no_user)
assert session_id == "** UNKNOWN USER **"
# Session with user_info but no id
session_no_id = {"user_info": {}}
session_id = InstancesManager.get_session_id(session_no_id)
assert session_id == "** INVALID SESSION **"
def test_i_can_reset_all_instances(self, session, root_instance):
"""Test that reset() clears all instances."""
BaseInstance(parent=root_instance, session=session, _id="id1")
BaseInstance(parent=root_instance, session=session, _id="id2")
assert len(InstancesManager.instances) > 0
InstancesManager.reset()
assert len(InstancesManager.instances) == 0
class TestRootInstance:
def test_i_can_create_root_instance_with_positional_args(self):
"""Test that RootInstance can be created with positional arguments."""
root = SingleInstance(None, special_session, Ids.Root)
assert root is not None
assert root.get_id() == Ids.Root
assert root.get_session() == special_session
assert root.get_parent() is None
def test_i_can_access_root_instance(self):
"""Test that RootInstance is created and accessible."""
assert RootInstance is not None
assert RootInstance.get_id() == Ids.Root
assert RootInstance.get_session() == special_session

View File

@@ -311,7 +311,7 @@ class TestFromParentChildList:
nodes, edges = from_parent_child_list(items)
assert len(nodes) == 1
assert nodes[0] == {"id": "root", "label": "Root"}
assert nodes[0] == {'color': '#ff9999', 'id': 'root', 'label': 'Root'}
assert len(edges) == 0
def test_i_can_convert_simple_parent_child_relationship(self):
@@ -323,7 +323,7 @@ class TestFromParentChildList:
nodes, edges = from_parent_child_list(items)
assert len(nodes) == 2
assert {"id": "root", "label": "Root"} in nodes
assert {'color': '#ff9999', 'id': 'root', 'label': 'Root'} in nodes
assert {"id": "child", "label": "Child"} in nodes
assert len(edges) == 1
@@ -513,3 +513,136 @@ class TestFromParentChildList:
ghost_node = [n for n in nodes if n["id"] == "ghost_parent"][0]
assert ghost_node["label"] == "ghost_parent"
def test_i_can_apply_root_color_to_single_root(self):
"""Test that a single root node receives the root_color."""
items = [{"id": "root", "label": "Root"}]
nodes, edges = from_parent_child_list(items, root_color="#ff0000")
assert len(nodes) == 1
assert nodes[0]["color"] == "#ff0000"
def test_i_can_apply_root_color_to_multiple_roots(self):
"""Test root_color is assigned to all nodes without parent."""
items = [
{"id": "root1", "label": "Root 1"},
{"id": "root2", "label": "Root 2"},
{"id": "child", "parent": "root1", "label": "Child"}
]
nodes, edges = from_parent_child_list(items, root_color="#aa0000")
root_nodes = [n for n in nodes if n["id"] in ("root1", "root2")]
assert all(n.get("color") == "#aa0000" for n in root_nodes)
# child must NOT have root_color
child_node = next(n for n in nodes if n["id"] == "child")
assert "color" not in child_node
def test_i_can_handle_root_with_parent_none(self):
"""Test that root_color is applied when parent=None."""
items = [
{"id": "r1", "parent": None, "label": "R1"}
]
nodes, edges = from_parent_child_list(items, root_color="#112233")
assert nodes[0]["color"] == "#112233"
def test_i_can_handle_root_with_parent_empty_string(self):
"""Test that root_color is applied when parent=''."""
items = [
{"id": "r1", "parent": "", "label": "R1"}
]
nodes, edges = from_parent_child_list(items, root_color="#334455")
assert nodes[0]["color"] == "#334455"
def test_i_do_not_apply_root_color_to_non_roots(self):
"""Test that only real roots receive root_color."""
items = [
{"id": "root", "label": "Root"},
{"id": "child", "parent": "root", "label": "Child"}
]
nodes, edges = from_parent_child_list(items, root_color="#ff0000")
# Only one root → only this one has the color
root_node = next(n for n in nodes if n["id"] == "root")
assert root_node["color"] == "#ff0000"
child_node = next(n for n in nodes if n["id"] == "child")
assert "color" not in child_node
def test_i_do_not_override_ghost_color_with_root_color(self):
"""Ghost nodes must keep ghost_color, not root_color."""
items = [
{"id": "child", "parent": "ghost_parent", "label": "Child"}
]
nodes, edges = from_parent_child_list(
items,
root_color="#ff0000",
ghost_color="#00ff00"
)
ghost_node = next(n for n in nodes if n["id"] == "ghost_parent")
assert ghost_node["color"] == "#00ff00"
# child is not root → no color
child_node = next(n for n in nodes if n["id"] == "child")
assert "color" not in child_node
def test_i_can_use_custom_root_color(self):
"""Test that a custom root_color is applied instead of default."""
items = [{"id": "root", "label": "Root"}]
nodes, edges = from_parent_child_list(items, root_color="#123456")
assert nodes[0]["color"] == "#123456"
def test_i_can_mix_root_nodes_and_ghost_nodes(self):
"""Ensure root_color applies only to roots and ghost nodes keep ghost_color."""
items = [
{"id": "root", "label": "Root"},
{"id": "child", "parent": "ghost_parent", "label": "Child"}
]
nodes, edges = from_parent_child_list(
items,
root_color="#ff0000",
ghost_color="#00ff00"
)
root_node = next(n for n in nodes if n["id"] == "root")
ghost_node = next(n for n in nodes if n["id"] == "ghost_parent")
assert root_node["color"] == "#ff0000"
assert ghost_node["color"] == "#00ff00"
def test_i_do_not_mark_node_as_root_if_parent_field_exists(self):
"""Node with parent key but non-empty value should NOT get root_color."""
items = [
{"id": "root", "label": "Root"},
{"id": "child", "parent": "root", "label": "Child"},
{"id": "other", "parent": "unknown_parent", "label": "Other"}
]
nodes, edges = from_parent_child_list(
items,
root_color="#ff0000",
ghost_color="#00ff00"
)
# "root" is the only real root
root_node = next(n for n in nodes if n["id"] == "root")
assert root_node["color"] == "#ff0000"
# "other" is NOT root, even though its parent is missing
other_node = next(n for n in nodes if n["id"] == "other")
assert "color" not in other_node
# ghost parent must have ghost_color
ghost_node = next(n for n in nodes if n["id"] == "unknown_parent")
assert ghost_node["color"] == "#00ff00"
def test_i_do_no_add_root_color_when_its_none(self):
"""Test that a single root node receives the root_color."""
items = [{"id": "root", "label": "Root"}]
nodes, edges = from_parent_child_list(items, root_color=None)
assert len(nodes) == 1
assert "color" not in nodes[0]