From 878064b1400a5dbcc8f94fcce49acc6db2fe3949 Mon Sep 17 00:00:00 2001 From: Kodjo Sossouvi Date: Fri, 17 Oct 2025 21:08:20 +0200 Subject: [PATCH] first commit --- .gitignore | 216 ++++++++++ .idea/.gitignore | 8 + .idea/MyDbEngine.iml | 11 + .idea/inspectionProfiles/Project_Default.xml | 14 + .../inspectionProfiles/profiles_settings.xml | 6 + .idea/misc.xml | 7 + .idea/modules.xml | 8 + .idea/vcs.xml | 6 + Readme.md | 187 +++++++++ main.py | 16 + requirements.txt | 5 + src/__init__.py | 0 src/core/__init__.py | 0 src/core/dbengine.py | 391 ++++++++++++++++++ src/core/handlers.py | 59 +++ src/core/serializer.py | 201 +++++++++ src/core/utils.py | 195 +++++++++ tests/__init__.py | 0 tests/test_dbengine.py | 273 ++++++++++++ tests/test_serializer.py | 268 ++++++++++++ 20 files changed, 1871 insertions(+) create mode 100644 .gitignore create mode 100644 .idea/.gitignore create mode 100644 .idea/MyDbEngine.iml create mode 100644 .idea/inspectionProfiles/Project_Default.xml create mode 100644 .idea/inspectionProfiles/profiles_settings.xml create mode 100644 .idea/misc.xml create mode 100644 .idea/modules.xml create mode 100644 .idea/vcs.xml create mode 100644 Readme.md create mode 100644 main.py create mode 100644 requirements.txt create mode 100644 src/__init__.py create mode 100644 src/core/__init__.py create mode 100644 src/core/dbengine.py create mode 100644 src/core/handlers.py create mode 100644 src/core/serializer.py create mode 100644 src/core/utils.py create mode 100644 tests/__init__.py create mode 100644 tests/test_dbengine.py create mode 100644 tests/test_serializer.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1f7286a --- /dev/null +++ b/.gitignore @@ -0,0 +1,216 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[codz] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py.cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +# Pipfile.lock + +# UV +# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# uv.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +# poetry.lock +# poetry.toml + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python. +# https://pdm-project.org/en/latest/usage/project/#working-with-version-control +# pdm.lock +# pdm.toml +.pdm-python +.pdm-build/ + +# pixi +# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control. +# pixi.lock +# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one +# in the .venv directory. It is recommended not to include this directory in version control. +.pixi + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# Redis +*.rdb +*.aof +*.pid + +# RabbitMQ +mnesia/ +rabbitmq/ +rabbitmq-data/ + +# ActiveMQ +activemq-data/ + +# SageMath parsed files +*.sage.py + +# Environments +.env +.envrc +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +# .idea/ + +# Abstra +# Abstra is an AI-powered process automation framework. +# Ignore directories containing user credentials, local state, and settings. +# Learn more at https://abstra.io/docs +.abstra/ + +# Visual Studio Code +# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore +# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore +# and can be added to the global gitignore or merged into this file. However, if you prefer, +# you could uncomment the following to ignore the entire vscode folder +# .vscode/ + +# Ruff stuff: +.ruff_cache/ + +# PyPI configuration file +.pypirc + +# Marimo +marimo/_static/ +marimo/_lsp/ +__marimo__/ + +# Streamlit +.streamlit/secrets.toml \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..1c2fda5 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/MyDbEngine.iml b/.idea/MyDbEngine.iml new file mode 100644 index 0000000..60d3a59 --- /dev/null +++ b/.idea/MyDbEngine.iml @@ -0,0 +1,11 @@ + + + + + + + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..1658d1d --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,14 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..31efb58 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..1976935 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..9661ac7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/Readme.md b/Readme.md new file mode 100644 index 0000000..a926341 --- /dev/null +++ b/Readme.md @@ -0,0 +1,187 @@ +# DbEngine + +A lightweight, git-inspired database engine for Python that maintains complete history of all modifications. + +## Overview + +DbEngine is a personal implementation of a versioned database engine that stores snapshots of data changes over time. Each modification creates a new immutable snapshot, allowing you to track the complete history of your data. + +## Key Features + +- **Version Control**: Every change creates a new snapshot with a unique digest (SHA-256 hash) +- **History Tracking**: Access any previous version of your data +- **Multi-tenant Support**: Isolated data storage per tenant +- **Thread-safe**: Built-in locking mechanism for concurrent access +- **Git-inspired Architecture**: Objects are stored in a content-addressable format +- **Efficient Storage**: Identical objects are stored only once + +## Architecture + +The engine uses a file-based storage system with the following structure: + +``` +.mytools_db/ +├── {tenant_id}/ +│ ├── head # Points to latest version of each entry +│ └── objects/ +│ └── {digest_prefix}/ +│ └── {full_digest} # Actual object data +└── refs/ # Shared references +``` + +## Installation + +```python +from db_engine import DbEngine + +# Initialize with default root +db = DbEngine() + +# Or specify custom root directory +db = DbEngine(root="/path/to/database") +``` + +## Basic Usage + +### Initialize Database for a Tenant + +```python +tenant_id = "my_company" +db.init(tenant_id) +``` + +### Save Data + +```python +# Save a complete object +user_id = "john_doe" +entry = "users" +data = {"name": "John", "age": 30} + +digest = db.save(tenant_id, user_id, entry, data) +``` + +### Load Data + +```python +# Load latest version +data = db.load(tenant_id, entry="users") + +# Load specific version by digest +data = db.load(tenant_id, entry="users", digest="abc123...") +``` + +### Work with Individual Records + +```python +# Add or update a single record +db.put(tenant_id, user_id, entry="users", key="john", value={"name": "John", "age": 30}) + +# Add or update multiple records at once +items = { + "john": {"name": "John", "age": 30}, + "jane": {"name": "Jane", "age": 25} +} +db.put_many(tenant_id, user_id, entry="users", items=items) + +# Get a specific record +user = db.get(tenant_id, entry="users", key="john") + +# Get all records +all_users = db.get(tenant_id, entry="users") +``` + +### Check Existence + +```python +if db.exists(tenant_id, entry="users"): + print("Entry exists") +``` + +### Access History + +```python +# Get history of an entry (returns list of digests) +history = db.history(tenant_id, entry="users", max_items=10) + +# Load a previous version +old_data = db.load(tenant_id, entry="users", digest=history[1]) +``` + +## Metadata + +Each snapshot automatically includes metadata: + +- `__parent__`: Digest of the previous version +- `__user__`: User ID who made the change +- `__date__`: Timestamp of the change (format: `YYYYMMDD HH:MM:SS`) + +## API Reference + +### Core Methods + +#### `init(tenant_id: str)` +Initialize database structure for a tenant. + +#### `save(tenant_id: str, user_id: str, entry: str, obj: object) -> str` +Save a complete snapshot. Returns the digest of the saved object. + +#### `load(tenant_id: str, entry: str, digest: str = None) -> object` +Load a snapshot. If digest is None, loads the latest version. + +#### `put(tenant_id: str, user_id: str, entry: str, key: str, value: object) -> bool` +Add or update a single record. Returns True if a new snapshot was created. + +#### `put_many(tenant_id: str, user_id: str, entry: str, items: list | dict) -> bool` +Add or update multiple records. Returns True if a new snapshot was created. + +#### `get(tenant_id: str, entry: str, key: str = None, digest: str = None) -> object` +Retrieve record(s). If key is None, returns all records as a list. + +#### `exists(tenant_id: str, entry: str) -> bool` +Check if an entry exists. + +#### `history(tenant_id: str, entry: str, digest: str = None, max_items: int = 1000) -> list` +Get the history chain of digests for an entry. + +#### `get_digest(tenant_id: str, entry: str) -> str` +Get the current digest for an entry. + +## Usage Patterns + +### Pattern 1: Snapshot-based (using `save()`) +Best for saving complete states of complex objects. + +```python +config = {"theme": "dark", "language": "en"} +db.save(tenant_id, user_id, "config", config) +``` + +### Pattern 2: Record-based (using `put()` / `put_many()`) +Best for managing collections of items incrementally. + +```python +db.put(tenant_id, user_id, "settings", "theme", "dark") +db.put(tenant_id, user_id, "settings", "language", "en") +``` + +**Note**: Don't mix these patterns for the same entry, as they use different data structures. + +## Thread Safety + +DbEngine uses `RLock` internally, making it safe for multi-threaded applications. + +## Exceptions + +- `DbException`: Raised for database-related errors (missing entries, invalid parameters, etc.) + +## Performance Considerations + +- Objects are stored as JSON files +- Identical objects (same SHA-256) are stored only once +- History chains can become long; use `max_items` parameter to limit traversal +- File system performance impacts overall speed + +## License + +This is a personal implementation. Please check with the author for licensing terms. \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..b104d14 --- /dev/null +++ b/main.py @@ -0,0 +1,16 @@ +# This is a sample Python script. + +# Press Ctrl+F5 to execute it or replace it with your code. +# Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings. + + +def print_hi(name): + # Use a breakpoint in the code line below to debug your script. + print(f'Hi, {name}') # Press F9 to toggle the breakpoint. + + +# Press the green button in the gutter to run the script. +if __name__ == '__main__': + print_hi('PyCharm') + +# See PyCharm help at https://www.jetbrains.com/help/pycharm/ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..f2a871b --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +iniconfig==2.1.0 +packaging==25.0 +pluggy==1.6.0 +Pygments==2.19.2 +pytest==8.4.2 diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/core/__init__.py b/src/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/core/dbengine.py b/src/core/dbengine.py new file mode 100644 index 0000000..dd2cf6d --- /dev/null +++ b/src/core/dbengine.py @@ -0,0 +1,391 @@ +import datetime +import hashlib +import io +import json +import logging +import os +import pickle +from threading import RLock + +from core.serializer import Serializer +from core.utils import get_stream_digest + +TYPE_KEY = "__type__" +TAG_PARENT = "__parent__" +TAG_USER = "__user_id__" +TAG_DATE = "__date__" +BUFFER_SIZE = 4096 + +logger = logging.getLogger(__name__) + + +class DbException(Exception): + pass + + +class RefHelper: + def __init__(self, get_ref_path): + self.get_ref_path = get_ref_path + + def save_ref(self, obj): + """ + + :param obj: + :return: + """ + buffer = io.BytesIO() + pickler = pickle.Pickler(buffer) + pickler.dump(obj) + + digest = get_stream_digest(buffer) + + target_path = self.get_ref_path(digest) + if not os.path.exists(os.path.dirname(target_path)): + os.makedirs(os.path.dirname(target_path)) + + buffer.seek(0) + with open(self.get_ref_path(digest), "wb") as file: + while chunk := buffer.read(BUFFER_SIZE): + file.write(chunk) + + logger.debug(f"Saved object type '{type(obj).__name__}' with digest {digest}") + return digest + + def load_ref(self, digest): + """ + + :param digest: + :return: + """ + with open(self.get_ref_path(digest), 'rb') as file: + return pickle.load(file) + + +class DbEngine: + """ + Personal implementation of DB engine + Inspired by the way git manage its files + Designed to keep history of the modifications + """ + ObjectsFolder = "objects" # group objects in the same folder + HeadFile = "head" # used to keep track of the latest version of all entries + + def __init__(self, root: str = None): + self.root = root or ".mytools_db" + self.lock = RLock() + + def is_initialized(self, tenant_id: str): + """ + + :return: + """ + return os.path.exists(self._get_user_root(tenant_id)) + + def init(self, tenant_id: str): + """ + Make sure that the DbEngine is properly initialized + :return: + """ + if not os.path.exists(self._get_user_root(tenant_id)): + logger.debug(f"Creating root folder in {os.path.abspath(self._get_user_root(tenant_id))}.") + os.makedirs(self._get_user_root(tenant_id)) + + def save(self, tenant_id: str, user_id: str, entry: str, obj: object) -> str: + """ + Save a snapshot of an entry + :param tenant_id: + :param user_id: + :param entry: + :param obj: snapshot to save + :return: + """ + with self.lock: + logger.info(f"Saving {tenant_id=}, {entry=}, {obj=}") + + if not tenant_id: + raise DbException("tenant_id is None") + + if not user_id: + raise DbException("user_id is None") + + if not entry: + raise DbException("entry is None") + # prepare the data + as_dict = self._serialize(obj) + as_dict[TAG_PARENT] = [self._get_entry_digest(tenant_id, entry)] + as_dict[TAG_USER] = user_id + as_dict[TAG_DATE] = datetime.datetime.now().strftime('%Y%m%d %H:%M:%S %z') + + # transform into a stream + as_str = json.dumps(as_dict, sort_keys=True, indent=4) + logger.debug(f"Serialized object : {as_str}") + byte_stream = as_str.encode("utf-8") + + # compute the digest to know where to store it + digest = hashlib.sha256(byte_stream).hexdigest() + + target_path = self._get_obj_path(tenant_id, digest) + if os.path.exists(target_path): + # the same object is already saved. Noting to do + return digest + + # save the new value + if not os.path.exists(os.path.dirname(target_path)): + os.makedirs(os.path.dirname(target_path)) + with open(target_path, "wb") as file: + file.write(byte_stream) + + # update the head to remember where the latest entry is + self._update_head(tenant_id, entry, digest) + logger.debug(f"New head for entry '{entry}' is {digest}") + return digest + + def load(self, tenant_id: str, entry, digest=None): + """ + Loads a snapshot + :param tenant_id: + :param entry: + :param digest: + :return: + """ + with self.lock: + logger.info(f"Loading {tenant_id=}, {entry=}, {digest=}") + + digest_to_use = digest or self._get_entry_digest(tenant_id, entry) + logger.debug(f"Using digest {digest_to_use}.") + + if digest_to_use is None: + raise DbException(entry) + + target_file = self._get_obj_path(tenant_id, digest_to_use) + with open(target_file, 'r', encoding='utf-8') as file: + as_dict = json.load(file) + + return self._deserialize(as_dict) + + def put(self, tenant_id: str, user_id, entry, key: str, value: object): + """ + Save a specific record. + This will create a new snapshot is the record is new or different + + You should not mix the usage of put_many() and save() as it's two different way to manage the db + :param user_id: + :param tenant_id: + :param entry: + :param key: + :param value: + :return: + """ + with self.lock: + logger.info(f"Adding {tenant_id=}, {entry=}, {key=}, {value=}") + try: + entry_content = self.load(tenant_id, entry) + except DbException: + entry_content = {} + + # Do not save if the entry is the same + if key in entry_content: + old_value = entry_content[key] + if old_value == value: + return False + + entry_content[key] = value + self.save(tenant_id, user_id, entry, entry_content) + return True + + def put_many(self, tenant_id: str, user_id, entry, items: list | dict): + """ + Save a list of item as one single snapshot + A new snapshot will not be created if all the items already exist + + You should not mix the usage of put_many() and save() as it's two different way to manage the db + :param tenant_id: + :param user_id: + :param entry: + :param items: + :return: + """ + with self.lock: + logger.info(f"Adding many {tenant_id=}, {entry=}, {items=}") + try: + entry_content = self.load(tenant_id, entry) + except DbException: + entry_content = {} + + is_dirty = False + + if isinstance(items, dict): + for key, item in items.items(): + if key in entry_content and entry_content[key] == item: + continue + else: + entry_content[key] = item + is_dirty = True + + else: + + for item in items: + key = item.get_key() + if key in entry_content and entry_content[key] == item: + continue + else: + entry_content[key] = item + is_dirty = True + + if is_dirty: + self.save(tenant_id, user_id, entry, entry_content) + return True + + return False + + def exists(self, tenant_id, entry: str): + """ + Tells if an entry exist + :param tenant_id: + :param entry: + :return: + """ + with self.lock: + return self._get_entry_digest(tenant_id, entry) is not None + + def get(self, tenant_id: str, entry: str, key: str | None = None, digest=None): + """ + Retrieve an item from the snapshot + :param tenant_id: + :param entry: + :param key: + :param digest: + :return: + """ + with self.lock: + logger.info(f"Getting {tenant_id=}, {entry=}, {key=}, {digest=}") + entry_content = self.load(tenant_id, entry, digest) + + if key is None: + # return all items as list + return [v for k, v in entry_content.items() if not k.startswith("__")] + + try: + return entry_content[key] + except KeyError: + raise DbException(f"Key '{key}' not found in entry '{entry}'") + + def history(self, tenant_id, entry, digest=None, max_items=1000): + """ + Gives the current digest and all its ancestors + :param tenant_id: + :param entry: + :param digest: + :param max_items: + :return: + """ + with self.lock: + logger.info(f"History for {tenant_id=}, {entry=}, {digest=}") + + digest_to_use = digest or self._get_entry_digest(tenant_id, entry) + logger.debug(f"Using digest {digest_to_use}.") + + count = 0 + history = [] + + while True: + if count >= max_items or digest_to_use is None: + break + + history.append(digest_to_use) + count += 1 + + try: + target_file = self._get_obj_path(tenant_id, digest_to_use) + with open(target_file, 'r', encoding='utf-8') as file: + as_dict = json.load(file) + + digest_to_use = as_dict[TAG_PARENT][0] + except FileNotFoundError: + break + + return history + + def get_digest(self, tenant_id, entry): + return self._get_entry_digest(tenant_id, entry) + + def _serialize(self, obj): + """ + Just call the serializer + :param obj: + :return: + """ + with self.lock: + serializer = Serializer(RefHelper(self._get_ref_path)) + use_refs = getattr(obj, "use_refs")() if hasattr(obj, "use_refs") else None + return serializer.serialize(obj, use_refs) + + def _deserialize(self, as_dict): + with self.lock: + serializer = Serializer(RefHelper(self._get_ref_path)) + return serializer.deserialize(as_dict) + + def _update_head(self, tenant_id, entry, digest): + """ + Actually dumps the snapshot in file system + :param entry: + :param digest: + :return: + """ + head_path = os.path.join(self.root, tenant_id, self.HeadFile) + # load + try: + with open(head_path, 'r') as file: + head = json.load(file) + except FileNotFoundError: + head = {} + + # update + head[entry] = digest + + # and save + with open(head_path, 'w') as file: + json.dump(head, file) + + def _get_user_root(self, tenant_id): + return os.path.join(self.root, tenant_id) + + def _get_entry_digest(self, tenant_id, entry): + """ + Search for the latest digest, for a given entry + :param entry: + :return: + """ + head_path = os.path.join(self._get_user_root(tenant_id), self.HeadFile) + try: + with open(head_path, 'r') as file: + head = json.load(file) + return head[str(entry)] + + except FileNotFoundError: + return None + except KeyError: + return None + + def _get_head_path(self, tenant_id: str): + """ + Location of the Head file + :return: + """ + return os.path.join(self._get_user_root(tenant_id), self.HeadFile) + + def _get_obj_path(self, tenant_id, digest): + """ + Location of objects + :param digest: + :return: + """ + return os.path.join(self._get_user_root(tenant_id), "objects", digest[:24], digest) + + def _get_ref_path(self, digest): + """ + Location of reference. They are not linked to the user folder + :param digest: + :return: + """ + return os.path.join(self.root, "refs", digest[:24], digest) diff --git a/src/core/handlers.py b/src/core/handlers.py new file mode 100644 index 0000000..d63a2be --- /dev/null +++ b/src/core/handlers.py @@ -0,0 +1,59 @@ +# I delegate the complexity of some data type within specific handlers + +import datetime + +from core.utils import has_tag + +TAG_SPECIAL = "__special__" + + +class BaseHandler: + def is_eligible_for(self, obj): + pass + + def tag(self): + pass + + def serialize(self, obj) -> dict: + pass + + def deserialize(self, data: dict) -> object: + pass + + +class DateHandler(BaseHandler): + def is_eligible_for(self, obj): + return isinstance(obj, datetime.date) + + def tag(self): + return "Date" + + def serialize(self, obj): + return { + TAG_SPECIAL: self.tag(), + "year": obj.year, + "month": obj.month, + "day": obj.day, + } + + def deserialize(self, data: dict) -> object: + return datetime.date(year=data["year"], month=data["month"], day=data["day"]) + + +class Handlers: + + def __init__(self, handlers_): + self.handlers = handlers_ + + def get_handler(self, obj): + if has_tag(obj, TAG_SPECIAL): + return [h for h in self.handlers if h.tag() == obj[TAG_SPECIAL]][0] + + for h in self.handlers: + if h.is_eligible_for(obj): + return h + + return None + + +handlers = Handlers([DateHandler()]) diff --git a/src/core/serializer.py b/src/core/serializer.py new file mode 100644 index 0000000..99aaa4d --- /dev/null +++ b/src/core/serializer.py @@ -0,0 +1,201 @@ +import copy + +from core.handlers import handlers +from core.utils import has_tag, is_dictionary, is_list, is_object, is_set, is_tuple, is_primitive, importable_name, \ + get_class, get_full_qualified_name, is_enum + +TAG_ID = "__id__" +TAG_OBJECT = "__object__" +TAG_TUPLE = "__tuple__" +TAG_SET = "__set__" +TAG_REF = "__ref__" +TAG_ENUM = "__enum__" + + +class Serializer: + def __init__(self, ref_helper=None): + self.ref_helper = ref_helper + + self.ids = {} + self.objs = [] + self.id_count = 0 + + def serialize(self, obj, use_refs=None): + """ + From object to dictionary + :param obj: + :param use_refs: Sometimes it easier / quicker to use pickle ! + :return: + """ + if use_refs: + use_refs = set("root." + path for path in use_refs) + + return self._serialize(obj, use_refs or set(), "root") + + def deserialize(self, obj: dict): + """ + From dictionary to object (or primitive) + :param obj: + :return: + """ + if has_tag(obj, TAG_REF): + return self.ref_helper.load_ref(obj[TAG_REF]) + + if has_tag(obj, TAG_ID): + return self._restore_id(obj) + + if has_tag(obj, TAG_TUPLE): + return tuple([self.deserialize(v) for v in obj[TAG_TUPLE]]) + + if has_tag(obj, TAG_SET): + return set([self.deserialize(v) for v in obj[TAG_SET]]) + + if has_tag(obj, TAG_ENUM): + return self._deserialize_enum(obj) + + if has_tag(obj, TAG_OBJECT): + return self._deserialize_obj_instance(obj) + + if (handler := handlers.get_handler(obj)) is not None: + return handler.deserialize(obj) + + if is_list(obj): + return [self.deserialize(v) for v in obj] + + if is_dictionary(obj): + return {k: self.deserialize(v) for k, v in obj.items()} + + return obj + + def _serialize(self, obj, use_refs: set | None, path): + if use_refs is not None and path in use_refs: + digest = self.ref_helper.save_ref(obj) + return {TAG_REF: digest} + + if is_primitive(obj): + return obj + + if is_tuple(obj): + return {TAG_TUPLE: [self._serialize(v, use_refs, path) for v in obj]} + + if is_set(obj): + return {TAG_SET: [self._serialize(v, use_refs, path) for v in obj]} + + if is_list(obj): + return [self._serialize(v, use_refs, path) for v in obj] + + if is_dictionary(obj): + return {k: self._serialize(v, use_refs, path) for k, v in obj.items()} + + if is_enum(obj): + return self._serialize_enum(obj, use_refs, path) + + if is_object(obj): + return self._serialize_obj_instance(obj, use_refs, path) + + raise Exception(f"Cannot serialize '{obj}'") + + def _serialize_enum(self, obj, use_refs: set | None, path): + # check if the object was already seen + if (seen := self._check_already_seen(obj)) is not None: + return seen + + data = {} + class_name = get_full_qualified_name(obj) + data[TAG_ENUM] = class_name + "." + obj.name + return data + + def _serialize_obj_instance(self, obj, use_refs: set | None, path): + # check if the object was already seen + if (seen := self._check_already_seen(obj)) is not None: + return seen + + # try to manage use_refs + current_obj_use_refs = getattr(obj, "use_refs")() if hasattr(obj, "use_refs") else None + if current_obj_use_refs: + use_refs.update(f"{path}.{sub_path}" for sub_path in current_obj_use_refs) + + if (handler := handlers.get_handler(obj)) is not None: + return handler.serialize(obj) + + # flatten + data = {} + cls = obj.__class__ if hasattr(obj, '__class__') else type(obj) + class_name = importable_name(cls) + data[TAG_OBJECT] = class_name + + if hasattr(obj, "__dict__"): + for k, v in obj.__dict__.items(): + data[k] = self._serialize(v, use_refs, f"{path}.{k}") + + return data + + def _check_already_seen(self, obj): + _id = self._exist(obj) + if _id is not None: + return {TAG_ID: _id} + + # else: + self.ids[id(obj)] = self.id_count + self.objs.append(obj) + self.id_count = self.id_count + 1 + + return None + + def _deserialize_enum(self, obj): + cls_name, enum_name = obj[TAG_ENUM].rsplit(".", 1) + cls = get_class(cls_name) + obj = getattr(cls, enum_name) + self.objs.append(obj) + return obj + + def _deserialize_obj_instance(self, obj): + + cls = get_class(obj[TAG_OBJECT]) + instance = cls.__new__(cls) + self.objs.append(instance) + + for k, v in obj.items(): + value = self.deserialize(v) + setattr(instance, k, value) + + return instance + + def _restore_id(self, obj): + try: + return self.objs[obj[TAG_ID]] + except IndexError: + pass + + def _exist(self, obj): + try: + v = self.ids[id(obj)] + return v + except KeyError: + return None + + +class DebugSerializer(Serializer): + def __init__(self, ref_helper=None): + super().__init__(ref_helper) + + def _deserialize_obj_instance(self, obj): + data = {TAG_OBJECT: obj[TAG_OBJECT]} + self.objs.append(data) + + for k, v in obj.items(): + value = self.deserialize(v) + data[k] = value + + return data + + def _deserialize_enum(self, obj): + cls_name, enum_name = obj[TAG_ENUM].rsplit(".", 1) + self.objs.append(enum_name) + return enum_name + + def _restore_id(self, obj): + try: + return copy.deepcopy(self.objs[obj[TAG_ID]]) + except IndexError: + pass diff --git a/src/core/utils.py b/src/core/utils.py new file mode 100644 index 0000000..f984e28 --- /dev/null +++ b/src/core/utils.py @@ -0,0 +1,195 @@ +import ast +import hashlib +import importlib +import types +from enum import Enum + +PRIMITIVES = (str, bool, type(None), int, float) + + +def get_stream_digest(stream): + """ + Compute a SHA256 from a stream + :param stream: + :type stream: + :return: + :rtype: + """ + sha256_hash = hashlib.sha256() + stream.seek(0) + for byte_block in iter(lambda: stream.read(4096), b""): + sha256_hash.update(byte_block) + + return sha256_hash.hexdigest() + + +def has_tag(obj, tag): + """ + + :param obj: + :param tag: + :return: + """ + return type(obj) is dict and tag in obj + + +def is_primitive(obj): + """ + + :param obj: + :return: + """ + return isinstance(obj, PRIMITIVES) + + +def is_dictionary(obj): + """ + + :param obj: + :return: + """ + return isinstance(obj, dict) + + +def is_list(obj): + """ + + :param obj: + :return: + """ + return isinstance(obj, list) + + +def is_set(obj): + """ + + :param obj: + :return: + """ + return isinstance(obj, set) + + +def is_tuple(obj): + """ + + :param obj: + :return: + """ + return isinstance(obj, tuple) + + +def is_enum(obj): + return isinstance(obj, Enum) + + +def is_object(obj): + """Returns True is obj is a reference to an object instance.""" + + return (isinstance(obj, object) and + not isinstance(obj, (type, + types.FunctionType, + types.BuiltinFunctionType, + types.GeneratorType))) + + +def get_full_qualified_name(obj): + """ + Returns the full qualified name of a class (including its module name ) + :param obj: + :return: + """ + if obj.__class__ == type: + module = obj.__module__ + if module is None or module == str.__class__.__module__: + return obj.__name__ # Avoid reporting __builtin__ + else: + return module + '.' + obj.__name__ + else: + module = obj.__class__.__module__ + if module is None or module == str.__class__.__module__: + return obj.__class__.__name__ # Avoid reporting __builtin__ + else: + return module + '.' + obj.__class__.__name__ + + +def importable_name(cls): + """ + Fully qualified name (prefixed by builtin when needed) + """ + # Use the fully-qualified name if available (Python >= 3.3) + name = getattr(cls, '__qualname__', cls.__name__) + + # manage python 2 + lookup = dict(__builtin__='builtins', exceptions='builtins') + module = lookup.get(cls.__module__, cls.__module__) + + return f"{module}.{name}" + + +def get_class(qualified_class_name: str): + """ + Dynamically loads and returns a class type from its fully qualified name. + Note that the class is not instantiated. + + :param qualified_class_name: Fully qualified name of the class (e.g., 'some.module.ClassName'). + :return: The class object. + :raises ImportError: If the module cannot be imported. + :raises AttributeError: If the class cannot be resolved in the module. + """ + module_name, class_name = qualified_class_name.rsplit(".", 1) + + try: + module = importlib.import_module(module_name) + except ModuleNotFoundError as e: + raise ImportError(f"Could not import module '{module_name}' for '{qualified_class_name}': {e}") + + if not hasattr(module, class_name): + raise AttributeError(f"Component '{class_name}' not found in '{module.__name__}'.") + + return getattr(module, class_name) + + +class UnreferencedNamesVisitor(ast.NodeVisitor): + """ + Try to find symbols that will be requested by the ast + It can be variable names, but also function names + """ + + def __init__(self): + self.names = set() + + def get_names(self, node): + self.visit(node) + return self.names + + def visit_Name(self, node): + self.names.add(node.id) + + def visit_For(self, node: ast.For): + self.visit_selected(node, ["body", "orelse"]) + + def visit_selected(self, node, to_visit): + """Called if no explicit visitor function exists for a node.""" + for field in to_visit: + value = getattr(node, field) + if isinstance(value, list): + for item in value: + if isinstance(item, ast.AST): + self.visit(item) + elif isinstance(value, ast.AST): + self.visit(value) + + def visit_Call(self, node: ast.Call): + self.visit_selected(node, ["args", "keywords"]) + + def visit_keyword(self, node: ast.keyword): + """ + Keywords are parameters that are defined with a double star (**) in function / method definition + ex: def fun(positional, *args, **keywords) + :param node: + :type node: + :return: + :rtype: + """ + self.names.add(node.arg) + self.visit_selected(node, ["value"]) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_dbengine.py b/tests/test_dbengine.py new file mode 100644 index 0000000..cb6ac54 --- /dev/null +++ b/tests/test_dbengine.py @@ -0,0 +1,273 @@ +import os.path +import shutil + +import pytest + +from core.dbengine import DbEngine, DbException, TAG_PARENT, TAG_USER, TAG_DATE + +DB_ENGINE_ROOT = "TestDBEngineRoot" +FAKE_TENANT_ID = "FakeTenantId" +FAKE_USER_EMAIL = "fake_user@me.com" + + +class DummyObj: + def __init__(self, a, b, c): + self.a = a + self.b = b + self.c = c + + def __eq__(self, other): + if id(self) == id(other): + return True + + if not isinstance(other, DummyObj): + return False + + return self.a == other.a and self.b == other.b and self.c == other.c + + def __hash__(self): + return hash((self.a, self.b, self.c)) + + +class DummyObjWithRef(DummyObj): + @staticmethod + def use_refs() -> set: + return {"c"} + + +class DummyObjWithKey(DummyObj): + def get_key(self) -> set: + return self.a + + +@pytest.fixture() +def engine(): + if os.path.exists(DB_ENGINE_ROOT): + shutil.rmtree(DB_ENGINE_ROOT) + + engine = DbEngine(DB_ENGINE_ROOT) + engine.init(FAKE_TENANT_ID) + + yield engine + + shutil.rmtree(DB_ENGINE_ROOT) + + +@pytest.fixture() +def dummy_obj(): + return DummyObj(1, "a", False) + + +@pytest.fixture() +def dummy_obj2(): + return DummyObj(2, "b", True) + + +@pytest.fixture() +def dummy_obj_with_ref(): + data = { + 'Key1': ['A', 'B', 'C'], + 'Key2': ['X', 'Y', 'Z'], + 'Percentage': [0.1, 0.2, 0.15], + } + return DummyObjWithRef(1, "a", data) + + +def test_i_can_test_init(): + if os.path.exists(DB_ENGINE_ROOT): + shutil.rmtree(DB_ENGINE_ROOT) + + engine = DbEngine(DB_ENGINE_ROOT) + assert not engine.is_initialized(FAKE_TENANT_ID) + + engine.init(FAKE_TENANT_ID) + assert engine.is_initialized(FAKE_TENANT_ID) + + +def test_i_can_save_and_load(engine, dummy_obj): + digest = engine.save(FAKE_TENANT_ID, FAKE_USER_EMAIL, "MyEntry", dummy_obj) + + res = engine.load(FAKE_TENANT_ID, "MyEntry") + + assert digest is not None + assert isinstance(res, DummyObj) + + assert res.a == dummy_obj.a + assert res.b == dummy_obj.b + assert res.c == dummy_obj.c + + # check that the files are created + assert os.path.exists(os.path.join(DB_ENGINE_ROOT, FAKE_TENANT_ID, "objects")) + assert os.path.exists(os.path.join(DB_ENGINE_ROOT, FAKE_TENANT_ID, "head")) + + +def test_save_invalid_inputs(engine): + """ + Test save with invalid inputs. + """ + with pytest.raises(DbException): + engine.save(None, FAKE_USER_EMAIL, "InvalidEntry", DummyObj(1, 2, 3)) + + with pytest.raises(DbException): + engine.save(FAKE_TENANT_ID, None, "InvalidEntry", DummyObj(1, 2, 3)) + + with pytest.raises(DbException): + engine.save(FAKE_TENANT_ID, FAKE_USER_EMAIL, "", DummyObj(1, 2, 3)) + + with pytest.raises(DbException): + engine.save(FAKE_TENANT_ID, FAKE_USER_EMAIL, None, DummyObj(1, 2, 3)) + + +def test_i_can_save_using_ref(engine, dummy_obj_with_ref): + engine.save(FAKE_TENANT_ID, FAKE_USER_EMAIL, "MyEntry", dummy_obj_with_ref) + + res = engine.load(FAKE_TENANT_ID, "MyEntry") + assert isinstance(res, DummyObjWithRef) + + assert res.a == dummy_obj_with_ref.a + assert res.b == dummy_obj_with_ref.b + assert res.c == dummy_obj_with_ref.c + + # check that the files are created + assert os.path.exists(os.path.join(DB_ENGINE_ROOT, FAKE_TENANT_ID, "objects")) + assert os.path.exists(os.path.join(DB_ENGINE_ROOT, FAKE_TENANT_ID, "head")) + assert os.path.exists(os.path.join(DB_ENGINE_ROOT, "refs")) + + +def test_refs_are_share_across_users(engine, dummy_obj_with_ref): + engine.save(FAKE_TENANT_ID, FAKE_USER_EMAIL, "MyEntry", dummy_obj_with_ref) + engine.save("AnotherUserId", "AnotherUser", "AnotherMyEntry", dummy_obj_with_ref) + + refs_path = os.path.join(DB_ENGINE_ROOT, "refs") + assert len(os.listdir(refs_path)) == 1 + + +def test_metadata_are_correctly_set(engine, dummy_obj): + digest = engine.save(FAKE_TENANT_ID, FAKE_USER_EMAIL, "MyEntry", {"obj": dummy_obj}) + + as_dict = engine.load(FAKE_TENANT_ID, "MyEntry", digest) + assert as_dict[TAG_PARENT] == [None] + assert as_dict[TAG_USER] == FAKE_USER_EMAIL + assert as_dict[TAG_DATE] is not None + + +def test_i_can_track_parents(engine): + digest = engine.save(FAKE_TENANT_ID, FAKE_USER_EMAIL, "MyEntry", {"obj": DummyObj(1, "a", False)}) + second_digest = engine.save(FAKE_TENANT_ID, FAKE_USER_EMAIL, "MyEntry", {"obj": DummyObj(1, "a", True)}) + + as_dict = engine.load(FAKE_TENANT_ID, "MyEntry", second_digest) + + assert as_dict[TAG_PARENT] == [digest] + + +def test_i_can_put_and_get_one_object(engine, dummy_obj): + engine.put(FAKE_TENANT_ID, FAKE_USER_EMAIL, "MyEntry", "key1", dummy_obj) + from_db = engine.get(FAKE_TENANT_ID, "MyEntry", "key1") + + assert from_db == dummy_obj + + +def test_i_can_put_and_get_multiple_objects(engine, dummy_obj, dummy_obj2): + engine.put(FAKE_TENANT_ID, FAKE_USER_EMAIL, "MyEntry", "key1", dummy_obj) + engine.put(FAKE_TENANT_ID, FAKE_USER_EMAIL, "MyEntry", "key2", dummy_obj2) + + from_db1 = engine.get(FAKE_TENANT_ID, "MyEntry", "key1") + from_db2 = engine.get(FAKE_TENANT_ID, "MyEntry", "key2") + + assert from_db1 == dummy_obj + assert from_db2 == dummy_obj2 + + as_dict = engine.load(FAKE_TENANT_ID, "MyEntry") + + assert "key1" in as_dict + assert "key2" in as_dict + assert as_dict["key1"] == dummy_obj + assert as_dict["key2"] == dummy_obj2 + + +def test_i_automatically_replace_keys(engine, dummy_obj, dummy_obj2): + engine.put(FAKE_TENANT_ID, FAKE_USER_EMAIL, "MyEntry", "key1", dummy_obj) + engine.put(FAKE_TENANT_ID, FAKE_USER_EMAIL, "MyEntry", "key1", dummy_obj2) + + from_db1 = engine.get(FAKE_TENANT_ID, "MyEntry", "key1") + assert from_db1 == dummy_obj2 + + all_items = engine.get(FAKE_TENANT_ID, "MyEntry") + assert all_items == [dummy_obj2] + + +def test_i_do_not_save_twice_when_the_entries_are_the_same(engine, dummy_obj): + engine.put(FAKE_TENANT_ID, FAKE_USER_EMAIL, "MyEntry", "key1", dummy_obj) + + entry_content = engine.load(FAKE_TENANT_ID, "MyEntry") + assert entry_content[TAG_PARENT] == [None] + + # Save the same entry again + engine.put(FAKE_TENANT_ID, FAKE_USER_EMAIL, "MyEntry", "key1", dummy_obj) + + entry_content = engine.load(FAKE_TENANT_ID, "MyEntry") + assert entry_content[TAG_PARENT] == [None] # still no other parent + + +def test_i_can_put_many(engine): + dummy_obj = DummyObjWithKey("1", "a", True) + dummy_obj2 = DummyObjWithKey("2", "b", False) + engine.put_many(FAKE_TENANT_ID, FAKE_USER_EMAIL, "MyEntry", [dummy_obj, dummy_obj2]) + + from_db1 = engine.get(FAKE_TENANT_ID, "MyEntry", "1") + from_db2 = engine.get(FAKE_TENANT_ID, "MyEntry", "2") + + assert from_db1 == dummy_obj + assert from_db2 == dummy_obj2 + + entry_content = engine.load(FAKE_TENANT_ID, "MyEntry") + assert entry_content[TAG_PARENT] == [None] # only one save was made + + +def test_put_many_save_only_if_necessary(engine): + dummy_obj = DummyObjWithKey("1", "a", True) + dummy_obj2 = DummyObjWithKey("2", "b", False) + + engine.put_many(FAKE_TENANT_ID, FAKE_USER_EMAIL, "MyEntry", [dummy_obj, dummy_obj2]) + engine.put_many(FAKE_TENANT_ID, FAKE_USER_EMAIL, "MyEntry", [dummy_obj, dummy_obj2]) + + entry_content = engine.load(FAKE_TENANT_ID, "MyEntry") + assert entry_content[TAG_PARENT] == [None] # Still None, nothing was save + + +def test_i_can_retrieve_history_using_put(engine): + engine.put(FAKE_TENANT_ID, FAKE_USER_EMAIL, "MyEntry", "key1", DummyObj(1, "a", False)) + engine.put(FAKE_TENANT_ID, FAKE_USER_EMAIL, "MyEntry", "key1", DummyObj(2, "a", False)) + engine.put(FAKE_TENANT_ID, FAKE_USER_EMAIL, "MyEntry", "key1", DummyObj(3, "a", False)) + + history = engine.history(FAKE_TENANT_ID, "MyEntry") + assert len(history) == 3 + + v0 = engine.load(FAKE_TENANT_ID, "MyEntry", history[0]) + v1 = engine.load(FAKE_TENANT_ID, "MyEntry", history[1]) + v2 = engine.load(FAKE_TENANT_ID, "MyEntry", history[2]) + + assert v0["key1"] == DummyObj(3, "a", False) + assert v1["key1"] == DummyObj(2, "a", False) + assert v2["key1"] == DummyObj(1, "a", False) + + assert v2[TAG_PARENT] == [None] + + +def test_i_can_retrieve_history_using_save(engine): + engine.save(FAKE_TENANT_ID, FAKE_USER_EMAIL, "MyEntry", {"key1": DummyObj(1, "a", False)}) + engine.save(FAKE_TENANT_ID, FAKE_USER_EMAIL, "MyEntry", {"key1": DummyObj(2, "a", False)}) + engine.save(FAKE_TENANT_ID, FAKE_USER_EMAIL, "MyEntry", {"key1": DummyObj(3, "a", False)}) + + history = engine.history(FAKE_TENANT_ID, "MyEntry") + assert len(history) == 3 + + v0 = engine.load(FAKE_TENANT_ID, "MyEntry", history[0]) + v1 = engine.load(FAKE_TENANT_ID, "MyEntry", history[1]) + v2 = engine.load(FAKE_TENANT_ID, "MyEntry", history[2]) + + assert v0["key1"] == DummyObj(3, "a", False) + assert v1["key1"] == DummyObj(2, "a", False) + assert v2["key1"] == DummyObj(1, "a", False) + + assert v2[TAG_PARENT] == [None] diff --git a/tests/test_serializer.py b/tests/test_serializer.py new file mode 100644 index 0000000..88e1c4f --- /dev/null +++ b/tests/test_serializer.py @@ -0,0 +1,268 @@ +import dataclasses +import datetime +import hashlib +import pickle +from enum import Enum + +import pytest + +from core.serializer import TAG_TUPLE, TAG_SET, Serializer, TAG_OBJECT, TAG_ID, TAG_REF + + +class Obj: + def __init__(self, a, b, c): + self.a = a + self.b = b + self.c = c + + def __eq__(self, other): + if id(self) == id(other): + return True + + if not isinstance(other, Obj): + return False + + return self.a == other.a and self.b == other.b and self.c == other.c + + def __hash__(self): + return hash((self.a, self.b, self.c)) + + +class Obj2: + class InnerClass: + def __init__(self, x): + self.x = x + + def __eq__(self, other): + if not isinstance(other, Obj2.InnerClass): + return False + + return self.x == other.x + + def __hash__(self): + return hash(self.x) + + def __init__(self, a, b, x): + self.a = a + self.b = b + self.x = Obj2.InnerClass(x) + + def __eq__(self, other): + if not isinstance(other, Obj2): + return False + + return (self.a == other.a and + self.b == other.b and + self.x == other.x) + + def __hash__(self): + return hash((self.a, self.b)) + + +class ObjEnum(Enum): + A = 1 + B = "second" + C = "last" + + +@dataclasses.dataclass +class DummyComplexClass: + prop1: str + prop2: Obj + prop3: ObjEnum + + +class DummyRefHelper: + """ + When something is too complicated to serialize, we just default to pickle + That is what this helper class is doing + """ + + def __init__(self): + self.refs = {} + + def save_ref(self, obj): + sha256_hash = hashlib.sha256() + + pickled_data = pickle.dumps(obj) + sha256_hash.update(pickled_data) + digest = sha256_hash.hexdigest() + + self.refs[digest] = pickled_data + return digest + + def load_ref(self, digest): + return pickle.loads(self.refs[digest]) + + +@pytest.mark.parametrize("obj, expected", [ + (1, 1), + (3.14, 3.14), + ("a string", "a string"), + (True, True), + (None, None), + ([1, 3.14, "a string"], [1, 3.14, "a string"]), + ((1, 3.14, "a string"), {TAG_TUPLE: [1, 3.14, "a string"]}), + ({1}, {TAG_SET: [1]}), + ({"a": "a", "b": 3.14, "c": True}, {"a": "a", "b": 3.14, "c": True}), + ({1: "a", 2: 3.14, 3: True}, {1: "a", 2: 3.14, 3: True}), + ([1, [3.14, "a string"]], [1, [3.14, "a string"]]), + ([1, (3.14, "a string")], [1, {TAG_TUPLE: [3.14, "a string"]}]), + ([], []), +]) +def test_i_can_flatten_and_restore_primitives(obj, expected): + serializer = Serializer() + + flatten = serializer.serialize(obj) + assert flatten == expected + + decoded = serializer.deserialize(flatten) + assert decoded == obj + + +def test_i_can_flatten_and_restore_instances(): + serializer = Serializer() + obj1 = Obj(1, "b", True) + obj2 = Obj(3.14, ("a", "b"), obj1) + + flatten = serializer.serialize(obj2) + assert flatten == {TAG_OBJECT: 'tests.test_serializer.Obj', + 'a': 3.14, + 'b': {TAG_TUPLE: ['a', 'b']}, + 'c': {TAG_OBJECT: 'tests.test_serializer.Obj', + 'a': 1, + 'b': 'b', + 'c': True}} + + decoded = serializer.deserialize(flatten) + assert decoded == obj2 + + +def test_i_can_flatten_and_restore_enum(): + serializer = Serializer() + obj1 = ObjEnum.A + obj2 = ObjEnum.B + obj3 = ObjEnum.C + + wrapper = { + "a": obj1, + "b": obj2, + "c": obj3, + "d": obj1 + } + flatten = serializer.serialize(wrapper) + assert flatten == {'a': {'__enum__': 'tests.test_serializer.ObjEnum.A'}, + 'b': {'__enum__': 'tests.test_serializer.ObjEnum.B'}, + 'c': {'__enum__': 'tests.test_serializer.ObjEnum.C'}, + 'd': {'__id__': 0}} + decoded = serializer.deserialize(flatten) + assert decoded == wrapper + + +def test_i_can_flatten_and_restore_list_with_enum(): + serializer = Serializer() + obj = [DummyComplexClass("a", Obj(1, "a", ObjEnum.A), ObjEnum.A), + DummyComplexClass("b", Obj(2, "b", ObjEnum.B), ObjEnum.B), + DummyComplexClass("c", Obj(3, "c", ObjEnum.C), ObjEnum.B)] + + flatten = serializer.serialize(obj) + assert flatten == [{'__object__': 'tests.test_serializer.DummyComplexClass', + 'prop1': 'a', + 'prop2': {'__object__': 'tests.test_serializer.Obj', + 'a': 1, + 'b': 'a', + 'c': {'__enum__': 'tests.test_serializer.ObjEnum.A'}}, + 'prop3': {'__id__': 2}}, + {'__object__': 'tests.test_serializer.DummyComplexClass', + 'prop1': 'b', + 'prop2': {'__object__': 'tests.test_serializer.Obj', + 'a': 2, + 'b': 'b', + 'c': {'__enum__': 'tests.test_serializer.ObjEnum.B'}}, + 'prop3': {'__id__': 5}}, + {'__object__': 'tests.test_serializer.DummyComplexClass', + 'prop1': 'c', + 'prop2': {'__object__': 'tests.test_serializer.Obj', + 'a': 3, + 'b': 'c', + 'c': {'__enum__': 'tests.test_serializer.ObjEnum.C'}}, + 'prop3': {'__id__': 5}}] + decoded = serializer.deserialize(flatten) + assert decoded == obj + + +def test_i_can_manage_circular_reference(): + serializer = Serializer() + obj1 = Obj(1, "b", True) + obj1.c = obj1 + + flatten = serializer.serialize(obj1) + assert flatten == {TAG_OBJECT: 'tests.test_serializer.Obj', + 'a': 1, + 'b': 'b', + 'c': {TAG_ID: 0}} + + decoded = serializer.deserialize(flatten) + assert decoded.a == obj1.a + assert decoded.b == obj1.b + assert decoded.c == decoded + + +def test_i_can_use_refs_on_primitive(): + serializer = Serializer(DummyRefHelper()) + obj1 = Obj(1, "b", True) + + flatten = serializer.serialize(obj1, ["c"]) + assert flatten == {TAG_OBJECT: 'tests.test_serializer.Obj', + 'a': 1, + 'b': 'b', + 'c': {TAG_REF: '112bda3b495d867b6a98c899fac7c25eb60ca4b6e6fe5ec7ab9299f93e8274bc'}} + + decoded = serializer.deserialize(flatten) + assert decoded == obj1 + + +def test_i_can_use_refs_on_path(): + serializer = Serializer(DummyRefHelper()) + obj1 = Obj(1, "b", True) + obj2 = Obj(1, "b", obj1) + + flatten = serializer.serialize(obj2, ["c.b"]) + assert flatten == {TAG_OBJECT: 'tests.test_serializer.Obj', + 'a': 1, + 'b': 'b', + 'c': {TAG_OBJECT: 'tests.test_serializer.Obj', + 'a': 1, + 'b': {TAG_REF: '897f2e2b559dd876ad870c82283197b8cfecdf84736192ea6fb9ee5a5080a3a4'}, + 'c': True}} + + decoded = serializer.deserialize(flatten) + assert decoded == obj2 + + +def test_can_use_refs_when_circular_reference(): + serializer = Serializer(DummyRefHelper()) + obj1 = Obj(1, "b", True) + obj1.c = obj1 + + flatten = serializer.serialize(obj1, ["c"]) + assert flatten == {TAG_OBJECT: 'tests.test_serializer.Obj', + 'a': 1, + 'b': 'b', + 'c': {TAG_REF: "87b1980d83bd267e2c8cc2fbc435ba00349e45b736c40f3984f710ebb4495adc"}} + + decoded = serializer.deserialize(flatten) + assert decoded.a == obj1.a + assert decoded.b == obj1.b + assert decoded.c == decoded + + +def test_i_can_serialize_date(): + obj = datetime.date.today() + serializer = Serializer() + + flatten = serializer.serialize(obj) + + decoded = serializer.deserialize(flatten) + + assert decoded == obj