First Working version. I can add table
This commit is contained in:
0
src/core/__init__.py
Normal file
0
src/core/__init__.py
Normal file
335
src/core/dbengine.py
Normal file
335
src/core/dbengine.py
Normal file
@@ -0,0 +1,335 @@
|
||||
import datetime
|
||||
import hashlib
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
from threading import RLock
|
||||
|
||||
from core.serializer import Serializer, DebugSerializer
|
||||
from core.utils import get_stream_digest
|
||||
|
||||
TYPE_KEY = "__type__"
|
||||
TAG_PARENT = "__parent__"
|
||||
TAG_USER = "__user_id__"
|
||||
TAG_DATE = "__date__"
|
||||
BUFFER_SIZE = 4096
|
||||
FAKE_USER_ID = "FakeUserId"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DbException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class RefHelper:
|
||||
def __init__(self, get_obj_path):
|
||||
self.get_obj_path = get_obj_path
|
||||
|
||||
def save_ref(self, obj):
|
||||
"""
|
||||
|
||||
:param obj:
|
||||
:return:
|
||||
"""
|
||||
buffer = io.BytesIO()
|
||||
pickler = pickle.Pickler(buffer)
|
||||
pickler.dump(obj)
|
||||
|
||||
digest = get_stream_digest(buffer)
|
||||
|
||||
target_path = self.get_obj_path(digest)
|
||||
if not os.path.exists(os.path.dirname(target_path)):
|
||||
os.makedirs(os.path.dirname(target_path))
|
||||
|
||||
buffer.seek(0)
|
||||
with open(self.get_obj_path(digest), "wb") as file:
|
||||
while chunk := buffer.read(BUFFER_SIZE):
|
||||
file.write(chunk)
|
||||
|
||||
logger.debug(f"Saved object type '{type(obj).__name__}' with digest {digest}")
|
||||
return digest
|
||||
|
||||
def load_ref(self, digest):
|
||||
"""
|
||||
|
||||
:param digest:
|
||||
:return:
|
||||
"""
|
||||
with open(self.get_obj_path(digest), 'rb') as file:
|
||||
return pickle.load(file)
|
||||
|
||||
|
||||
class DbEngine:
|
||||
"""
|
||||
Personal implementation of DB engine
|
||||
Inspire by the way git manage its files
|
||||
Designed to keep history of the modifications
|
||||
"""
|
||||
ObjectsFolder = "objects" # group objects in the same folder
|
||||
HeadFile = "head" # used to keep track the latest version of all entries
|
||||
|
||||
def __init__(self, root: str = None):
|
||||
self.root = root or ".mytools_db"
|
||||
self.serializer = Serializer(RefHelper(self._get_obj_path))
|
||||
self.debug_serializer = DebugSerializer(RefHelper(self.debug_load))
|
||||
self.lock = RLock()
|
||||
|
||||
def is_initialized(self):
|
||||
"""
|
||||
|
||||
:return:
|
||||
"""
|
||||
return os.path.exists(self.root)
|
||||
|
||||
def init(self):
|
||||
"""
|
||||
Make sure that the DbEngine is properly initialized
|
||||
:return:
|
||||
"""
|
||||
if not os.path.exists(self.root):
|
||||
logger.debug(f"Creating root folder in {os.path.abspath(self.root)}.")
|
||||
os.mkdir(self.root)
|
||||
|
||||
def save(self, user_id: str, entry: str, obj: object) -> str:
|
||||
"""
|
||||
Save a snapshot of an entry
|
||||
:param user_id:
|
||||
:param entry:
|
||||
:param obj: snapshot to save
|
||||
:return:
|
||||
"""
|
||||
with self.lock:
|
||||
logger.info(f"Saving {user_id=}, {entry=}, {obj=}")
|
||||
# prepare the data
|
||||
as_dict = self._serialize(obj)
|
||||
as_dict[TAG_PARENT] = [self._get_entry_digest(entry)]
|
||||
as_dict[TAG_USER] = user_id
|
||||
as_dict[TAG_DATE] = datetime.datetime.now().strftime('%Y%m%d %H:%M:%S %z')
|
||||
|
||||
# transform into a stream
|
||||
as_str = json.dumps(as_dict, sort_keys=True, indent=4)
|
||||
logger.debug(f"Serialized object : {as_str}")
|
||||
byte_stream = as_str.encode("utf-8")
|
||||
|
||||
# compute the digest to know where to store it
|
||||
digest = hashlib.sha256(byte_stream).hexdigest()
|
||||
|
||||
target_path = self._get_obj_path(digest)
|
||||
if os.path.exists(target_path):
|
||||
# the same object is already saved. Noting to do
|
||||
return digest
|
||||
|
||||
# save the new value
|
||||
if not os.path.exists(os.path.dirname(target_path)):
|
||||
os.makedirs(os.path.dirname(target_path))
|
||||
with open(target_path, "wb") as file:
|
||||
file.write(byte_stream)
|
||||
|
||||
# update the head to remember where is the latest entry
|
||||
self._update_head(entry, digest)
|
||||
logger.debug(f"New head for entry '{entry}' is {digest}")
|
||||
return digest
|
||||
|
||||
def load(self, user_id: str, entry, digest=None):
|
||||
"""
|
||||
Loads a snapshot
|
||||
:param user_id:
|
||||
:param entry:
|
||||
:param digest:
|
||||
:return:
|
||||
"""
|
||||
with self.lock:
|
||||
logger.info(f"Loading {user_id=}, {entry=}, {digest=}")
|
||||
|
||||
digest_to_use = digest or self._get_entry_digest(entry)
|
||||
logger.debug(f"Using digest {digest_to_use}.")
|
||||
|
||||
if digest_to_use is None:
|
||||
raise DbException(entry)
|
||||
|
||||
target_file = self._get_obj_path(digest_to_use)
|
||||
with open(target_file, 'r', encoding='utf-8') as file:
|
||||
as_dict = json.load(file)
|
||||
|
||||
return self._deserialize(as_dict)
|
||||
|
||||
def put(self, user_id: str, entry, key: str, value: object):
|
||||
"""
|
||||
Save a specific record.
|
||||
This will create a new snapshot is the record is new or different
|
||||
|
||||
You should not mix the usage of put_many() and save() as it's two different way to manage the db
|
||||
:param user_id:
|
||||
:param entry:
|
||||
:param key:
|
||||
:param value:
|
||||
:return:
|
||||
"""
|
||||
with self.lock:
|
||||
logger.info(f"Adding {user_id=}, {entry=}, {key=}, {value=}")
|
||||
try:
|
||||
entry_content = self.load(user_id, entry)
|
||||
except DbException:
|
||||
entry_content = {}
|
||||
|
||||
# Do not save if the entry is the same
|
||||
if key in entry_content:
|
||||
old_value = entry_content[key]
|
||||
if old_value == value:
|
||||
return False
|
||||
|
||||
entry_content[key] = value
|
||||
self.save(user_id, entry, entry_content)
|
||||
return True
|
||||
|
||||
def put_many(self, user_id: str, entry, items: list):
|
||||
"""
|
||||
Save a list of item as one single snapshot
|
||||
A new snapshot will not be created if all the items already exist
|
||||
|
||||
You should not mix the usage of put_many() and save() as it's two different way to manage the db
|
||||
:param user_id:
|
||||
:param entry:
|
||||
:param items:
|
||||
:return:
|
||||
"""
|
||||
with self.lock:
|
||||
logger.info(f"Adding many {user_id=}, {entry=}, {items=}")
|
||||
try:
|
||||
entry_content = self.load(user_id, entry)
|
||||
except DbException:
|
||||
entry_content = {}
|
||||
|
||||
is_dirty = False
|
||||
for item in items:
|
||||
key = item.get_key()
|
||||
if key in entry_content and entry_content[key] == item:
|
||||
continue
|
||||
else:
|
||||
entry_content[key] = item
|
||||
is_dirty = True
|
||||
|
||||
if is_dirty:
|
||||
self.save(user_id, entry, entry_content)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def exists(self, entry: str):
|
||||
"""
|
||||
Tells if an entry exist
|
||||
:param user_id:
|
||||
:param entry:
|
||||
:return:
|
||||
"""
|
||||
with self.lock:
|
||||
return self._get_entry_digest(entry) is not None
|
||||
|
||||
def get(self, user_id: str, entry: str, key: str | None = None, digest=None):
|
||||
"""
|
||||
Retrieve an item from the snapshot
|
||||
:param user_id:
|
||||
:param entry:
|
||||
:param key:
|
||||
:param digest:
|
||||
:return:
|
||||
"""
|
||||
with self.lock:
|
||||
logger.info(f"Getting {user_id=}, {entry=}, {key=}, {digest=}")
|
||||
entry_content = self.load(user_id, entry, digest)
|
||||
|
||||
if key is None:
|
||||
# return all items as list
|
||||
return [v for k, v in entry_content.items() if not k.startswith("__")]
|
||||
|
||||
return entry_content[key]
|
||||
|
||||
def debug_head(self):
|
||||
with self.lock:
|
||||
head_path = os.path.join(self.root, self.HeadFile)
|
||||
# load
|
||||
try:
|
||||
with open(head_path, 'r') as file:
|
||||
head = json.load(file)
|
||||
except FileNotFoundError:
|
||||
head = {}
|
||||
|
||||
return head
|
||||
|
||||
def debug_load(self, digest):
|
||||
with self.lock:
|
||||
target_file = self._get_obj_path(digest)
|
||||
with open(target_file, 'r', encoding='utf-8') as file:
|
||||
as_dict = json.load(file)
|
||||
|
||||
return self.debug_serializer.deserialize(as_dict)
|
||||
|
||||
def _serialize(self, obj):
|
||||
"""
|
||||
Just call the serializer
|
||||
:param obj:
|
||||
:return:
|
||||
"""
|
||||
# serializer = Serializer(RefHelper(self._get_obj_path))
|
||||
use_refs = getattr(obj, "use_refs")() if hasattr(obj, "use_refs") else None
|
||||
return self.serializer.serialize(obj, use_refs)
|
||||
|
||||
def _deserialize(self, as_dict):
|
||||
return self.serializer.deserialize(as_dict)
|
||||
|
||||
def _update_head(self, entry, digest):
|
||||
"""
|
||||
Actually dumps the snapshot in file system
|
||||
:param entry:
|
||||
:param digest:
|
||||
:return:
|
||||
"""
|
||||
head_path = os.path.join(self.root, self.HeadFile)
|
||||
# load
|
||||
try:
|
||||
with open(head_path, 'r') as file:
|
||||
head = json.load(file)
|
||||
except FileNotFoundError:
|
||||
head = {}
|
||||
|
||||
# update
|
||||
head[entry] = digest
|
||||
|
||||
# and save
|
||||
with open(head_path, 'w') as file:
|
||||
json.dump(head, file)
|
||||
|
||||
def _get_entry_digest(self, entry):
|
||||
"""
|
||||
Search for the latest digest, for a given entry
|
||||
:param entry:
|
||||
:return:
|
||||
"""
|
||||
head_path = os.path.join(self.root, self.HeadFile)
|
||||
try:
|
||||
with open(head_path, 'r') as file:
|
||||
head = json.load(file)
|
||||
return head[str(entry)]
|
||||
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
def _get_head_path(self):
|
||||
"""
|
||||
Location of the Head file
|
||||
:return:
|
||||
"""
|
||||
return os.path.join(self.root, self.HeadFile)
|
||||
|
||||
def _get_obj_path(self, digest):
|
||||
"""
|
||||
Location of objects
|
||||
:param digest:
|
||||
:return:
|
||||
"""
|
||||
return os.path.join(self.root, "objects", digest[:24], digest)
|
||||
59
src/core/handlers.py
Normal file
59
src/core/handlers.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# I delegate the complexity of some data type within specific handlers
|
||||
|
||||
import datetime
|
||||
|
||||
from core.utils import has_tag
|
||||
|
||||
TAG_SPECIAL = "__special__"
|
||||
|
||||
|
||||
class BaseHandler:
|
||||
def is_eligible_for(self, obj):
|
||||
pass
|
||||
|
||||
def tag(self):
|
||||
pass
|
||||
|
||||
def serialize(self, obj) -> dict:
|
||||
pass
|
||||
|
||||
def deserialize(self, data: dict) -> object:
|
||||
pass
|
||||
|
||||
|
||||
class DateHandler(BaseHandler):
|
||||
def is_eligible_for(self, obj):
|
||||
return isinstance(obj, datetime.date)
|
||||
|
||||
def tag(self):
|
||||
return "Date"
|
||||
|
||||
def serialize(self, obj):
|
||||
return {
|
||||
TAG_SPECIAL: self.tag(),
|
||||
"year": obj.year,
|
||||
"month": obj.month,
|
||||
"day": obj.day,
|
||||
}
|
||||
|
||||
def deserialize(self, data: dict) -> object:
|
||||
return datetime.date(year=data["year"], month=data["month"], day=data["day"])
|
||||
|
||||
|
||||
class Handlers:
|
||||
|
||||
def __init__(self, handlers_):
|
||||
self.handlers = handlers_
|
||||
|
||||
def get_handler(self, obj):
|
||||
if has_tag(obj, TAG_SPECIAL):
|
||||
return [h for h in self.handlers if h.tag() == obj[TAG_SPECIAL]][0]
|
||||
|
||||
for h in self.handlers:
|
||||
if h.is_eligible_for(obj):
|
||||
return h
|
||||
|
||||
return None
|
||||
|
||||
|
||||
handlers = Handlers([DateHandler()])
|
||||
108
src/core/instance_manager.py
Normal file
108
src/core/instance_manager.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import logging
|
||||
|
||||
from components.BaseComponent import BaseComponent
|
||||
|
||||
SESSION_ID_KEY = "user_id"
|
||||
NO_SESSION = "__NO_SESSION__"
|
||||
NOT_LOGGED = "__NOT_LOGGED__"
|
||||
|
||||
logger = logging.getLogger("InstanceManager")
|
||||
|
||||
|
||||
def debug_session(session):
|
||||
if session is None:
|
||||
return f"session={NO_SESSION}"
|
||||
else:
|
||||
return f"session={InstanceManager.get_session_id(session)}"
|
||||
|
||||
class InstanceManager:
|
||||
_instances = {}
|
||||
|
||||
@staticmethod
|
||||
def get(session: dict | None, instance_id: str, instance_type: type = None, **kwargs):
|
||||
"""
|
||||
Retrieves an instance from the InstanceManager or creates a new one if it does not exist,
|
||||
using the provided session, instance_id, and instance_type. If the instance already exists
|
||||
in the InstanceManager, it will be returned directly. If the instance does not exist and
|
||||
the instance_type is provided, a new instance will be created based on the instance_type.
|
||||
|
||||
:param session: The session context associated with the instance. Can be None.
|
||||
:type session: dict or None
|
||||
:param instance_id: The unique identifier for the instance to retrieve or create.
|
||||
:type instance_id: str
|
||||
:param instance_type: The class type for creating a new instance if it does not already
|
||||
exist. Defaults to None.
|
||||
:type instance_type: type
|
||||
:param kwargs: Additional keyword arguments to initialize a new instance if it is created.
|
||||
:return: The existing or newly created instance for the given session and instance_id.
|
||||
"""
|
||||
logger.debug(f"'get' session={InstanceManager.get_session_id(session)}, {instance_id=}")
|
||||
|
||||
key = (InstanceManager.get_session_id(session), instance_id)
|
||||
|
||||
if key not in InstanceManager._instances and instance_type is not None:
|
||||
new_instance = instance_type(session, instance_id, **kwargs) \
|
||||
if issubclass(instance_type, BaseComponent) \
|
||||
else instance_type(instance_id, **kwargs)
|
||||
InstanceManager._instances[key] = new_instance
|
||||
|
||||
return InstanceManager._instances[key]
|
||||
|
||||
@staticmethod
|
||||
def register(session: dict | None, instance, instance_id: str = None):
|
||||
"""
|
||||
Register an instance with the given session_id and instance_id.
|
||||
If instance_id is None, attempt to fetch it from the _id attribute of the instance.
|
||||
If the instance has no _id attribute and instance_id is None, raise a ValueError.
|
||||
If session_id is None, it is allowed.
|
||||
|
||||
Args:
|
||||
session (dict): The current session.
|
||||
instance: The instance to be registered.
|
||||
instance_id (str): The instance ID.
|
||||
|
||||
Raises:
|
||||
ValueError: If instance_id is not provided and the instance does not have an _id attribute.
|
||||
"""
|
||||
logger.debug(f"'register' session={InstanceManager.get_session_id(session)}, {instance_id=}")
|
||||
|
||||
if instance_id is None:
|
||||
if hasattr(instance, "_id"):
|
||||
instance_id = getattr(instance, "_id")
|
||||
else:
|
||||
raise ValueError(f"`instance_id` is not provided and the instance '{instance}' has no `_id` attribute.")
|
||||
|
||||
key = (InstanceManager.get_session_id(session), instance_id)
|
||||
InstanceManager._instances[key] = instance
|
||||
|
||||
@staticmethod
|
||||
def register_many(*instances):
|
||||
for instance in instances:
|
||||
InstanceManager.register(None, instance)
|
||||
|
||||
@staticmethod
|
||||
def remove(session: dict, instance_id: str):
|
||||
"""
|
||||
Remove a specific instance by its key.
|
||||
"""
|
||||
logger.debug(f"'remove' session={InstanceManager.get_session_id(session)}, {instance_id=}")
|
||||
key = (InstanceManager.get_session_id(session), instance_id)
|
||||
if key in InstanceManager._instances:
|
||||
instance = InstanceManager._instances[key]
|
||||
if hasattr(instance, "dispose"):
|
||||
instance.dispose()
|
||||
del InstanceManager._instances[key]
|
||||
|
||||
@staticmethod
|
||||
def clear():
|
||||
"""
|
||||
Clear all stored instances.
|
||||
"""
|
||||
logger.debug(f"'clear'")
|
||||
InstanceManager._instances.clear()
|
||||
|
||||
@staticmethod
|
||||
def get_session_id(session: dict | None):
|
||||
return session[SESSION_ID_KEY] if session else NOT_LOGGED
|
||||
|
||||
|
||||
201
src/core/serializer.py
Normal file
201
src/core/serializer.py
Normal file
@@ -0,0 +1,201 @@
|
||||
import copy
|
||||
|
||||
from core.handlers import handlers
|
||||
from core.utils import has_tag, is_dictionary, is_list, is_object, is_set, is_tuple, is_primitive, importable_name, \
|
||||
get_class, get_full_qualified_name, is_enum
|
||||
|
||||
TAG_ID = "__id__"
|
||||
TAG_OBJECT = "__object__"
|
||||
TAG_TUPLE = "__tuple__"
|
||||
TAG_SET = "__set__"
|
||||
TAG_REF = "__ref__"
|
||||
TAG_ENUM = "__enum__"
|
||||
|
||||
|
||||
class Serializer:
|
||||
def __init__(self, ref_helper=None):
|
||||
self.ref_helper = ref_helper
|
||||
|
||||
self.ids = {}
|
||||
self.objs = []
|
||||
self.id_count = 0
|
||||
|
||||
def serialize(self, obj, use_refs=None):
|
||||
"""
|
||||
From object to dictionary
|
||||
:param obj:
|
||||
:param use_refs: Sometimes it easier / quicker to use pickle !
|
||||
:return:
|
||||
"""
|
||||
if use_refs:
|
||||
use_refs = set("root." + path for path in use_refs)
|
||||
|
||||
return self._serialize(obj, use_refs or set(), "root")
|
||||
|
||||
def deserialize(self, obj: dict):
|
||||
"""
|
||||
From dictionary to object (or primitive)
|
||||
:param obj:
|
||||
:return:
|
||||
"""
|
||||
if has_tag(obj, TAG_REF):
|
||||
return self.ref_helper.load_ref(obj[TAG_REF])
|
||||
|
||||
if has_tag(obj, TAG_ID):
|
||||
return self._restore_id(obj)
|
||||
|
||||
if has_tag(obj, TAG_TUPLE):
|
||||
return tuple([self.deserialize(v) for v in obj[TAG_TUPLE]])
|
||||
|
||||
if has_tag(obj, TAG_SET):
|
||||
return set([self.deserialize(v) for v in obj[TAG_SET]])
|
||||
|
||||
if has_tag(obj, TAG_ENUM):
|
||||
return self._deserialize_enum(obj)
|
||||
|
||||
if has_tag(obj, TAG_OBJECT):
|
||||
return self._deserialize_obj_instance(obj)
|
||||
|
||||
if (handler := handlers.get_handler(obj)) is not None:
|
||||
return handler.deserialize(obj)
|
||||
|
||||
if is_list(obj):
|
||||
return [self.deserialize(v) for v in obj]
|
||||
|
||||
if is_dictionary(obj):
|
||||
return {k: self.deserialize(v) for k, v in obj.items()}
|
||||
|
||||
return obj
|
||||
|
||||
def _serialize(self, obj, use_refs: set | None, path):
|
||||
if use_refs is not None and path in use_refs:
|
||||
digest = self.ref_helper.save_ref(obj)
|
||||
return {TAG_REF: digest}
|
||||
|
||||
if is_primitive(obj):
|
||||
return obj
|
||||
|
||||
if is_tuple(obj):
|
||||
return {TAG_TUPLE: [self._serialize(v, use_refs, path) for v in obj]}
|
||||
|
||||
if is_set(obj):
|
||||
return {TAG_SET: [self._serialize(v, use_refs, path) for v in obj]}
|
||||
|
||||
if is_list(obj):
|
||||
return [self._serialize(v, use_refs, path) for v in obj]
|
||||
|
||||
if is_dictionary(obj):
|
||||
return {k: self._serialize(v, use_refs, path) for k, v in obj.items()}
|
||||
|
||||
if is_enum(obj):
|
||||
return self._serialize_enum(obj, use_refs, path)
|
||||
|
||||
if is_object(obj):
|
||||
return self._serialize_obj_instance(obj, use_refs, path)
|
||||
|
||||
raise Exception(f"Cannot serialize '{obj}'")
|
||||
|
||||
def _serialize_enum(self, obj, use_refs: set | None, path):
|
||||
# check if the object was already seen
|
||||
if (seen := self._check_already_seen(obj)) is not None:
|
||||
return seen
|
||||
|
||||
data = {}
|
||||
class_name = get_full_qualified_name(obj)
|
||||
data[TAG_ENUM] = class_name + "." + obj.name
|
||||
return data
|
||||
|
||||
def _serialize_obj_instance(self, obj, use_refs: set | None, path):
|
||||
# check if the object was already seen
|
||||
if (seen := self._check_already_seen(obj)) is not None:
|
||||
return seen
|
||||
|
||||
# try to manage use_refs
|
||||
current_obj_use_refs = getattr(obj, "use_refs")() if hasattr(obj, "use_refs") else None
|
||||
if current_obj_use_refs:
|
||||
use_refs.update(f"{path}.{sub_path}" for sub_path in current_obj_use_refs)
|
||||
|
||||
if (handler := handlers.get_handler(obj)) is not None:
|
||||
return handler.serialize(obj)
|
||||
|
||||
# flatten
|
||||
data = {}
|
||||
cls = obj.__class__ if hasattr(obj, '__class__') else type(obj)
|
||||
class_name = importable_name(cls)
|
||||
data[TAG_OBJECT] = class_name
|
||||
|
||||
if hasattr(obj, "__dict__"):
|
||||
for k, v in obj.__dict__.items():
|
||||
data[k] = self._serialize(v, use_refs, f"{path}.{k}")
|
||||
|
||||
return data
|
||||
|
||||
def _check_already_seen(self, obj):
|
||||
_id = self._exist(obj)
|
||||
if _id is not None:
|
||||
return {TAG_ID: _id}
|
||||
|
||||
# else:
|
||||
self.ids[id(obj)] = self.id_count
|
||||
self.objs.append(obj)
|
||||
self.id_count = self.id_count + 1
|
||||
|
||||
return None
|
||||
|
||||
def _deserialize_enum(self, obj):
|
||||
cls_name, enum_name = obj[TAG_ENUM].rsplit(".", 1)
|
||||
cls = get_class(cls_name)
|
||||
obj = getattr(cls, enum_name)
|
||||
self.objs.append(obj)
|
||||
return obj
|
||||
|
||||
def _deserialize_obj_instance(self, obj):
|
||||
|
||||
cls = get_class(obj[TAG_OBJECT])
|
||||
instance = cls.__new__(cls)
|
||||
self.objs.append(instance)
|
||||
|
||||
for k, v in obj.items():
|
||||
value = self.deserialize(v)
|
||||
setattr(instance, k, value)
|
||||
|
||||
return instance
|
||||
|
||||
def _restore_id(self, obj):
|
||||
try:
|
||||
return self.objs[obj[TAG_ID]]
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
def _exist(self, obj):
|
||||
try:
|
||||
v = self.ids[id(obj)]
|
||||
return v
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
|
||||
class DebugSerializer(Serializer):
|
||||
def __init__(self, ref_helper=None):
|
||||
super().__init__(ref_helper)
|
||||
|
||||
def _deserialize_obj_instance(self, obj):
|
||||
data = {TAG_OBJECT: obj[TAG_OBJECT]}
|
||||
self.objs.append(data)
|
||||
|
||||
for k, v in obj.items():
|
||||
value = self.deserialize(v)
|
||||
data[k] = value
|
||||
|
||||
return data
|
||||
|
||||
def _deserialize_enum(self, obj):
|
||||
cls_name, enum_name = obj[TAG_ENUM].rsplit(".", 1)
|
||||
self.objs.append(enum_name)
|
||||
return enum_name
|
||||
|
||||
def _restore_id(self, obj):
|
||||
try:
|
||||
return copy.deepcopy(self.objs[obj[TAG_ID]])
|
||||
except IndexError:
|
||||
pass
|
||||
210
src/core/settings_management.py
Normal file
210
src/core/settings_management.py
Normal file
@@ -0,0 +1,210 @@
|
||||
import json
|
||||
import logging
|
||||
import os.path
|
||||
|
||||
from core.dbengine import DbEngine, DbException
|
||||
from core.instance_manager import NO_SESSION, NOT_LOGGED
|
||||
from core.settings_objects import *
|
||||
|
||||
load_settings_obj() # needed to make sure that the import of core is not removed
|
||||
|
||||
FAKE_USER_ID = "FakeUserId"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NoDefaultCls:
|
||||
pass
|
||||
|
||||
|
||||
NoDefault = NoDefaultCls()
|
||||
|
||||
|
||||
class DummyDbEngine:
|
||||
"""
|
||||
Dummy DB engine
|
||||
Can only serialize object defined in settings_object module
|
||||
Save everything in a single file
|
||||
"""
|
||||
|
||||
def __init__(self, setting_path="settings.json"):
|
||||
self.db_path = setting_path
|
||||
|
||||
def save(self, user_id: str, entry: str, obj: object) -> bool:
|
||||
if not hasattr(obj, "as_dict"):
|
||||
raise Exception("'as_dict' not found. Not supported")
|
||||
|
||||
as_dict = getattr(obj, "as_dict")()
|
||||
as_dict["__type__"] = type(obj).__name__
|
||||
|
||||
if os.path.exists(self.db_path):
|
||||
with open(self.db_path, "r") as settings_file:
|
||||
as_json = json.load(settings_file)
|
||||
as_json[entry] = as_dict
|
||||
with open(self.db_path, "w") as settings_file:
|
||||
json.dump(as_json, settings_file)
|
||||
else:
|
||||
as_json = {entry: as_dict}
|
||||
with open(self.db_path, "w") as settings_file:
|
||||
json.dump(as_json, settings_file)
|
||||
|
||||
return True
|
||||
|
||||
def load(self, user_id: str, entry: str, digest: str = None):
|
||||
try:
|
||||
with open(self.db_path, "r") as settings_file:
|
||||
as_json = json.load(settings_file)
|
||||
|
||||
as_dict = as_json[entry]
|
||||
obj_type = as_dict.pop("__type__")
|
||||
obj = globals()[obj_type]()
|
||||
getattr(obj, "from_dict")(as_dict)
|
||||
return obj
|
||||
except Exception as ex:
|
||||
raise DbException(f"Entry '{entry}' is not found.")
|
||||
|
||||
def is_initialized(self):
|
||||
return os.path.exists(self.db_path)
|
||||
|
||||
def init(self):
|
||||
pass
|
||||
|
||||
|
||||
class MemoryDbEngine:
|
||||
"""
|
||||
Keeps everything in memory
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.db = {}
|
||||
|
||||
def init_db(self, entry, key, obj):
|
||||
self.db[entry] = {key: obj}
|
||||
|
||||
def save(self, user_id: str, entry: str, obj: object) -> bool:
|
||||
self.db[entry] = obj
|
||||
return True
|
||||
|
||||
def load(self, user_id: str, entry: str, digest: str = None):
|
||||
try:
|
||||
return self.db[entry]
|
||||
except KeyError:
|
||||
return {}
|
||||
|
||||
def get(self, user_id: str, entry: str, key: str | None = None, digest=None):
|
||||
return self.db[entry][key]
|
||||
|
||||
def put(self, user_id: str, entry, key: str, value: object):
|
||||
if entry not in self.db:
|
||||
self.db[entry] = {}
|
||||
self.db[entry][key] = value
|
||||
|
||||
def is_initialized(self):
|
||||
return True
|
||||
|
||||
|
||||
class SettingsManager:
|
||||
def __init__(self, engine=None):
|
||||
self._db_engine = engine or DbEngine()
|
||||
|
||||
def save(self, user_id: str, entry: str, obj: object):
|
||||
return self._db_engine.save(user_id, entry, obj)
|
||||
|
||||
def load(self, user_id: str, entry: str):
|
||||
return self._db_engine.load(user_id, entry)
|
||||
|
||||
def get_all(self, user_id: str, entry: str):
|
||||
""""
|
||||
Returns all the items of an entry
|
||||
"""
|
||||
return self._db_engine.get(user_id, entry, None)
|
||||
|
||||
def put(self, session: dict, key: str, value: object):
|
||||
"""
|
||||
Inserts or updates a key-value pair in the database for the current user session.
|
||||
The method extracts the user ID and email from the session dictionary and
|
||||
utilizes the database engine to perform the storage operation.
|
||||
|
||||
:param session: A dictionary containing session-specific details,
|
||||
including 'user_id' and 'user_email'.
|
||||
:type session: dict
|
||||
:param key: The key under which the value should be stored in the database.
|
||||
:type key: str
|
||||
:param value: The value to be stored, associated with the specified key.
|
||||
:type value: object
|
||||
:return: The result of the database engine's put operation.
|
||||
:rtype: object
|
||||
"""
|
||||
user_id = session["user_id"] if session else NO_SESSION
|
||||
user_email = session["user_email"] if session else NOT_LOGGED
|
||||
return self._db_engine.put(user_email, str(user_id), key, value)
|
||||
|
||||
def get(self, session: dict, key: str | None = None, default=NoDefault):
|
||||
"""
|
||||
Fetches a value associated with a specific key for a user session from the
|
||||
database. If the key is not found in the database and a default value is
|
||||
provided, returns the default value. If no default is provided and the key
|
||||
is not found, raises a KeyError.
|
||||
|
||||
:param session: A dictionary containing session data. Must include "user_id"
|
||||
and "user_email" keys.
|
||||
:type session: dict
|
||||
:param key: The key to fetch from the database for the given session user.
|
||||
Defaults to None if not specified.
|
||||
:type key: str | None
|
||||
:param default: The default value to return if the key is not found in the
|
||||
database. If not provided, raises KeyError when the key is missing.
|
||||
:type default: Any
|
||||
:return: The value associated with the key for the user session if found in
|
||||
the database, or the provided default value if the key is not found.
|
||||
"""
|
||||
try:
|
||||
user_id = session["user_id"] if session else NO_SESSION
|
||||
user_email = session["user_email"] if session else NOT_LOGGED
|
||||
|
||||
return self._db_engine.get(user_email, str(user_id), key)
|
||||
except KeyError:
|
||||
if default is NoDefault:
|
||||
raise
|
||||
else:
|
||||
return default
|
||||
|
||||
def init_user(self, user_id: str, user_email: str):
|
||||
"""
|
||||
Init the settings block space for a user
|
||||
:param user_id:
|
||||
:param user_email:
|
||||
:return:
|
||||
"""
|
||||
if not self._db_engine.exists(user_id):
|
||||
self._db_engine.save(user_email, user_id, {})
|
||||
|
||||
def get_db_engine_root(self):
|
||||
return os.path.abspath(self._db_engine.root)
|
||||
|
||||
def get_db_engine(self):
|
||||
return self._db_engine
|
||||
|
||||
|
||||
class SettingsTransaction:
|
||||
def __init__(self, session, settings_manager: SettingsManager):
|
||||
self._settings_manager = settings_manager
|
||||
self._session = session
|
||||
self._user_id = session["user_id"] if session else NO_SESSION
|
||||
self._user_email = session["user_email"] if session else NOT_LOGGED
|
||||
self._entries = None
|
||||
|
||||
def __enter__(self):
|
||||
self._entries = self._settings_manager.load(self._user_email, self._user_id)
|
||||
return self
|
||||
|
||||
def put(self, key: str, value: object):
|
||||
self._entries[key] = value
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if exc_type is None:
|
||||
self._settings_manager.save(self._user_email, self._user_id, self._entries)
|
||||
|
||||
#
|
||||
# settings_manager = SettingsManager()
|
||||
# settings_manager.init()
|
||||
303
src/core/settings_objects.py
Normal file
303
src/core/settings_objects.py
Normal file
@@ -0,0 +1,303 @@
|
||||
import dataclasses
|
||||
|
||||
from pandas import DataFrame
|
||||
|
||||
|
||||
def load_settings_obj():
|
||||
"""
|
||||
Do not remove. Used to dynamically load objects
|
||||
:return:
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
BUDGET_TRACKER_SETTINGS_ENTRY = "BudgetTrackerSettings"
|
||||
BUDGET_TRACKER_MAPPINGS_ENTRY = "BudgetTrackerMappings"
|
||||
|
||||
PROJECTS_CODES_SETTINGS_ENTRY = "ProjectsCodesSettings"
|
||||
PROJECTS_CODES_ENTRY = "ProjectsCodes"
|
||||
|
||||
BUDGET_TRACKER_FILES_ENTRY = "BudgetTrackerFiles"
|
||||
BUDGETS_FILES_ENTRY = "BudgetFiles"
|
||||
VIEWS_DEFINITIONS_ENTRY = "ViewsDefinitions"
|
||||
MAPPINGS_DEFINITIONS_ENTRY = "MappingsDefinitions"
|
||||
IMPORT_SETTINGS_ENTRY = "ImportSettingsDefinitions"
|
||||
VIEWS_DEFINITIONS_NEW_ENTRY = "ViewsDefinitionsNew"
|
||||
|
||||
# Get the columns names from BudgetTrackerSettings
|
||||
COL_ROW_NUM = "col_row_num"
|
||||
COL_PROJECT = "col_project"
|
||||
COL_OWNER = "col_owner"
|
||||
COL_CAPEX = "col_capex"
|
||||
COL_DETAILS = "col_details"
|
||||
COL_SUPPLIER = "col_supplier"
|
||||
COL_BUDGET = "col_budget"
|
||||
COL_ACTUAL = "col_actual"
|
||||
COL_FORCAST5_7 = "col_forecast5_7"
|
||||
# other columns
|
||||
COL_INDEX = "col_index"
|
||||
COL_LEVEL1 = "col_level1"
|
||||
COL_LEVEL2 = "col_level2"
|
||||
COL_LEVEL3 = "col_level3"
|
||||
COL_PERCENTAGE = "col_percentage"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BaseSettingObj:
|
||||
def as_dict(self):
|
||||
return dataclasses.asdict(self)
|
||||
|
||||
def from_dict(self, as_dict):
|
||||
for k, v in as_dict.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
return self
|
||||
|
||||
def format_props(self):
|
||||
return {}
|
||||
|
||||
def get_display_name(self, prop):
|
||||
return self.format_props().get(prop, prop)
|
||||
|
||||
def use_refs(self) -> set:
|
||||
"""
|
||||
List of attributes to store as reference, rather than as dictionary
|
||||
:return:
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BudgetTrackerSettings(BaseSettingObj):
|
||||
"""
|
||||
Class that holds the settings to read the 'Suivi de budget' xls file
|
||||
"""
|
||||
spread_sheet: str = "full charges"
|
||||
col_row_num: str = "A"
|
||||
col_project: str = "D"
|
||||
col_owner: str = "E"
|
||||
col_capex: str = "G"
|
||||
col_details: str = "H"
|
||||
col_supplier: str = "I"
|
||||
col_budget_amt: str = "BQ"
|
||||
col_actual_amt: str = "AO"
|
||||
col_forecast5_7_amt: str = "BC"
|
||||
|
||||
def format_props(self):
|
||||
return {
|
||||
"spread_sheet": "Spread Sheet",
|
||||
"col_row_num": "Row Number",
|
||||
"col_project": "Project",
|
||||
"col_owner": "Owner",
|
||||
"col_capex": "Capex",
|
||||
"col_details": "Details",
|
||||
"col_supplier": "Supplier",
|
||||
"col_budget_amt": "Budget",
|
||||
"col_actual_amt": "Actual",
|
||||
"col_forecast5_7_amt": "Forecast 5+7"
|
||||
}
|
||||
|
||||
def get_key(self):
|
||||
return self.spread_sheet
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BudgetTrackerMappings(BaseSettingObj):
|
||||
"""
|
||||
This class holds the link between the old nomenclature
|
||||
as it's used in xsl 'budget tracking' file
|
||||
and the new one we want to use from the MDM (Master Data Management)
|
||||
"""
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Mapping:
|
||||
# Keys to use for the matching
|
||||
col_index: int = 0
|
||||
col_project: str = None
|
||||
col_owner: str = None
|
||||
col_details: str = None
|
||||
col_supplier: str = None
|
||||
|
||||
# Values to add
|
||||
col_level1: str = None
|
||||
col_level2: str = None
|
||||
col_level3: str = None
|
||||
col_percentage: int = None
|
||||
|
||||
mappings: list[Mapping] = None
|
||||
|
||||
def from_dict(self, as_dict):
|
||||
self.mappings = [BudgetTrackerMappings.Mapping(**item) for item in as_dict["mappings"]]
|
||||
return self
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ProjectsCodesSettings(BaseSettingObj):
|
||||
"""
|
||||
Class that holds the settings to read the project code xls file
|
||||
"""
|
||||
spread_sheet: str = "Proposition"
|
||||
col_level1: str = "E"
|
||||
col_level2: str = "F"
|
||||
col_level3: str = "G"
|
||||
min_row: int = 150
|
||||
|
||||
def format_props(self):
|
||||
return {
|
||||
"col_level1": "Niveau 1",
|
||||
"col_level2": "Niveau 2",
|
||||
"col_level3": "Niveau 3",
|
||||
}
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ProjectsCodes(BaseSettingObj):
|
||||
"""
|
||||
Class that lists the projects id
|
||||
It only stores the aggregated value of the 3 levels, concatenated with a hyphen '-'
|
||||
"""
|
||||
level1: list = None
|
||||
level2: dict = None
|
||||
level3: dict = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BudgetTrackerFile(BaseSettingObj):
|
||||
year: int = None
|
||||
month: int = None
|
||||
file_name: str = None # original file name
|
||||
sheet_name: str = None # when uploaded from a file, sheet to use
|
||||
grid_settings: dict = None
|
||||
data: DataFrame = dataclasses.field(default=None, compare=False)
|
||||
|
||||
@staticmethod
|
||||
def use_refs() -> set:
|
||||
return {"data"}
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.year}-{self.month}"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BudgetTrackerFiles(BaseSettingObj):
|
||||
"""
|
||||
Stores all the budget tracker files
|
||||
"""
|
||||
files: list[BudgetTrackerFile] = dataclasses.field(default_factory=list)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ViewDefinition(BaseSettingObj):
|
||||
"""
|
||||
This class is used to define a new set of codification
|
||||
"""
|
||||
name: str = None
|
||||
sheet_name: str = None # when uploaded from a file, sheet to use
|
||||
grid_settings: dict = None
|
||||
data: DataFrame = dataclasses.field(default=None, compare=False)
|
||||
|
||||
@staticmethod
|
||||
def use_refs() -> set:
|
||||
return {"data"}
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ViewsDefinitions(BaseSettingObj):
|
||||
"""
|
||||
This class is used to define a new set of codification
|
||||
"""
|
||||
views: list[ViewDefinition] = dataclasses.field(default_factory=list)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MappingDefinition(BaseSettingObj):
|
||||
"""
|
||||
Define the items from the left will translate into items in the right
|
||||
We only need to store the names of the columns of the two sides
|
||||
"""
|
||||
name: str = None # name given to this mapping
|
||||
file: str = None # budget tracking file from with the mapping is constructed
|
||||
view: str = None # view from with the mapping is constructed
|
||||
keys: list[str] = None # columns from the tracking file
|
||||
values: list[str] = None # columns from the view
|
||||
grid_settings: dict = None # keys + values + percent display names
|
||||
data: DataFrame = dataclasses.field(default=None, compare=False)
|
||||
|
||||
@staticmethod
|
||||
def use_refs() -> set:
|
||||
return {"data"}
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MappingsDefinitions(BaseSettingObj):
|
||||
mappings: list[MappingDefinition] = dataclasses.field(default_factory=list)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ImportSettingsDefinition(BaseSettingObj):
|
||||
@dataclasses.dataclass
|
||||
class ColumnDef:
|
||||
col_name: str = None
|
||||
col_location: str = None
|
||||
col_display_name: str = None
|
||||
|
||||
"""
|
||||
Defines how to import data from an excel file
|
||||
"""
|
||||
name: str = None # name of the setting
|
||||
setting_type: str = None
|
||||
spread_sheet: str = None # spreadsheet where the info is located
|
||||
columns_definitions: list[ColumnDef] = None
|
||||
|
||||
def get_key(self):
|
||||
return f"{self.name}|{self.setting_type}"
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BudgetFile(BaseSettingObj):
|
||||
name: str = None
|
||||
year: int = None
|
||||
month: int = None
|
||||
file_name: str = None # original file name
|
||||
original_import_settings: str = None # import setting used to import
|
||||
sheet_name: str = None # when uploaded from a file, sheet to use
|
||||
grid_settings: dict = None
|
||||
data: DataFrame = dataclasses.field(default=None, compare=False) # content of the file
|
||||
|
||||
@staticmethod
|
||||
def use_refs() -> set:
|
||||
return {"data"}
|
||||
|
||||
def get_key(self):
|
||||
return f"{self.name}-{self.year}-{self.month}"
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name}-{self.year}-{self.month}"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ViewDefinition2(BaseSettingObj):
|
||||
"""
|
||||
This class is used to define a new set of codification
|
||||
"""
|
||||
name: str = None
|
||||
import_settings_name: str = None # when uploaded from a file, sheet to use
|
||||
file_name: str = None
|
||||
grid_settings: dict = None
|
||||
data: DataFrame = dataclasses.field(default=None, compare=False)
|
||||
|
||||
@staticmethod
|
||||
def use_refs() -> set:
|
||||
return {"data"}
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
258
src/core/user_dao.py
Normal file
258
src/core/user_dao.py
Normal file
@@ -0,0 +1,258 @@
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from .user_database import user_db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class UserDAO:
|
||||
"""Data Access Object for user management."""
|
||||
|
||||
@staticmethod
|
||||
def create_user(username: str, email: str, password: str | None = None, github_id: str | None = None) -> int:
|
||||
"""
|
||||
Create a new user with email/password or GitHub authentication.
|
||||
|
||||
Args:
|
||||
username: The username
|
||||
email: The user's email
|
||||
password: The user's password (optional)
|
||||
github_id: GitHub user ID (optional)
|
||||
|
||||
Returns:
|
||||
int: ID of the new user or 0 if creation failed
|
||||
"""
|
||||
try:
|
||||
with user_db.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Check if user already exists
|
||||
cursor.execute(
|
||||
"SELECT id FROM users WHERE email = ? OR username = ? OR (github_id = ? AND github_id IS NOT NULL)",
|
||||
(email, username, github_id)
|
||||
)
|
||||
|
||||
if cursor.fetchone():
|
||||
# User already exists
|
||||
return 0
|
||||
|
||||
# Prepare values for insertion
|
||||
password_hash = None
|
||||
salt = None
|
||||
|
||||
if password:
|
||||
# Generate salt and hash password for email auth
|
||||
salt = secrets.token_hex(16)
|
||||
password_hash = user_db._hash_password(password, salt)
|
||||
|
||||
cursor.execute('''
|
||||
INSERT INTO users (username, email, password_hash, salt, github_id, is_admin)
|
||||
VALUES (?, ?, ?, ?, ?, 0)
|
||||
''', (username, email, password_hash, salt, github_id))
|
||||
|
||||
conn.commit()
|
||||
return cursor.lastrowid
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating user: {e}")
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
def authenticate_email(email: str, password: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Authenticate a user with email and password.
|
||||
|
||||
Args:
|
||||
email: The user's email
|
||||
password: The user's password
|
||||
|
||||
Returns:
|
||||
Dict or None: User record if authentication succeeds, None otherwise
|
||||
"""
|
||||
with user_db.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Get user record by email
|
||||
cursor.execute(
|
||||
"SELECT id, username, email, password_hash, salt, is_admin FROM users WHERE email = ?",
|
||||
(email,)
|
||||
)
|
||||
|
||||
user = cursor.fetchone()
|
||||
if not user or not user['password_hash'] or not user['salt']:
|
||||
return None
|
||||
|
||||
# Hash the provided password with the stored salt
|
||||
password_hash = user_db._hash_password(password, user['salt'])
|
||||
|
||||
# Check if password matches
|
||||
if password_hash != user['password_hash']:
|
||||
return None
|
||||
|
||||
# Update last login time
|
||||
cursor.execute(
|
||||
"UPDATE users SET last_login = ? WHERE id = ?",
|
||||
(datetime.now().isoformat(), user['id'])
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
# Return user info
|
||||
return dict(user)
|
||||
|
||||
@staticmethod
|
||||
def find_or_create_github_user(github_id: str, username: str, email: str | None) -> dict[str, Any] | None:
|
||||
"""
|
||||
Find existing GitHub user or create a new one.
|
||||
|
||||
Args:
|
||||
github_id: GitHub user ID
|
||||
username: The username from GitHub
|
||||
email: The email from GitHub (may be None)
|
||||
|
||||
Returns:
|
||||
Dict or None: User record if found or created, None on error
|
||||
"""
|
||||
with user_db.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Try to find user by GitHub ID
|
||||
cursor.execute(
|
||||
"SELECT id, username, email, is_admin FROM users WHERE github_id = ?",
|
||||
(github_id,)
|
||||
)
|
||||
|
||||
user = cursor.fetchone()
|
||||
if user:
|
||||
# Update last login time
|
||||
cursor.execute(
|
||||
"UPDATE users SET last_login = ? WHERE id = ?",
|
||||
(datetime.now().isoformat(), user['id'])
|
||||
)
|
||||
conn.commit()
|
||||
return dict(user)
|
||||
|
||||
# Create new user
|
||||
# Use GitHub username with random suffix if email not provided
|
||||
user_email = email or f"{username}-{secrets.token_hex(4)}@github.user"
|
||||
|
||||
try:
|
||||
cursor.execute('''
|
||||
INSERT INTO users (username, email, github_id, is_admin)
|
||||
VALUES (?, ?, ?, 0)
|
||||
''', (username, user_email, github_id))
|
||||
|
||||
user_id = cursor.lastrowid
|
||||
conn.commit()
|
||||
|
||||
# Return the new user info
|
||||
return {
|
||||
'id': user_id,
|
||||
'username': username,
|
||||
'email': user_email,
|
||||
'is_admin': 0
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating GitHub user: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_user_by_id(user_id: int) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get a user by ID.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
|
||||
Returns:
|
||||
Dict or None: User record if found, None otherwise
|
||||
"""
|
||||
with user_db.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute(
|
||||
"SELECT id, username, email, is_admin, created_at, last_login FROM users WHERE id = ?",
|
||||
(user_id,)
|
||||
)
|
||||
|
||||
user = cursor.fetchone()
|
||||
return dict(user) if user else None
|
||||
|
||||
@staticmethod
|
||||
def get_all_users(limit: int = 100, offset: int = 0) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Get all users with pagination.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of users to return
|
||||
offset: Number of users to skip
|
||||
|
||||
Returns:
|
||||
List of user records
|
||||
"""
|
||||
with user_db.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
SELECT id, username, email, is_admin, created_at, last_login,
|
||||
(github_id IS NOT NULL) as is_github_user
|
||||
FROM users
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ? OFFSET ?
|
||||
''', (limit, offset))
|
||||
|
||||
return [dict(user) for user in cursor.fetchall()]
|
||||
|
||||
@staticmethod
|
||||
def set_admin_status(user_id: int, is_admin: bool) -> bool:
|
||||
"""
|
||||
Change a user's admin status.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
is_admin: True to make admin, False to remove admin status
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
with user_db.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute(
|
||||
"UPDATE users SET is_admin = ? WHERE id = ?",
|
||||
(1 if is_admin else 0, user_id)
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
except Exception as e:
|
||||
logger.error(f"Error setting admin status: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def delete_user(user_id: int) -> bool:
|
||||
"""
|
||||
Delete a user and all their data.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
with user_db.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Delete user's history records
|
||||
cursor.execute("DELETE FROM title_history WHERE user_id = ?", (user_id,))
|
||||
|
||||
# Delete user
|
||||
cursor.execute("DELETE FROM users WHERE id = ?", (user_id,))
|
||||
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting user: {e}")
|
||||
return False
|
||||
115
src/core/user_database.py
Normal file
115
src/core/user_database.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import logging
|
||||
import sqlite3
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
import config
|
||||
import hashlib
|
||||
import secrets
|
||||
|
||||
logger = logging.getLogger("UserDatabase")
|
||||
|
||||
class Database:
|
||||
"""Handles database connections and initialization."""
|
||||
|
||||
def __init__(self, db_path=None):
|
||||
"""
|
||||
Initialize the database connection.
|
||||
|
||||
Args:
|
||||
db_path: Path to the SQLite database file (defaults to config setting)
|
||||
"""
|
||||
# If db_path is None or empty, use a default path
|
||||
self.db_path = db_path or config.DB_PATH
|
||||
if not self.db_path:
|
||||
# Set default path if DB_PATH is empty
|
||||
self.db_path = "tools.db"
|
||||
self._initialize_db()
|
||||
|
||||
def _initialize_db(self):
|
||||
"""Create database tables if they don't exist."""
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Create the users table
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
username TEXT UNIQUE,
|
||||
email TEXT UNIQUE,
|
||||
password_hash TEXT,
|
||||
salt TEXT,
|
||||
github_id TEXT UNIQUE,
|
||||
is_admin BOOLEAN DEFAULT 0,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
last_login TIMESTAMP
|
||||
)
|
||||
''')
|
||||
logger.info("Created users table")
|
||||
|
||||
|
||||
# Check if we need to create an admin user
|
||||
cursor.execute("SELECT COUNT(*) FROM users WHERE is_admin = 1")
|
||||
if cursor.fetchone()[0] == 0:
|
||||
if config.ADMIN_EMAIL and config.ADMIN_PASSWORD:
|
||||
# Create admin user if credentials are provided in config
|
||||
salt = secrets.token_hex(16)
|
||||
password_hash = self._hash_password(config.ADMIN_PASSWORD, salt)
|
||||
|
||||
cursor.execute('''
|
||||
INSERT INTO users (username, email, password_hash, salt, is_admin)
|
||||
VALUES (?, ?, ?, ?, 1)
|
||||
''', ('admin', config.ADMIN_EMAIL, password_hash, salt))
|
||||
logger.info("Created admin user")
|
||||
else:
|
||||
logger.error(f"Failed to create admin user. Admin user is '{config.ADMIN_EMAIL}'.")
|
||||
|
||||
conn.commit()
|
||||
|
||||
@staticmethod
|
||||
def _hash_password(password, salt):
|
||||
"""
|
||||
Hash a password with the given salt using PBKDF2.
|
||||
|
||||
Args:
|
||||
password: The plain text password
|
||||
salt: The salt to use
|
||||
|
||||
Returns:
|
||||
str: The hashed password
|
||||
"""
|
||||
# Use PBKDF2 with SHA-256, 100,000 iterations
|
||||
return hashlib.pbkdf2_hmac(
|
||||
'sha256',
|
||||
password.encode('utf-8'),
|
||||
salt.encode('utf-8'),
|
||||
100000
|
||||
).hex()
|
||||
|
||||
@contextmanager
|
||||
def get_connection(self):
|
||||
"""
|
||||
Context manager for database connections.
|
||||
|
||||
Yields:
|
||||
sqlite3.Connection: Active database connection
|
||||
"""
|
||||
# Check if db_path has a directory component
|
||||
db_dir = os.path.dirname(self.db_path)
|
||||
|
||||
# Only try to create directories if there's a directory path
|
||||
if db_dir:
|
||||
os.makedirs(db_dir, exist_ok=True)
|
||||
|
||||
# Connect to the database
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
|
||||
# Configure connection
|
||||
conn.row_factory = sqlite3.Row # Use dictionary-like rows
|
||||
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
# Create a singleton instance
|
||||
user_db = Database()
|
||||
380
src/core/utils.py
Normal file
380
src/core/utils.py
Normal file
@@ -0,0 +1,380 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import importlib
|
||||
import inspect
|
||||
import pkgutil
|
||||
import re
|
||||
import types
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
|
||||
import pandas as pd
|
||||
|
||||
PRIMITIVES = (str, bool, type(None), int, float)
|
||||
|
||||
|
||||
def get_stream_digest(stream):
|
||||
"""
|
||||
Compute a SHA256 from a stream
|
||||
:param stream:
|
||||
:type stream:
|
||||
:return:
|
||||
:rtype:
|
||||
"""
|
||||
sha256_hash = hashlib.sha256()
|
||||
stream.seek(0)
|
||||
for byte_block in iter(lambda: stream.read(4096), b""):
|
||||
sha256_hash.update(byte_block)
|
||||
|
||||
return sha256_hash.hexdigest()
|
||||
|
||||
|
||||
def has_tag(obj, tag):
|
||||
"""
|
||||
|
||||
:param obj:
|
||||
:param tag:
|
||||
:return:
|
||||
"""
|
||||
return type(obj) is dict and tag in obj
|
||||
|
||||
|
||||
def is_primitive(obj):
|
||||
"""
|
||||
|
||||
:param obj:
|
||||
:return:
|
||||
"""
|
||||
return isinstance(obj, PRIMITIVES)
|
||||
|
||||
|
||||
def is_dictionary(obj):
|
||||
"""
|
||||
|
||||
:param obj:
|
||||
:return:
|
||||
"""
|
||||
return isinstance(obj, dict)
|
||||
|
||||
|
||||
def is_list(obj):
|
||||
"""
|
||||
|
||||
:param obj:
|
||||
:return:
|
||||
"""
|
||||
return isinstance(obj, list)
|
||||
|
||||
|
||||
def is_set(obj):
|
||||
"""
|
||||
|
||||
:param obj:
|
||||
:return:
|
||||
"""
|
||||
return isinstance(obj, set)
|
||||
|
||||
|
||||
def is_tuple(obj):
|
||||
"""
|
||||
|
||||
:param obj:
|
||||
:return:
|
||||
"""
|
||||
return isinstance(obj, tuple)
|
||||
|
||||
|
||||
def is_enum(obj):
|
||||
return isinstance(obj, Enum)
|
||||
|
||||
|
||||
def is_object(obj):
|
||||
"""Returns True is obj is a reference to an object instance."""
|
||||
|
||||
return (isinstance(obj, object) and
|
||||
not isinstance(obj, (type,
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
types.GeneratorType)))
|
||||
|
||||
|
||||
def get_full_qualified_name(obj):
|
||||
"""
|
||||
Returns the full qualified name of a class (including its module name )
|
||||
:param obj:
|
||||
:return:
|
||||
"""
|
||||
if obj.__class__ == type:
|
||||
module = obj.__module__
|
||||
if module is None or module == str.__class__.__module__:
|
||||
return obj.__name__ # Avoid reporting __builtin__
|
||||
else:
|
||||
return module + '.' + obj.__name__
|
||||
else:
|
||||
module = obj.__class__.__module__
|
||||
if module is None or module == str.__class__.__module__:
|
||||
return obj.__class__.__name__ # Avoid reporting __builtin__
|
||||
else:
|
||||
return module + '.' + obj.__class__.__name__
|
||||
|
||||
|
||||
def importable_name(cls):
|
||||
"""
|
||||
Fully qualified name (prefixed by builtin when needed)
|
||||
"""
|
||||
# Use the fully-qualified name if available (Python >= 3.3)
|
||||
name = getattr(cls, '__qualname__', cls.__name__)
|
||||
|
||||
# manage python 2
|
||||
lookup = dict(__builtin__='builtins', exceptions='builtins')
|
||||
module = lookup.get(cls.__module__, cls.__module__)
|
||||
|
||||
return f"{module}.{name}"
|
||||
|
||||
|
||||
def get_class(qualified_class_name: str):
|
||||
"""
|
||||
Dynamically loads and returns a class type from its fully qualified name.
|
||||
Note that the class is not instantiated.
|
||||
|
||||
:param qualified_class_name: Fully qualified name of the class (e.g., 'some.module.ClassName').
|
||||
:return: The class object.
|
||||
:raises ImportError: If the module cannot be imported.
|
||||
:raises AttributeError: If the class cannot be resolved in the module.
|
||||
"""
|
||||
module_name, class_name = qualified_class_name.rsplit(".", 1)
|
||||
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
except ModuleNotFoundError as e:
|
||||
raise ImportError(f"Could not import module '{module_name}' for '{qualified_class_name}': {e}")
|
||||
|
||||
if not hasattr(module, class_name):
|
||||
raise AttributeError(f"Component '{class_name}' not found in '{module.__name__}'.")
|
||||
|
||||
return getattr(module, class_name)
|
||||
|
||||
|
||||
def make_html_id(s: str | None) -> str | None:
|
||||
"""
|
||||
Creates a valid html id
|
||||
:param s:
|
||||
:return:
|
||||
"""
|
||||
if s is None:
|
||||
return None
|
||||
|
||||
s = str(s).strip()
|
||||
# Replace spaces and special characters with hyphens or remove them
|
||||
s = re.sub(r'[^a-zA-Z0-9_-]', '-', s)
|
||||
|
||||
# Ensure the ID starts with a letter or underscore
|
||||
if not re.match(r'^[a-zA-Z_]', s):
|
||||
s = 'id_' + s # Add a prefix if it doesn't
|
||||
|
||||
# Collapse multiple consecutive hyphens into one
|
||||
s = re.sub(r'-+', '-', s)
|
||||
|
||||
# Replace trailing hyphens with underscores
|
||||
s = re.sub(r'-+$', '_', s)
|
||||
|
||||
return s
|
||||
|
||||
|
||||
def snake_case_to_capitalized_words(s: str) -> str:
|
||||
"""
|
||||
Try to (re)create the column title from the column id
|
||||
>>> assert snake_case_to_capitalized_words("column_id") == "Column Id"
|
||||
>>> assert snake_case_to_capitalized_words("this_is_a_column_name") == "This Is A Column Name"
|
||||
:param s:
|
||||
:return:
|
||||
"""
|
||||
parts = s.split('_')
|
||||
capitalized_parts = [part.capitalize() for part in parts]
|
||||
|
||||
# Join the capitalized parts with spaces
|
||||
transformed_name = ' '.join(capitalized_parts)
|
||||
|
||||
return transformed_name
|
||||
|
||||
|
||||
def make_column_id(s: str | None):
|
||||
if s is None:
|
||||
return None
|
||||
|
||||
res = re.sub('-', '_', make_html_id(s)) # replace '-' by '_'
|
||||
return res.lower() # no uppercase
|
||||
|
||||
|
||||
def update_elements(elts, updates: list[dict]):
|
||||
"""
|
||||
walk through elements and update them if needed
|
||||
:param elts:
|
||||
:param updates:
|
||||
:return:
|
||||
"""
|
||||
|
||||
def _update_elt(_elt):
|
||||
if hasattr(_elt, 'attrs'):
|
||||
for blue_print in updates:
|
||||
if "id" in _elt.attrs and _elt.attrs["id"] == blue_print["id"]:
|
||||
method = blue_print["method"]
|
||||
_elt.attrs[method] = blue_print["value"]
|
||||
|
||||
if hasattr(_elt, "children"):
|
||||
for child in _elt.children:
|
||||
_update_elt(child)
|
||||
|
||||
if elts is None:
|
||||
return None
|
||||
|
||||
to_use = elts if isinstance(elts, (list, tuple, set)) else [elts]
|
||||
for elt in to_use:
|
||||
_update_elt(elt)
|
||||
|
||||
return elts
|
||||
|
||||
|
||||
def get_sheets_names(file_content):
|
||||
try:
|
||||
excel_file = pd.ExcelFile(BytesIO(file_content))
|
||||
sheet_names = excel_file.sheet_names
|
||||
except Exception:
|
||||
sheet_names = []
|
||||
|
||||
return sheet_names
|
||||
|
||||
|
||||
def to_bool(value: str):
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(value, str):
|
||||
raise NotImplemented("Cannot convert to bool")
|
||||
|
||||
return value.lower() in ("yes", "true", "t", "1")
|
||||
|
||||
|
||||
def from_bool(value: bool):
|
||||
return "true" if value else "false"
|
||||
|
||||
|
||||
def append_once(lst: list, elt):
|
||||
if elt in lst:
|
||||
return
|
||||
|
||||
lst.append(elt)
|
||||
|
||||
|
||||
def find_classes_in_modules(modules, base_class_name):
|
||||
"""
|
||||
Recursively search for all classes in the given list of modules (and their submodules)
|
||||
that inherit from a specified base class.
|
||||
|
||||
:param modules: List of top-level module names (e.g., ["core.settings_objects", "another.module"])
|
||||
:param base_class_name: Name of the base class to search for (e.g., "BaseSettingObj")
|
||||
"""
|
||||
# List to store matching classes
|
||||
derived_classes = []
|
||||
|
||||
def inspect_module(_module_name):
|
||||
"""Recursively inspect a module and its submodules for matching classes."""
|
||||
try:
|
||||
# Import the module dynamically
|
||||
module = importlib.import_module(_module_name)
|
||||
|
||||
# Iterate over all objects in the module
|
||||
for name, obj in inspect.getmembers(module, inspect.isclass):
|
||||
# Check if the class inherits from the specified base class
|
||||
for base in obj.__bases__:
|
||||
if base.__name__ == base_class_name:
|
||||
derived_classes.append(f"{_module_name}.{name}")
|
||||
|
||||
# Recursively inspect submodules
|
||||
if hasattr(module, "__path__"): # Check if the module has submodules
|
||||
for submodule_info in pkgutil.iter_modules(module.__path__):
|
||||
inspect_module(f"{_module_name}.{submodule_info.name}")
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Start inspecting from the top-level modules
|
||||
for module_name in modules:
|
||||
inspect_module(module_name)
|
||||
|
||||
return derived_classes
|
||||
|
||||
|
||||
def instantiate_class(qualified_class_name):
|
||||
"""
|
||||
Dynamically instantiates a class provided its full module path. The function takes
|
||||
the fully-qualified class path, imports the corresponding module at runtime,
|
||||
retrieves the class from the module, and instantiates it. Any exceptions during
|
||||
this process are caught and logged.
|
||||
|
||||
:param qualified_class_name: Full dot-separated path to the class to be instantiated.
|
||||
Example: 'module.submodule.ClassName'
|
||||
:type qualified_class_name: str
|
||||
:return: An instance of the dynamically instantiated class.
|
||||
:rtype: object
|
||||
:raises ValueError: If the class path fails to split correctly into module and
|
||||
class parts.
|
||||
:raises ModuleNotFoundError: If the specified module cannot be imported.
|
||||
:raises AttributeError: If the specified class does not exist in the module.
|
||||
:raises TypeError: For errors in class instantiation process.
|
||||
"""
|
||||
try:
|
||||
# Split module and class name
|
||||
module_name, class_name = qualified_class_name.rsplit(".", 1)
|
||||
|
||||
# Dynamically import the module
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
# Get the class from the module
|
||||
cls = getattr(module, class_name)
|
||||
|
||||
# Instantiate the class (pass arguments here if required)
|
||||
return cls()
|
||||
except Exception as e:
|
||||
print(f"Failed to instantiate {qualified_class_name}: {e}")
|
||||
|
||||
|
||||
def get_unique_id(prefix: str = None):
|
||||
suffix = base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode('ascii')
|
||||
if prefix is None:
|
||||
return suffix
|
||||
else:
|
||||
return f"{prefix}_{suffix}"
|
||||
|
||||
|
||||
def merge_classes(*args):
|
||||
all_elements = []
|
||||
for element in args:
|
||||
if element is None or element == '':
|
||||
continue
|
||||
|
||||
if isinstance(element, (tuple, list, set)):
|
||||
all_elements.extend(element)
|
||||
|
||||
elif isinstance(element, dict):
|
||||
if "cls" in element:
|
||||
all_elements.append(element.pop("cls"))
|
||||
elif "class" in element:
|
||||
all_elements.append(element.pop("class"))
|
||||
|
||||
elif isinstance(element, str):
|
||||
all_elements.append(element)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Cannot merge {element} of type {type(element)}")
|
||||
|
||||
if all_elements:
|
||||
# Remove duplicates while preserving order
|
||||
unique_elements = list(dict.fromkeys(all_elements))
|
||||
return " ".join(unique_elements)
|
||||
else:
|
||||
return None
|
||||
Reference in New Issue
Block a user