Removed async
This commit is contained in:
@@ -67,7 +67,7 @@ services:
|
|||||||
networks:
|
networks:
|
||||||
- mydocmanager-network
|
- mydocmanager-network
|
||||||
command: celery -A tasks.main worker --loglevel=info
|
command: celery -A tasks.main worker --loglevel=info
|
||||||
#command: celery -A main --loglevel=info # pour la production
|
|
||||||
|
|
||||||
volumes:
|
volumes:
|
||||||
mongodb-data:
|
mongodb-data:
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
amqp==5.3.1
|
amqp==5.3.1
|
||||||
annotated-types==0.7.0
|
annotated-types==0.7.0
|
||||||
anyio==4.10.0
|
anyio==4.10.0
|
||||||
|
asgiref==3.9.1
|
||||||
bcrypt==4.3.0
|
bcrypt==4.3.0
|
||||||
billiard==4.2.1
|
billiard==4.2.1
|
||||||
celery==5.5.3
|
celery==5.5.3
|
||||||
@@ -12,9 +13,12 @@ dnspython==2.8.0
|
|||||||
email-validator==2.3.0
|
email-validator==2.3.0
|
||||||
fastapi==0.116.1
|
fastapi==0.116.1
|
||||||
h11==0.16.0
|
h11==0.16.0
|
||||||
|
hiredis==3.2.1
|
||||||
httptools==0.6.4
|
httptools==0.6.4
|
||||||
idna==3.10
|
idna==3.10
|
||||||
|
importlib_metadata==8.7.0
|
||||||
iniconfig==2.1.0
|
iniconfig==2.1.0
|
||||||
|
izulu==0.50.0
|
||||||
kombu==5.5.4
|
kombu==5.5.4
|
||||||
mongomock==4.3.0
|
mongomock==4.3.0
|
||||||
mongomock-motor==0.0.36
|
mongomock-motor==0.0.36
|
||||||
@@ -23,9 +27,11 @@ packaging==25.0
|
|||||||
pipdeptree==2.28.0
|
pipdeptree==2.28.0
|
||||||
pluggy==1.6.0
|
pluggy==1.6.0
|
||||||
prompt_toolkit==3.0.52
|
prompt_toolkit==3.0.52
|
||||||
|
pycron==3.2.0
|
||||||
pydantic==2.11.9
|
pydantic==2.11.9
|
||||||
pydantic_core==2.33.2
|
pydantic_core==2.33.2
|
||||||
Pygments==2.19.2
|
Pygments==2.19.2
|
||||||
|
PyJWT==2.10.1
|
||||||
pymongo==4.15.1
|
pymongo==4.15.1
|
||||||
pytest==8.4.2
|
pytest==8.4.2
|
||||||
pytest-asyncio==1.2.0
|
pytest-asyncio==1.2.0
|
||||||
@@ -35,6 +41,7 @@ python-dotenv==1.1.1
|
|||||||
python-magic==0.4.27
|
python-magic==0.4.27
|
||||||
pytz==2025.2
|
pytz==2025.2
|
||||||
PyYAML==6.0.2
|
PyYAML==6.0.2
|
||||||
|
redis==6.4.0
|
||||||
sentinels==1.1.1
|
sentinels==1.1.1
|
||||||
six==1.17.0
|
six==1.17.0
|
||||||
sniffio==1.3.1
|
sniffio==1.3.1
|
||||||
@@ -49,3 +56,4 @@ watchdog==6.0.0
|
|||||||
watchfiles==1.1.0
|
watchfiles==1.1.0
|
||||||
wcwidth==0.2.13
|
wcwidth==0.2.13
|
||||||
websockets==15.0.1
|
websockets==15.0.1
|
||||||
|
zipp==3.23.0
|
||||||
|
|||||||
@@ -30,6 +30,26 @@ def get_mongodb_database_name() -> str:
|
|||||||
return os.getenv("MONGODB_DATABASE", "mydocmanager")
|
return os.getenv("MONGODB_DATABASE", "mydocmanager")
|
||||||
|
|
||||||
|
|
||||||
|
def get_redis_url() -> str:
|
||||||
|
return os.getenv("REDIS_URL", "redis://localhost:6379/0")
|
||||||
|
|
||||||
|
|
||||||
|
# def get_redis_host() -> str:
|
||||||
|
# redis_url = get_redis_url()
|
||||||
|
# if redis_url.startswith("redis://"):
|
||||||
|
# return redis_url.split("redis://")[1].split("/")[0]
|
||||||
|
# else:
|
||||||
|
# return redis_url
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# def get_redis_port() -> int:
|
||||||
|
# redis_url = get_redis_url()
|
||||||
|
# if redis_url.startswith("redis://"):
|
||||||
|
# return int(redis_url.split("redis://")[1].split("/")[0].split(":")[1])
|
||||||
|
# else:
|
||||||
|
# return int(redis_url.split(":")[1])
|
||||||
|
|
||||||
|
|
||||||
def get_jwt_secret_key() -> str:
|
def get_jwt_secret_key() -> str:
|
||||||
"""
|
"""
|
||||||
Get JWT secret key from environment variables.
|
Get JWT secret key from environment variables.
|
||||||
|
|||||||
@@ -7,10 +7,11 @@ The application will terminate if MongoDB is not accessible at startup.
|
|||||||
|
|
||||||
import sys
|
import sys
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pymongo import MongoClient
|
from pymongo import MongoClient
|
||||||
from pymongo.database import Database
|
from pymongo.database import Database
|
||||||
from pymongo.errors import ConnectionFailure, ServerSelectionTimeoutError
|
from pymongo.errors import ConnectionFailure, ServerSelectionTimeoutError
|
||||||
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
|
|
||||||
from app.config.settings import get_mongodb_url, get_mongodb_database_name
|
from app.config.settings import get_mongodb_url, get_mongodb_database_name
|
||||||
|
|
||||||
# Global variables for singleton pattern
|
# Global variables for singleton pattern
|
||||||
@@ -18,7 +19,7 @@ _client: Optional[MongoClient] = None
|
|||||||
_database: Optional[Database] = None
|
_database: Optional[Database] = None
|
||||||
|
|
||||||
|
|
||||||
def create_mongodb_client() -> AsyncIOMotorClient:
|
def create_mongodb_client() -> MongoClient:
|
||||||
"""
|
"""
|
||||||
Create MongoDB client with connection validation.
|
Create MongoDB client with connection validation.
|
||||||
|
|
||||||
@@ -32,7 +33,7 @@ def create_mongodb_client() -> AsyncIOMotorClient:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Create client with short timeout for fail-fast behavior
|
# Create client with short timeout for fail-fast behavior
|
||||||
client = AsyncIOMotorClient(
|
client = MongoClient(
|
||||||
mongodb_url,
|
mongodb_url,
|
||||||
serverSelectionTimeoutMS=5000, # 5 seconds timeout
|
serverSelectionTimeoutMS=5000, # 5 seconds timeout
|
||||||
connectTimeoutMS=5000,
|
connectTimeoutMS=5000,
|
||||||
|
|||||||
@@ -8,7 +8,8 @@ in MongoDB with proper error handling and type safety.
|
|||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
|
|
||||||
from bson import ObjectId
|
from bson import ObjectId
|
||||||
from motor.motor_asyncio import AsyncIOMotorCollection, AsyncIOMotorDatabase
|
from pymongo.collection import Collection
|
||||||
|
from pymongo.database import Database
|
||||||
from pymongo.errors import DuplicateKeyError, PyMongoError
|
from pymongo.errors import DuplicateKeyError, PyMongoError
|
||||||
|
|
||||||
from app.database.connection import get_extra_args
|
from app.database.connection import get_extra_args
|
||||||
@@ -37,33 +38,29 @@ class FileDocumentRepository:
|
|||||||
with proper error handling and data validation.
|
with proper error handling and data validation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, database: AsyncIOMotorDatabase):
|
def __init__(self, database: Database):
|
||||||
"""Initialize file repository with database connection."""
|
"""Initialize file repository with database connection."""
|
||||||
self.db = database
|
self.db = database
|
||||||
self.collection: AsyncIOMotorCollection = self.db.documents
|
self.collection: Collection = self.db.documents
|
||||||
|
|
||||||
async def initialize(self):
|
def initialize(self):
|
||||||
"""
|
"""
|
||||||
Initialize repository by ensuring required indexes exist.
|
Initialize repository by ensuring required indexes exist.
|
||||||
|
|
||||||
Should be called after repository instantiation to setup database indexes.
|
Should be called after repository instantiation to setup database indexes.
|
||||||
"""
|
"""
|
||||||
await self._ensure_indexes()
|
self._ensure_indexes()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def _ensure_indexes(self):
|
def _ensure_indexes(self):
|
||||||
"""
|
"""
|
||||||
Ensure required database indexes exist.
|
Ensure required database indexes exist.
|
||||||
|
|
||||||
Creates unique index on username field to prevent duplicates.
|
Creates unique index on username field to prevent duplicates.
|
||||||
"""
|
"""
|
||||||
try:
|
|
||||||
await self.collection.create_index("filepath", unique=True)
|
|
||||||
except PyMongoError:
|
|
||||||
# Index might already exist, ignore error
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def create_document(self, file_data: FileDocument, session=None) -> FileDocument:
|
def create_document(self, file_data: FileDocument, session=None) -> FileDocument:
|
||||||
"""
|
"""
|
||||||
Create a new file document in database.
|
Create a new file document in database.
|
||||||
|
|
||||||
@@ -83,7 +80,7 @@ class FileDocumentRepository:
|
|||||||
if "_id" in file_dict and file_dict["_id"] is None:
|
if "_id" in file_dict and file_dict["_id"] is None:
|
||||||
del file_dict["_id"]
|
del file_dict["_id"]
|
||||||
|
|
||||||
result = await self.collection.insert_one(file_dict, **get_extra_args(session))
|
result = self.collection.insert_one(file_dict, **get_extra_args(session))
|
||||||
file_data.id = result.inserted_id
|
file_data.id = result.inserted_id
|
||||||
return file_data
|
return file_data
|
||||||
|
|
||||||
@@ -92,7 +89,7 @@ class FileDocumentRepository:
|
|||||||
except PyMongoError as e:
|
except PyMongoError as e:
|
||||||
raise ValueError(f"Failed to create file document: {e}")
|
raise ValueError(f"Failed to create file document: {e}")
|
||||||
|
|
||||||
async def find_document_by_id(self, file_id: str) -> Optional[FileDocument]:
|
def find_document_by_id(self, file_id: str) -> Optional[FileDocument]:
|
||||||
"""
|
"""
|
||||||
Find file document by ID.
|
Find file document by ID.
|
||||||
|
|
||||||
@@ -106,7 +103,7 @@ class FileDocumentRepository:
|
|||||||
if not ObjectId.is_valid(file_id):
|
if not ObjectId.is_valid(file_id):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
file_doc = await self.collection.find_one({"_id": ObjectId(file_id)})
|
file_doc = self.collection.find_one({"_id": ObjectId(file_id)})
|
||||||
if file_doc:
|
if file_doc:
|
||||||
return FileDocument(**file_doc)
|
return FileDocument(**file_doc)
|
||||||
return None
|
return None
|
||||||
@@ -114,7 +111,7 @@ class FileDocumentRepository:
|
|||||||
except PyMongoError:
|
except PyMongoError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def find_document_by_hash(self, file_hash: str) -> Optional[FileDocument]:
|
def find_document_by_hash(self, file_hash: str) -> Optional[FileDocument]:
|
||||||
"""
|
"""
|
||||||
Find file document by file hash to detect duplicates.
|
Find file document by file hash to detect duplicates.
|
||||||
|
|
||||||
@@ -125,7 +122,7 @@ class FileDocumentRepository:
|
|||||||
FileDocument or None: File document if found, None otherwise
|
FileDocument or None: File document if found, None otherwise
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
file_doc = await self.collection.find_one({"file_hash": file_hash})
|
file_doc = self.collection.find_one({"file_hash": file_hash})
|
||||||
if file_doc:
|
if file_doc:
|
||||||
return FileDocument(**file_doc)
|
return FileDocument(**file_doc)
|
||||||
return None
|
return None
|
||||||
@@ -133,7 +130,7 @@ class FileDocumentRepository:
|
|||||||
except PyMongoError:
|
except PyMongoError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def find_document_by_filepath(self, filepath: str) -> Optional[FileDocument]:
|
def find_document_by_filepath(self, filepath: str) -> Optional[FileDocument]:
|
||||||
"""
|
"""
|
||||||
Find file document by exact filepath.
|
Find file document by exact filepath.
|
||||||
|
|
||||||
@@ -144,7 +141,7 @@ class FileDocumentRepository:
|
|||||||
FileDocument or None: File document if found, None otherwise
|
FileDocument or None: File document if found, None otherwise
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
file_doc = await self.collection.find_one({"filepath": filepath})
|
file_doc = self.collection.find_one({"filepath": filepath})
|
||||||
if file_doc:
|
if file_doc:
|
||||||
return FileDocument(**file_doc)
|
return FileDocument(**file_doc)
|
||||||
return None
|
return None
|
||||||
@@ -152,7 +149,7 @@ class FileDocumentRepository:
|
|||||||
except PyMongoError:
|
except PyMongoError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def find_document_by_name(self, filename: str, matching_method: MatchMethodBase = None) -> List[FileDocument]:
|
def find_document_by_name(self, filename: str, matching_method: MatchMethodBase = None) -> List[FileDocument]:
|
||||||
"""
|
"""
|
||||||
Find file documents by filename using fuzzy matching.
|
Find file documents by filename using fuzzy matching.
|
||||||
|
|
||||||
@@ -166,8 +163,7 @@ class FileDocumentRepository:
|
|||||||
try:
|
try:
|
||||||
# Get all files from database
|
# Get all files from database
|
||||||
cursor = self.collection.find({})
|
cursor = self.collection.find({})
|
||||||
all_files = await cursor.to_list(length=None)
|
all_documents = [FileDocument(**file_doc) for file_doc in cursor]
|
||||||
all_documents = [FileDocument(**file_doc) for file_doc in all_files]
|
|
||||||
|
|
||||||
if isinstance(matching_method, FuzzyMatching):
|
if isinstance(matching_method, FuzzyMatching):
|
||||||
return fuzzy_matching(filename, all_documents, matching_method.threshold)
|
return fuzzy_matching(filename, all_documents, matching_method.threshold)
|
||||||
@@ -177,7 +173,7 @@ class FileDocumentRepository:
|
|||||||
except PyMongoError:
|
except PyMongoError:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def list_documents(self, skip: int = 0, limit: int = 100) -> List[FileDocument]:
|
def list_documents(self, skip: int = 0, limit: int = 100) -> List[FileDocument]:
|
||||||
"""
|
"""
|
||||||
List file documents with pagination.
|
List file documents with pagination.
|
||||||
|
|
||||||
@@ -190,13 +186,12 @@ class FileDocumentRepository:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
cursor = self.collection.find({}).skip(skip).limit(limit).sort("detected_at", -1)
|
cursor = self.collection.find({}).skip(skip).limit(limit).sort("detected_at", -1)
|
||||||
file_docs = await cursor.to_list(length=limit)
|
return [FileDocument(**doc) for doc in cursor]
|
||||||
return [FileDocument(**doc) for doc in file_docs]
|
|
||||||
|
|
||||||
except PyMongoError:
|
except PyMongoError:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def count_documents(self) -> int:
|
def count_documents(self) -> int:
|
||||||
"""
|
"""
|
||||||
Count total number of file documents.
|
Count total number of file documents.
|
||||||
|
|
||||||
@@ -204,11 +199,11 @@ class FileDocumentRepository:
|
|||||||
int: Total number of file documents in collection
|
int: Total number of file documents in collection
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
return await self.collection.count_documents({})
|
return self.collection.count_documents({})
|
||||||
except PyMongoError:
|
except PyMongoError:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
async def update_document(self, file_id: str, update_data: dict, session=None) -> Optional[FileDocument]:
|
def update_document(self, file_id: str, update_data: dict, session=None) -> Optional[FileDocument]:
|
||||||
"""
|
"""
|
||||||
Update file document with new data.
|
Update file document with new data.
|
||||||
|
|
||||||
@@ -228,9 +223,9 @@ class FileDocumentRepository:
|
|||||||
clean_update_data = {k: v for k, v in update_data.items() if v is not None}
|
clean_update_data = {k: v for k, v in update_data.items() if v is not None}
|
||||||
|
|
||||||
if not clean_update_data:
|
if not clean_update_data:
|
||||||
return await self.find_document_by_id(file_id)
|
return self.find_document_by_id(file_id)
|
||||||
|
|
||||||
result = await self.collection.find_one_and_update(
|
result = self.collection.find_one_and_update(
|
||||||
{"_id": ObjectId(file_id)},
|
{"_id": ObjectId(file_id)},
|
||||||
{"$set": clean_update_data},
|
{"$set": clean_update_data},
|
||||||
return_document=True,
|
return_document=True,
|
||||||
@@ -244,7 +239,7 @@ class FileDocumentRepository:
|
|||||||
except PyMongoError:
|
except PyMongoError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def delete_document(self, file_id: str, session=None) -> bool:
|
def delete_document(self, file_id: str, session=None) -> bool:
|
||||||
"""
|
"""
|
||||||
Delete file document from database.
|
Delete file document from database.
|
||||||
|
|
||||||
@@ -259,7 +254,7 @@ class FileDocumentRepository:
|
|||||||
if not ObjectId.is_valid(file_id):
|
if not ObjectId.is_valid(file_id):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
result = await self.collection.delete_one({"_id": ObjectId(file_id)}, **get_extra_args(session))
|
result = self.collection.delete_one({"_id": ObjectId(file_id)}, **get_extra_args(session))
|
||||||
return result.deleted_count > 0
|
return result.deleted_count > 0
|
||||||
|
|
||||||
except PyMongoError:
|
except PyMongoError:
|
||||||
|
|||||||
@@ -8,7 +8,8 @@ with automatic timestamp management and error handling.
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from motor.motor_asyncio import AsyncIOMotorCollection, AsyncIOMotorDatabase
|
from pymongo.collection import Collection
|
||||||
|
from pymongo.database import Database
|
||||||
from pymongo.errors import PyMongoError
|
from pymongo.errors import PyMongoError
|
||||||
|
|
||||||
from app.exceptions.job_exceptions import JobRepositoryError
|
from app.exceptions.job_exceptions import JobRepositoryError
|
||||||
@@ -24,33 +25,33 @@ class JobRepository:
|
|||||||
timestamp management and proper error handling.
|
timestamp management and proper error handling.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, database: AsyncIOMotorDatabase):
|
def __init__(self, database: Database):
|
||||||
"""Initialize repository with MongoDB collection reference."""
|
"""Initialize repository with MongoDB collection reference."""
|
||||||
self.db = database
|
self.db = database
|
||||||
self.collection: AsyncIOMotorCollection = self.db.processing_jobs
|
self.collection: Collection = self.db.processing_jobs
|
||||||
|
|
||||||
async def _ensure_indexes(self):
|
def _ensure_indexes(self):
|
||||||
"""
|
"""
|
||||||
Ensure required database indexes exist.
|
Ensure required database indexes exist.
|
||||||
|
|
||||||
Creates unique index on username field to prevent duplicates.
|
Creates unique index on username field to prevent duplicates.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
await self.collection.create_index("document_id", unique=True)
|
self.collection.create_index("document_id", unique=True)
|
||||||
except PyMongoError:
|
except PyMongoError:
|
||||||
# Index might already exist, ignore error
|
# Index might already exist, ignore error
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def initialize(self):
|
def initialize(self):
|
||||||
"""
|
"""
|
||||||
Initialize repository by ensuring required indexes exist.
|
Initialize repository by ensuring required indexes exist.
|
||||||
|
|
||||||
Should be called after repository instantiation to setup database indexes.
|
Should be called after repository instantiation to setup database indexes.
|
||||||
"""
|
"""
|
||||||
await self._ensure_indexes()
|
self._ensure_indexes()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def create_job(self, document_id: PyObjectId, task_id: Optional[str] = None) -> ProcessingJob:
|
def create_job(self, document_id: PyObjectId, task_id: Optional[str] = None) -> ProcessingJob:
|
||||||
"""
|
"""
|
||||||
Create a new processing job.
|
Create a new processing job.
|
||||||
|
|
||||||
@@ -75,7 +76,7 @@ class JobRepository:
|
|||||||
"error_message": None
|
"error_message": None
|
||||||
}
|
}
|
||||||
|
|
||||||
result = await self.collection.insert_one(job_data)
|
result = self.collection.insert_one(job_data)
|
||||||
job_data["_id"] = result.inserted_id
|
job_data["_id"] = result.inserted_id
|
||||||
|
|
||||||
return ProcessingJob(**job_data)
|
return ProcessingJob(**job_data)
|
||||||
@@ -83,7 +84,7 @@ class JobRepository:
|
|||||||
except PyMongoError as e:
|
except PyMongoError as e:
|
||||||
raise JobRepositoryError("create_job", e)
|
raise JobRepositoryError("create_job", e)
|
||||||
|
|
||||||
async def find_job_by_id(self, job_id: PyObjectId) -> Optional[ProcessingJob]:
|
def find_job_by_id(self, job_id: PyObjectId) -> Optional[ProcessingJob]:
|
||||||
"""
|
"""
|
||||||
Retrieve a job by its ID.
|
Retrieve a job by its ID.
|
||||||
|
|
||||||
@@ -98,7 +99,7 @@ class JobRepository:
|
|||||||
JobRepositoryError: If database operation fails
|
JobRepositoryError: If database operation fails
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
job_data = await self.collection.find_one({"_id": job_id})
|
job_data = self.collection.find_one({"_id": job_id})
|
||||||
if job_data:
|
if job_data:
|
||||||
return ProcessingJob(**job_data)
|
return ProcessingJob(**job_data)
|
||||||
|
|
||||||
@@ -107,7 +108,7 @@ class JobRepository:
|
|||||||
except PyMongoError as e:
|
except PyMongoError as e:
|
||||||
raise JobRepositoryError("get_job_by_id", e)
|
raise JobRepositoryError("get_job_by_id", e)
|
||||||
|
|
||||||
async def update_job_status(
|
def update_job_status(
|
||||||
self,
|
self,
|
||||||
job_id: PyObjectId,
|
job_id: PyObjectId,
|
||||||
status: ProcessingStatus,
|
status: ProcessingStatus,
|
||||||
@@ -143,7 +144,7 @@ class JobRepository:
|
|||||||
if error_message is not None:
|
if error_message is not None:
|
||||||
update_data["error_message"] = error_message
|
update_data["error_message"] = error_message
|
||||||
|
|
||||||
result = await self.collection.find_one_and_update(
|
result = self.collection.find_one_and_update(
|
||||||
{"_id": job_id},
|
{"_id": job_id},
|
||||||
{"$set": update_data},
|
{"$set": update_data},
|
||||||
return_document=True
|
return_document=True
|
||||||
@@ -157,7 +158,7 @@ class JobRepository:
|
|||||||
except PyMongoError as e:
|
except PyMongoError as e:
|
||||||
raise JobRepositoryError("update_job_status", e)
|
raise JobRepositoryError("update_job_status", e)
|
||||||
|
|
||||||
async def delete_job(self, job_id: PyObjectId) -> bool:
|
def delete_job(self, job_id: PyObjectId) -> bool:
|
||||||
"""
|
"""
|
||||||
Delete a job from the database.
|
Delete a job from the database.
|
||||||
|
|
||||||
@@ -171,14 +172,14 @@ class JobRepository:
|
|||||||
JobRepositoryError: If database operation fails
|
JobRepositoryError: If database operation fails
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
result = await self.collection.delete_one({"_id": job_id})
|
result = self.collection.delete_one({"_id": job_id})
|
||||||
|
|
||||||
return result.deleted_count > 0
|
return result.deleted_count > 0
|
||||||
|
|
||||||
except PyMongoError as e:
|
except PyMongoError as e:
|
||||||
raise JobRepositoryError("delete_job", e)
|
raise JobRepositoryError("delete_job", e)
|
||||||
|
|
||||||
async def find_jobs_by_document_id(self, document_id: PyObjectId) -> List[ProcessingJob]:
|
def find_jobs_by_document_id(self, document_id: PyObjectId) -> List[ProcessingJob]:
|
||||||
"""
|
"""
|
||||||
Retrieve all jobs for a specific file.
|
Retrieve all jobs for a specific file.
|
||||||
|
|
||||||
@@ -195,7 +196,7 @@ class JobRepository:
|
|||||||
cursor = self.collection.find({"document_id": document_id})
|
cursor = self.collection.find({"document_id": document_id})
|
||||||
|
|
||||||
jobs = []
|
jobs = []
|
||||||
async for job_data in cursor:
|
for job_data in cursor:
|
||||||
jobs.append(ProcessingJob(**job_data))
|
jobs.append(ProcessingJob(**job_data))
|
||||||
|
|
||||||
return jobs
|
return jobs
|
||||||
@@ -203,7 +204,7 @@ class JobRepository:
|
|||||||
except PyMongoError as e:
|
except PyMongoError as e:
|
||||||
raise JobRepositoryError("get_jobs_by_file_id", e)
|
raise JobRepositoryError("get_jobs_by_file_id", e)
|
||||||
|
|
||||||
async def get_jobs_by_status(self, status: ProcessingStatus) -> List[ProcessingJob]:
|
def get_jobs_by_status(self, status: ProcessingStatus) -> List[ProcessingJob]:
|
||||||
"""
|
"""
|
||||||
Retrieve all jobs with a specific status.
|
Retrieve all jobs with a specific status.
|
||||||
|
|
||||||
@@ -220,7 +221,7 @@ class JobRepository:
|
|||||||
cursor = self.collection.find({"status": status})
|
cursor = self.collection.find({"status": status})
|
||||||
|
|
||||||
jobs = []
|
jobs = []
|
||||||
async for job_data in cursor:
|
for job_data in cursor:
|
||||||
jobs.append(ProcessingJob(**job_data))
|
jobs.append(ProcessingJob(**job_data))
|
||||||
|
|
||||||
return jobs
|
return jobs
|
||||||
|
|||||||
@@ -5,10 +5,12 @@ This module implements the repository pattern for user CRUD operations
|
|||||||
with dependency injection of the database connection using async/await.
|
with dependency injection of the database connection using async/await.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional, List
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
from bson import ObjectId
|
from bson import ObjectId
|
||||||
from motor.motor_asyncio import AsyncIOMotorDatabase, AsyncIOMotorCollection
|
from pymongo.collection import Collection
|
||||||
|
from pymongo.database import Database
|
||||||
from pymongo.errors import DuplicateKeyError, PyMongoError
|
from pymongo.errors import DuplicateKeyError, PyMongoError
|
||||||
|
|
||||||
from app.models.user import UserCreate, UserInDB, UserUpdate
|
from app.models.user import UserCreate, UserInDB, UserUpdate
|
||||||
@@ -23,7 +25,7 @@ class UserRepository:
|
|||||||
following the repository pattern with dependency injection and async/await.
|
following the repository pattern with dependency injection and async/await.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, database: AsyncIOMotorDatabase):
|
def __init__(self, database: Database):
|
||||||
"""
|
"""
|
||||||
Initialize repository with database dependency.
|
Initialize repository with database dependency.
|
||||||
|
|
||||||
@@ -31,30 +33,30 @@ class UserRepository:
|
|||||||
database (AsyncIOMotorDatabase): MongoDB database instance
|
database (AsyncIOMotorDatabase): MongoDB database instance
|
||||||
"""
|
"""
|
||||||
self.db = database
|
self.db = database
|
||||||
self.collection: AsyncIOMotorCollection = database.users
|
self.collection: Collection = database.users
|
||||||
|
|
||||||
async def initialize(self):
|
def initialize(self):
|
||||||
"""
|
"""
|
||||||
Initialize repository by ensuring required indexes exist.
|
Initialize repository by ensuring required indexes exist.
|
||||||
|
|
||||||
Should be called after repository instantiation to setup database indexes.
|
Should be called after repository instantiation to setup database indexes.
|
||||||
"""
|
"""
|
||||||
await self._ensure_indexes()
|
self._ensure_indexes()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def _ensure_indexes(self):
|
def _ensure_indexes(self):
|
||||||
"""
|
"""
|
||||||
Ensure required database indexes exist.
|
Ensure required database indexes exist.
|
||||||
|
|
||||||
Creates unique index on username field to prevent duplicates.
|
Creates unique index on username field to prevent duplicates.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
await self.collection.create_index("username", unique=True)
|
self.collection.create_index("username", unique=True)
|
||||||
except PyMongoError:
|
except PyMongoError:
|
||||||
# Index might already exist, ignore error
|
# Index might already exist, ignore error
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def create_user(self, user_data: UserCreate) -> UserInDB:
|
def create_user(self, user_data: UserCreate) -> UserInDB:
|
||||||
"""
|
"""
|
||||||
Create a new user in the database.
|
Create a new user in the database.
|
||||||
|
|
||||||
@@ -79,7 +81,7 @@ class UserRepository:
|
|||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await self.collection.insert_one(user_dict)
|
result = self.collection.insert_one(user_dict)
|
||||||
user_dict["_id"] = result.inserted_id
|
user_dict["_id"] = result.inserted_id
|
||||||
return UserInDB(**user_dict)
|
return UserInDB(**user_dict)
|
||||||
except DuplicateKeyError as e:
|
except DuplicateKeyError as e:
|
||||||
@@ -87,7 +89,7 @@ class UserRepository:
|
|||||||
except PyMongoError as e:
|
except PyMongoError as e:
|
||||||
raise ValueError(f"Failed to create user: {e}")
|
raise ValueError(f"Failed to create user: {e}")
|
||||||
|
|
||||||
async def find_user_by_username(self, username: str) -> Optional[UserInDB]:
|
def find_user_by_username(self, username: str) -> Optional[UserInDB]:
|
||||||
"""
|
"""
|
||||||
Find user by username.
|
Find user by username.
|
||||||
|
|
||||||
@@ -98,14 +100,14 @@ class UserRepository:
|
|||||||
UserInDB or None: User if found, None otherwise
|
UserInDB or None: User if found, None otherwise
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
user_doc = await self.collection.find_one({"username": username})
|
user_doc = self.collection.find_one({"username": username})
|
||||||
if user_doc:
|
if user_doc:
|
||||||
return UserInDB(**user_doc)
|
return UserInDB(**user_doc)
|
||||||
return None
|
return None
|
||||||
except PyMongoError:
|
except PyMongoError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def find_user_by_id(self, user_id: str) -> Optional[UserInDB]:
|
def find_user_by_id(self, user_id: str) -> Optional[UserInDB]:
|
||||||
"""
|
"""
|
||||||
Find user by ID.
|
Find user by ID.
|
||||||
|
|
||||||
@@ -119,14 +121,14 @@ class UserRepository:
|
|||||||
if not ObjectId.is_valid(user_id):
|
if not ObjectId.is_valid(user_id):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
user_doc = await self.collection.find_one({"_id": ObjectId(user_id)})
|
user_doc = self.collection.find_one({"_id": ObjectId(user_id)})
|
||||||
if user_doc:
|
if user_doc:
|
||||||
return UserInDB(**user_doc)
|
return UserInDB(**user_doc)
|
||||||
return None
|
return None
|
||||||
except PyMongoError:
|
except PyMongoError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def find_user_by_email(self, email: str) -> Optional[UserInDB]:
|
def find_user_by_email(self, email: str) -> Optional[UserInDB]:
|
||||||
"""
|
"""
|
||||||
Find user by email address.
|
Find user by email address.
|
||||||
|
|
||||||
@@ -137,14 +139,14 @@ class UserRepository:
|
|||||||
UserInDB or None: User if found, None otherwise
|
UserInDB or None: User if found, None otherwise
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
user_doc = await self.collection.find_one({"email": email})
|
user_doc = self.collection.find_one({"email": email})
|
||||||
if user_doc:
|
if user_doc:
|
||||||
return UserInDB(**user_doc)
|
return UserInDB(**user_doc)
|
||||||
return None
|
return None
|
||||||
except PyMongoError:
|
except PyMongoError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def update_user(self, user_id: str, user_update: UserUpdate) -> Optional[UserInDB]:
|
def update_user(self, user_id: str, user_update: UserUpdate) -> Optional[UserInDB]:
|
||||||
"""
|
"""
|
||||||
Update user information.
|
Update user information.
|
||||||
|
|
||||||
@@ -177,9 +179,9 @@ class UserRepository:
|
|||||||
clean_update_data = {k: v for k, v in update_data.items() if v is not None}
|
clean_update_data = {k: v for k, v in update_data.items() if v is not None}
|
||||||
|
|
||||||
if not clean_update_data:
|
if not clean_update_data:
|
||||||
return await self.find_user_by_id(user_id)
|
return self.find_user_by_id(user_id)
|
||||||
|
|
||||||
result = await self.collection.find_one_and_update(
|
result = self.collection.find_one_and_update(
|
||||||
{"_id": ObjectId(user_id)},
|
{"_id": ObjectId(user_id)},
|
||||||
{"$set": clean_update_data},
|
{"$set": clean_update_data},
|
||||||
return_document=True
|
return_document=True
|
||||||
@@ -192,7 +194,7 @@ class UserRepository:
|
|||||||
except PyMongoError:
|
except PyMongoError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def delete_user(self, user_id: str) -> bool:
|
def delete_user(self, user_id: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Delete user from database.
|
Delete user from database.
|
||||||
|
|
||||||
@@ -206,12 +208,12 @@ class UserRepository:
|
|||||||
if not ObjectId.is_valid(user_id):
|
if not ObjectId.is_valid(user_id):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
result = await self.collection.delete_one({"_id": ObjectId(user_id)})
|
result = self.collection.delete_one({"_id": ObjectId(user_id)})
|
||||||
return result.deleted_count > 0
|
return result.deleted_count > 0
|
||||||
except PyMongoError:
|
except PyMongoError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def list_users(self, skip: int = 0, limit: int = 100) -> List[UserInDB]:
|
def list_users(self, skip: int = 0, limit: int = 100) -> List[UserInDB]:
|
||||||
"""
|
"""
|
||||||
List users with pagination.
|
List users with pagination.
|
||||||
|
|
||||||
@@ -224,12 +226,12 @@ class UserRepository:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
cursor = self.collection.find({}).skip(skip).limit(limit).sort("created_at", -1)
|
cursor = self.collection.find({}).skip(skip).limit(limit).sort("created_at", -1)
|
||||||
user_docs = await cursor.to_list(length=limit)
|
user_docs = cursor.to_list(length=limit)
|
||||||
return [UserInDB(**user_doc) for user_doc in user_docs]
|
return [UserInDB(**user_doc) for user_doc in user_docs]
|
||||||
except PyMongoError:
|
except PyMongoError:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def count_users(self) -> int:
|
def count_users(self) -> int:
|
||||||
"""
|
"""
|
||||||
Count total number of users.
|
Count total number of users.
|
||||||
|
|
||||||
@@ -237,11 +239,11 @@ class UserRepository:
|
|||||||
int: Total number of users in database
|
int: Total number of users in database
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
return await self.collection.count_documents({})
|
return self.collection.count_documents({})
|
||||||
except PyMongoError:
|
except PyMongoError:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
async def user_exists(self, username: str) -> bool:
|
def user_exists(self, username: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if user exists by username.
|
Check if user exists by username.
|
||||||
|
|
||||||
@@ -252,7 +254,7 @@ class UserRepository:
|
|||||||
bool: True if user exists, False otherwise
|
bool: True if user exists, False otherwise
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
count = await self.collection.count_documents({"username": username})
|
count = self.collection.count_documents({"username": username})
|
||||||
return count > 0
|
return count > 0
|
||||||
except PyMongoError:
|
except PyMongoError:
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -63,15 +63,17 @@ class DocumentFileEventHandler(FileSystemEventHandler):
|
|||||||
|
|
||||||
logger.info(f"Processing new file: {filepath}")
|
logger.info(f"Processing new file: {filepath}")
|
||||||
|
|
||||||
try:
|
# try:
|
||||||
from tasks.main import process_document
|
from tasks.document_processing import process_document
|
||||||
celery_result = process_document.delay(filepath)
|
task_result = process_document.delay(filepath)
|
||||||
celery_task_id = celery_result.id
|
print(task_result)
|
||||||
logger.info(f"Dispatched Celery task with ID: {celery_task_id}")
|
print("hello world")
|
||||||
|
# task_id = task_result.task_id
|
||||||
|
# logger.info(f"Dispatched Celery task with ID: {task_id}")
|
||||||
|
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
logger.error(f"Failed to process file {filepath}: {str(e)}")
|
# logger.error(f"Failed to process file {filepath}: {str(e)}")
|
||||||
# Note: We don't re-raise the exception to keep the watcher running
|
# # Note: We don't re-raise the exception to keep the watcher running
|
||||||
|
|
||||||
|
|
||||||
class FileWatcher:
|
class FileWatcher:
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|||||||
|
|
||||||
# Create default admin user
|
# Create default admin user
|
||||||
init_service = InitializationService(user_service)
|
init_service = InitializationService(user_service)
|
||||||
await init_service.initialize_application()
|
init_service.initialize_application()
|
||||||
logger.info("Default admin user initialization completed")
|
logger.info("Default admin user initialization completed")
|
||||||
|
|
||||||
# Create and start file watcher
|
# Create and start file watcher
|
||||||
|
|||||||
@@ -44,8 +44,8 @@ class DocumentService:
|
|||||||
self.document_repository = FileDocumentRepository(self.db)
|
self.document_repository = FileDocumentRepository(self.db)
|
||||||
self.objects_folder = objects_folder or get_objects_folder()
|
self.objects_folder = objects_folder or get_objects_folder()
|
||||||
|
|
||||||
async def initialize(self):
|
def initialize(self):
|
||||||
await self.document_repository.initialize()
|
self.document_repository.initialize()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -136,7 +136,7 @@ class DocumentService:
|
|||||||
with open(target_path, "wb") as f:
|
with open(target_path, "wb") as f:
|
||||||
f.write(content)
|
f.write(content)
|
||||||
|
|
||||||
async def create_document(
|
def create_document(
|
||||||
self,
|
self,
|
||||||
file_path: str,
|
file_path: str,
|
||||||
file_bytes: bytes | None = None,
|
file_bytes: bytes | None = None,
|
||||||
@@ -187,7 +187,7 @@ class DocumentService:
|
|||||||
mime_type=mime_type
|
mime_type=mime_type
|
||||||
)
|
)
|
||||||
|
|
||||||
created_file = await self.document_repository.create_document(file_data)
|
created_file = self.document_repository.create_document(file_data)
|
||||||
|
|
||||||
return created_file
|
return created_file
|
||||||
|
|
||||||
@@ -195,7 +195,7 @@ class DocumentService:
|
|||||||
# Transaction will automatically rollback if supported
|
# Transaction will automatically rollback if supported
|
||||||
raise PyMongoError(f"Failed to create document: {str(e)}")
|
raise PyMongoError(f"Failed to create document: {str(e)}")
|
||||||
|
|
||||||
async def get_document_by_id(self, document_id: PyObjectId) -> Optional[FileDocument]:
|
def get_document_by_id(self, document_id: PyObjectId) -> Optional[FileDocument]:
|
||||||
"""
|
"""
|
||||||
Retrieve a document by its ID.
|
Retrieve a document by its ID.
|
||||||
|
|
||||||
@@ -205,9 +205,9 @@ class DocumentService:
|
|||||||
Returns:
|
Returns:
|
||||||
FileDocument if found, None otherwise
|
FileDocument if found, None otherwise
|
||||||
"""
|
"""
|
||||||
return await self.document_repository.find_document_by_id(str(document_id))
|
return self.document_repository.find_document_by_id(str(document_id))
|
||||||
|
|
||||||
async def get_document_by_hash(self, file_hash: str) -> Optional[FileDocument]:
|
def get_document_by_hash(self, file_hash: str) -> Optional[FileDocument]:
|
||||||
"""
|
"""
|
||||||
Retrieve a document by its file hash.
|
Retrieve a document by its file hash.
|
||||||
|
|
||||||
@@ -217,9 +217,9 @@ class DocumentService:
|
|||||||
Returns:
|
Returns:
|
||||||
FileDocument if found, None otherwise
|
FileDocument if found, None otherwise
|
||||||
"""
|
"""
|
||||||
return await self.document_repository.find_document_by_hash(file_hash)
|
return self.document_repository.find_document_by_hash(file_hash)
|
||||||
|
|
||||||
async def get_document_by_filepath(self, filepath: str) -> Optional[FileDocument]:
|
def get_document_by_filepath(self, filepath: str) -> Optional[FileDocument]:
|
||||||
"""
|
"""
|
||||||
Retrieve a document by its file path.
|
Retrieve a document by its file path.
|
||||||
|
|
||||||
@@ -229,9 +229,9 @@ class DocumentService:
|
|||||||
Returns:
|
Returns:
|
||||||
FileDocument if found, None otherwise
|
FileDocument if found, None otherwise
|
||||||
"""
|
"""
|
||||||
return await self.document_repository.find_document_by_filepath(filepath)
|
return self.document_repository.find_document_by_filepath(filepath)
|
||||||
|
|
||||||
async def get_document_content_by_hash(self, file_hash):
|
def get_document_content_by_hash(self, file_hash):
|
||||||
target_path = self._get_document_path(file_hash)
|
target_path = self._get_document_path(file_hash)
|
||||||
if not os.path.exists(target_path):
|
if not os.path.exists(target_path):
|
||||||
return None
|
return None
|
||||||
@@ -239,7 +239,7 @@ class DocumentService:
|
|||||||
with open(target_path, "rb") as f:
|
with open(target_path, "rb") as f:
|
||||||
return f.read()
|
return f.read()
|
||||||
|
|
||||||
async def list_documents(
|
def list_documents(
|
||||||
self,
|
self,
|
||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
limit: int = 100
|
limit: int = 100
|
||||||
@@ -254,18 +254,18 @@ class DocumentService:
|
|||||||
Returns:
|
Returns:
|
||||||
List of FileDocument instances
|
List of FileDocument instances
|
||||||
"""
|
"""
|
||||||
return await self.document_repository.list_documents(skip=skip, limit=limit)
|
return self.document_repository.list_documents(skip=skip, limit=limit)
|
||||||
|
|
||||||
async def count_documents(self) -> int:
|
def count_documents(self) -> int:
|
||||||
"""
|
"""
|
||||||
Get total number of documents.
|
Get total number of documents.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Total document count
|
Total document count
|
||||||
"""
|
"""
|
||||||
return await self.document_repository.count_documents()
|
return self.document_repository.count_documents()
|
||||||
|
|
||||||
async def update_document(
|
def update_document(
|
||||||
self,
|
self,
|
||||||
document_id: PyObjectId,
|
document_id: PyObjectId,
|
||||||
update_data: Dict[str, Any]
|
update_data: Dict[str, Any]
|
||||||
@@ -285,9 +285,9 @@ class DocumentService:
|
|||||||
update_data["file_hash"] = file_hash
|
update_data["file_hash"] = file_hash
|
||||||
self.save_content_if_needed(file_hash, update_data["file_bytes"])
|
self.save_content_if_needed(file_hash, update_data["file_bytes"])
|
||||||
|
|
||||||
return await self.document_repository.update_document(document_id, update_data)
|
return self.document_repository.update_document(document_id, update_data)
|
||||||
|
|
||||||
async def delete_document(self, document_id: PyObjectId) -> bool:
|
def delete_document(self, document_id: PyObjectId) -> bool:
|
||||||
"""
|
"""
|
||||||
Delete a document and its orphaned content.
|
Delete a document and its orphaned content.
|
||||||
|
|
||||||
@@ -308,17 +308,17 @@ class DocumentService:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Get document to find its hash
|
# Get document to find its hash
|
||||||
document = await self.document_repository.find_document_by_id(document_id)
|
document = self.document_repository.find_document_by_id(document_id)
|
||||||
if not document:
|
if not document:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Delete the document
|
# Delete the document
|
||||||
deleted = await self.document_repository.delete_document(document_id)
|
deleted = self.document_repository.delete_document(document_id)
|
||||||
if not deleted:
|
if not deleted:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Check if content is orphaned
|
# Check if content is orphaned
|
||||||
remaining_files = await self.document_repository.find_document_by_hash(document.file_hash)
|
remaining_files = self.document_repository.find_document_by_hash(document.file_hash)
|
||||||
|
|
||||||
# If no other files reference this content, delete it
|
# If no other files reference this content, delete it
|
||||||
if not remaining_files:
|
if not remaining_files:
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ class InitializationService:
|
|||||||
"""
|
"""
|
||||||
self.user_service = user_service
|
self.user_service = user_service
|
||||||
|
|
||||||
async def ensure_admin_user_exists(self) -> Optional[UserInDB]:
|
def ensure_admin_user_exists(self) -> Optional[UserInDB]:
|
||||||
"""
|
"""
|
||||||
Ensure default admin user exists in the system.
|
Ensure default admin user exists in the system.
|
||||||
|
|
||||||
@@ -48,7 +48,7 @@ class InitializationService:
|
|||||||
logger.info("Checking if admin user exists...")
|
logger.info("Checking if admin user exists...")
|
||||||
|
|
||||||
# Check if any admin user already exists
|
# Check if any admin user already exists
|
||||||
if await self._admin_user_exists():
|
if self._admin_user_exists():
|
||||||
logger.info("Admin user already exists, skipping creation")
|
logger.info("Admin user already exists, skipping creation")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -63,7 +63,7 @@ class InitializationService:
|
|||||||
role=UserRole.ADMIN
|
role=UserRole.ADMIN
|
||||||
)
|
)
|
||||||
|
|
||||||
created_user = await self.user_service.create_user(admin_data)
|
created_user = self.user_service.create_user(admin_data)
|
||||||
logger.info(f"Default admin user created successfully with ID: {created_user.id}")
|
logger.info(f"Default admin user created successfully with ID: {created_user.id}")
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Default admin user created with username 'admin' and password 'admin'. "
|
"Default admin user created with username 'admin' and password 'admin'. "
|
||||||
@@ -76,7 +76,7 @@ class InitializationService:
|
|||||||
logger.error(f"Failed to create default admin user: {str(e)}")
|
logger.error(f"Failed to create default admin user: {str(e)}")
|
||||||
raise Exception(f"Admin user creation failed: {str(e)}")
|
raise Exception(f"Admin user creation failed: {str(e)}")
|
||||||
|
|
||||||
async def _admin_user_exists(self) -> bool:
|
def _admin_user_exists(self) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if any admin user exists in the system.
|
Check if any admin user exists in the system.
|
||||||
|
|
||||||
@@ -85,7 +85,7 @@ class InitializationService:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Get all users and check if any have admin role
|
# Get all users and check if any have admin role
|
||||||
users = await self.user_service.list_users(limit=1000) # Reasonable limit for admin check
|
users = self.user_service.list_users(limit=1000) # Reasonable limit for admin check
|
||||||
|
|
||||||
for user in users:
|
for user in users:
|
||||||
if user.role == UserRole.ADMIN and user.is_active:
|
if user.role == UserRole.ADMIN and user.is_active:
|
||||||
@@ -98,7 +98,7 @@ class InitializationService:
|
|||||||
# In case of error, assume admin exists to avoid creating duplicates
|
# In case of error, assume admin exists to avoid creating duplicates
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def initialize_application(self) -> dict:
|
def initialize_application(self) -> dict:
|
||||||
"""
|
"""
|
||||||
Perform all application initialization tasks.
|
Perform all application initialization tasks.
|
||||||
|
|
||||||
@@ -118,7 +118,7 @@ class InitializationService:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Ensure admin user exists
|
# Ensure admin user exists
|
||||||
created_admin = await self.ensure_admin_user_exists()
|
created_admin = self.ensure_admin_user_exists()
|
||||||
if created_admin:
|
if created_admin:
|
||||||
initialization_summary["admin_user_created"] = True
|
initialization_summary["admin_user_created"] = True
|
||||||
|
|
||||||
|
|||||||
@@ -31,11 +31,11 @@ class JobService:
|
|||||||
self.db = database
|
self.db = database
|
||||||
self.repository = JobRepository(database)
|
self.repository = JobRepository(database)
|
||||||
|
|
||||||
async def initialize(self):
|
def initialize(self):
|
||||||
await self.repository.initialize()
|
self.repository.initialize()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def create_job(self, document_id: PyObjectId, task_id: Optional[str] = None) -> ProcessingJob:
|
def create_job(self, document_id: PyObjectId, task_id: Optional[str] = None) -> ProcessingJob:
|
||||||
"""
|
"""
|
||||||
Create a new processing job.
|
Create a new processing job.
|
||||||
|
|
||||||
@@ -49,9 +49,9 @@ class JobService:
|
|||||||
Raises:
|
Raises:
|
||||||
JobRepositoryError: If database operation fails
|
JobRepositoryError: If database operation fails
|
||||||
"""
|
"""
|
||||||
return await self.repository.create_job(document_id, task_id)
|
return self.repository.create_job(document_id, task_id)
|
||||||
|
|
||||||
async def get_job_by_id(self, job_id: PyObjectId) -> ProcessingJob:
|
def get_job_by_id(self, job_id: PyObjectId) -> ProcessingJob:
|
||||||
"""
|
"""
|
||||||
Retrieve a job by its ID.
|
Retrieve a job by its ID.
|
||||||
|
|
||||||
@@ -65,9 +65,9 @@ class JobService:
|
|||||||
JobNotFoundError: If job doesn't exist
|
JobNotFoundError: If job doesn't exist
|
||||||
JobRepositoryError: If database operation fails
|
JobRepositoryError: If database operation fails
|
||||||
"""
|
"""
|
||||||
return await self.repository.find_job_by_id(job_id)
|
return self.repository.find_job_by_id(job_id)
|
||||||
|
|
||||||
async def mark_job_as_started(self, job_id: PyObjectId) -> ProcessingJob:
|
def mark_job_as_started(self, job_id: PyObjectId) -> ProcessingJob:
|
||||||
"""
|
"""
|
||||||
Mark a job as started (PENDING → PROCESSING).
|
Mark a job as started (PENDING → PROCESSING).
|
||||||
|
|
||||||
@@ -83,16 +83,16 @@ class JobService:
|
|||||||
JobRepositoryError: If database operation fails
|
JobRepositoryError: If database operation fails
|
||||||
"""
|
"""
|
||||||
# Get current job to validate transition
|
# Get current job to validate transition
|
||||||
current_job = await self.repository.find_job_by_id(job_id)
|
current_job = self.repository.find_job_by_id(job_id)
|
||||||
|
|
||||||
# Validate status transition
|
# Validate status transition
|
||||||
if current_job.status != ProcessingStatus.PENDING:
|
if current_job.status != ProcessingStatus.PENDING:
|
||||||
raise InvalidStatusTransitionError(current_job.status, ProcessingStatus.PROCESSING)
|
raise InvalidStatusTransitionError(current_job.status, ProcessingStatus.PROCESSING)
|
||||||
|
|
||||||
# Update status
|
# Update status
|
||||||
return await self.repository.update_job_status(job_id, ProcessingStatus.PROCESSING)
|
return self.repository.update_job_status(job_id, ProcessingStatus.PROCESSING)
|
||||||
|
|
||||||
async def mark_job_as_completed(self, job_id: PyObjectId) -> ProcessingJob:
|
def mark_job_as_completed(self, job_id: PyObjectId) -> ProcessingJob:
|
||||||
"""
|
"""
|
||||||
Mark a job as completed (PROCESSING → COMPLETED).
|
Mark a job as completed (PROCESSING → COMPLETED).
|
||||||
|
|
||||||
@@ -108,16 +108,16 @@ class JobService:
|
|||||||
JobRepositoryError: If database operation fails
|
JobRepositoryError: If database operation fails
|
||||||
"""
|
"""
|
||||||
# Get current job to validate transition
|
# Get current job to validate transition
|
||||||
current_job = await self.repository.find_job_by_id(job_id)
|
current_job = self.repository.find_job_by_id(job_id)
|
||||||
|
|
||||||
# Validate status transition
|
# Validate status transition
|
||||||
if current_job.status != ProcessingStatus.PROCESSING:
|
if current_job.status != ProcessingStatus.PROCESSING:
|
||||||
raise InvalidStatusTransitionError(current_job.status, ProcessingStatus.COMPLETED)
|
raise InvalidStatusTransitionError(current_job.status, ProcessingStatus.COMPLETED)
|
||||||
|
|
||||||
# Update status
|
# Update status
|
||||||
return await self.repository.update_job_status(job_id, ProcessingStatus.COMPLETED)
|
return self.repository.update_job_status(job_id, ProcessingStatus.COMPLETED)
|
||||||
|
|
||||||
async def mark_job_as_failed(
|
def mark_job_as_failed(
|
||||||
self,
|
self,
|
||||||
job_id: PyObjectId,
|
job_id: PyObjectId,
|
||||||
error_message: Optional[str] = None
|
error_message: Optional[str] = None
|
||||||
@@ -138,20 +138,20 @@ class JobService:
|
|||||||
JobRepositoryError: If database operation fails
|
JobRepositoryError: If database operation fails
|
||||||
"""
|
"""
|
||||||
# Get current job to validate transition
|
# Get current job to validate transition
|
||||||
current_job = await self.repository.find_job_by_id(job_id)
|
current_job = self.repository.find_job_by_id(job_id)
|
||||||
|
|
||||||
# Validate status transition
|
# Validate status transition
|
||||||
if current_job.status != ProcessingStatus.PROCESSING:
|
if current_job.status != ProcessingStatus.PROCESSING:
|
||||||
raise InvalidStatusTransitionError(current_job.status, ProcessingStatus.FAILED)
|
raise InvalidStatusTransitionError(current_job.status, ProcessingStatus.FAILED)
|
||||||
|
|
||||||
# Update status with error message
|
# Update status with error message
|
||||||
return await self.repository.update_job_status(
|
return self.repository.update_job_status(
|
||||||
job_id,
|
job_id,
|
||||||
ProcessingStatus.FAILED,
|
ProcessingStatus.FAILED,
|
||||||
error_message
|
error_message
|
||||||
)
|
)
|
||||||
|
|
||||||
async def delete_job(self, job_id: PyObjectId) -> bool:
|
def delete_job(self, job_id: PyObjectId) -> bool:
|
||||||
"""
|
"""
|
||||||
Delete a job from the database.
|
Delete a job from the database.
|
||||||
|
|
||||||
@@ -164,9 +164,9 @@ class JobService:
|
|||||||
Raises:
|
Raises:
|
||||||
JobRepositoryError: If database operation fails
|
JobRepositoryError: If database operation fails
|
||||||
"""
|
"""
|
||||||
return await self.repository.delete_job(job_id)
|
return self.repository.delete_job(job_id)
|
||||||
|
|
||||||
async def get_jobs_by_status(self, status: ProcessingStatus) -> list[ProcessingJob]:
|
def get_jobs_by_status(self, status: ProcessingStatus) -> list[ProcessingJob]:
|
||||||
"""
|
"""
|
||||||
Retrieve all jobs with a specific status.
|
Retrieve all jobs with a specific status.
|
||||||
|
|
||||||
@@ -179,4 +179,4 @@ class JobService:
|
|||||||
Raises:
|
Raises:
|
||||||
JobRepositoryError: If database operation fails
|
JobRepositoryError: If database operation fails
|
||||||
"""
|
"""
|
||||||
return await self.repository.get_jobs_by_status(status)
|
return self.repository.get_jobs_by_status(status)
|
||||||
|
|||||||
@@ -33,11 +33,11 @@ class UserService:
|
|||||||
self.user_repository = UserRepository(self.db)
|
self.user_repository = UserRepository(self.db)
|
||||||
self.auth_service = AuthService()
|
self.auth_service = AuthService()
|
||||||
|
|
||||||
async def initialize(self):
|
def initialize(self):
|
||||||
await self.user_repository.initialize()
|
self.user_repository.initialize()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def create_user(self, user_data: UserCreate | UserCreateNoValidation) -> UserInDB:
|
def create_user(self, user_data: UserCreate | UserCreateNoValidation) -> UserInDB:
|
||||||
"""
|
"""
|
||||||
Create a new user with business logic validation.
|
Create a new user with business logic validation.
|
||||||
|
|
||||||
@@ -60,11 +60,11 @@ class UserService:
|
|||||||
raise ValueError(f"User with email '{user_data.email}' already exists")
|
raise ValueError(f"User with email '{user_data.email}' already exists")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return await self.user_repository.create_user(user_data)
|
return self.user_repository.create_user(user_data)
|
||||||
except DuplicateKeyError:
|
except DuplicateKeyError:
|
||||||
raise ValueError(f"User with username '{user_data.username}' already exists")
|
raise ValueError(f"User with username '{user_data.username}' already exists")
|
||||||
|
|
||||||
async def get_user_by_username(self, username: str) -> Optional[UserInDB]:
|
def get_user_by_username(self, username: str) -> Optional[UserInDB]:
|
||||||
"""
|
"""
|
||||||
Retrieve user by username.
|
Retrieve user by username.
|
||||||
|
|
||||||
@@ -74,9 +74,9 @@ class UserService:
|
|||||||
Returns:
|
Returns:
|
||||||
UserInDB or None: User if found, None otherwise
|
UserInDB or None: User if found, None otherwise
|
||||||
"""
|
"""
|
||||||
return await self.user_repository.find_user_by_username(username)
|
return self.user_repository.find_user_by_username(username)
|
||||||
|
|
||||||
async def get_user_by_id(self, user_id: str) -> Optional[UserInDB]:
|
def get_user_by_id(self, user_id: str) -> Optional[UserInDB]:
|
||||||
"""
|
"""
|
||||||
Retrieve user by ID.
|
Retrieve user by ID.
|
||||||
|
|
||||||
@@ -86,9 +86,9 @@ class UserService:
|
|||||||
Returns:
|
Returns:
|
||||||
UserInDB or None: User if found, None otherwise
|
UserInDB or None: User if found, None otherwise
|
||||||
"""
|
"""
|
||||||
return await self.user_repository.find_user_by_id(user_id)
|
return self.user_repository.find_user_by_id(user_id)
|
||||||
|
|
||||||
async def authenticate_user(self, username: str, password: str) -> Optional[UserInDB]:
|
def authenticate_user(self, username: str, password: str) -> Optional[UserInDB]:
|
||||||
"""
|
"""
|
||||||
Authenticate user with username and password.
|
Authenticate user with username and password.
|
||||||
|
|
||||||
@@ -111,7 +111,7 @@ class UserService:
|
|||||||
|
|
||||||
return user
|
return user
|
||||||
|
|
||||||
async def update_user(self, user_id: str, user_update: UserUpdate) -> Optional[UserInDB]:
|
def update_user(self, user_id: str, user_update: UserUpdate) -> Optional[UserInDB]:
|
||||||
"""
|
"""
|
||||||
Update user information.
|
Update user information.
|
||||||
|
|
||||||
@@ -137,9 +137,9 @@ class UserService:
|
|||||||
if existing_user and str(existing_user.id) != user_id:
|
if existing_user and str(existing_user.id) != user_id:
|
||||||
raise ValueError(f"Email '{user_update.email}' is already taken")
|
raise ValueError(f"Email '{user_update.email}' is already taken")
|
||||||
|
|
||||||
return await self.user_repository.update_user(user_id, user_update)
|
return self.user_repository.update_user(user_id, user_update)
|
||||||
|
|
||||||
async def delete_user(self, user_id: str) -> bool:
|
def delete_user(self, user_id: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Delete user from system.
|
Delete user from system.
|
||||||
|
|
||||||
@@ -151,7 +151,7 @@ class UserService:
|
|||||||
"""
|
"""
|
||||||
return self.user_repository.delete_user(user_id)
|
return self.user_repository.delete_user(user_id)
|
||||||
|
|
||||||
async def list_users(self, skip: int = 0, limit: int = 100) -> List[UserInDB]:
|
def list_users(self, skip: int = 0, limit: int = 100) -> List[UserInDB]:
|
||||||
"""
|
"""
|
||||||
List users with pagination.
|
List users with pagination.
|
||||||
|
|
||||||
@@ -162,18 +162,18 @@ class UserService:
|
|||||||
Returns:
|
Returns:
|
||||||
List[UserInDB]: List of users
|
List[UserInDB]: List of users
|
||||||
"""
|
"""
|
||||||
return await self.user_repository.list_users(skip=skip, limit=limit)
|
return self.user_repository.list_users(skip=skip, limit=limit)
|
||||||
|
|
||||||
async def count_users(self) -> int:
|
def count_users(self) -> int:
|
||||||
"""
|
"""
|
||||||
Count total number of users.
|
Count total number of users.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
int: Total number of users in system
|
int: Total number of users in system
|
||||||
"""
|
"""
|
||||||
return await self.user_repository.count_users()
|
return self.user_repository.count_users()
|
||||||
|
|
||||||
async def user_exists(self, username: str) -> bool:
|
def user_exists(self, username: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if user exists by username.
|
Check if user exists by username.
|
||||||
|
|
||||||
@@ -183,4 +183,4 @@ class UserService:
|
|||||||
Returns:
|
Returns:
|
||||||
bool: True if user exists, False otherwise
|
bool: True if user exists, False otherwise
|
||||||
"""
|
"""
|
||||||
return await self.user_repository.user_exists(username)
|
return self.user_repository.user_exists(username)
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
asgiref==3.9.1
|
||||||
bcrypt==4.3.0
|
bcrypt==4.3.0
|
||||||
celery==5.5.3
|
celery==5.5.3
|
||||||
email-validator==2.3.0
|
email-validator==2.3.0
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
asgiref==3.9.1
|
||||||
bcrypt==4.3.0
|
bcrypt==4.3.0
|
||||||
celery==5.5.3
|
celery==5.5.3
|
||||||
email-validator==2.3.0
|
email-validator==2.3.0
|
||||||
|
|||||||
@@ -11,11 +11,12 @@ from typing import Any, Dict
|
|||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.database.connection import get_database
|
from app.database.connection import get_database
|
||||||
from app.services.document_service import DocumentService
|
from app.services.document_service import DocumentService
|
||||||
|
from tasks.main import celery_app
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@celery_app.task(bind=True, autoretry_for=(Exception,), retry_kwargs={'max_retries': 3, 'countdown': 60})
|
||||||
async def process_document_async(self, filepath: str) -> Dict[str, Any]:
|
def process_document(self, filepath: str) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Process a document file and extract its content.
|
Process a document file and extract its content.
|
||||||
|
|
||||||
@@ -45,18 +46,18 @@ async def process_document_async(self, filepath: str) -> Dict[str, Any]:
|
|||||||
job = None
|
job = None
|
||||||
try:
|
try:
|
||||||
# Step 1: Insert the document in DB
|
# Step 1: Insert the document in DB
|
||||||
document = await document_service.create_document(filepath)
|
document = document_service.create_document(filepath)
|
||||||
logger.info(f"Job {task_id} created for document {document.id} with file path: {filepath}")
|
logger.info(f"Job {task_id} created for document {document.id} with file path: {filepath}")
|
||||||
|
|
||||||
# Step 2: Create a new job record for the document
|
# Step 2: Create a new job record for the document
|
||||||
job = await job_service.create_job(task_id=task_id, document_id=document.id)
|
job = job_service.create_job(task_id=task_id, document_id=document.id)
|
||||||
|
|
||||||
# Step 3: Mark job as started
|
# Step 3: Mark job as started
|
||||||
await job_service.mark_job_as_started(job_id=job.id)
|
job_service.mark_job_as_started(job_id=job.id)
|
||||||
logger.info(f"Job {task_id} marked as PROCESSING")
|
logger.info(f"Job {task_id} marked as PROCESSING")
|
||||||
|
|
||||||
# Step 4: Mark job as completed
|
# Step 4: Mark job as completed
|
||||||
await job_service.mark_job_as_completed(job_id=job.id)
|
job_service.mark_job_as_completed(job_id=job.id)
|
||||||
logger.info(f"Job {task_id} marked as COMPLETED")
|
logger.info(f"Job {task_id} marked as COMPLETED")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -72,7 +73,7 @@ async def process_document_async(self, filepath: str) -> Dict[str, Any]:
|
|||||||
try:
|
try:
|
||||||
# Mark job as failed
|
# Mark job as failed
|
||||||
if job is not None:
|
if job is not None:
|
||||||
await job_service.mark_job_as_failed(job_id=job.id, error_message=error_message)
|
job_service.mark_job_as_failed(job_id=job.id, error_message=error_message)
|
||||||
logger.info(f"Job {task_id} marked as FAILED")
|
logger.info(f"Job {task_id} marked as FAILED")
|
||||||
else:
|
else:
|
||||||
logger.error(f"Failed to process {filepath}. error = {str(e)}")
|
logger.error(f"Failed to process {filepath}. error = {str(e)}")
|
||||||
@@ -81,3 +82,4 @@ async def process_document_async(self, filepath: str) -> Dict[str, Any]:
|
|||||||
|
|
||||||
# Re-raise the exception to trigger Celery retry mechanism
|
# Re-raise the exception to trigger Celery retry mechanism
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|||||||
@@ -3,13 +3,10 @@ Celery worker for MyDocManager document processing tasks.
|
|||||||
|
|
||||||
This module contains all Celery tasks for processing documents.
|
This module contains all Celery tasks for processing documents.
|
||||||
"""
|
"""
|
||||||
import asyncio
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from celery import Celery
|
from celery import Celery
|
||||||
|
|
||||||
from tasks.document_processing import process_document_async
|
|
||||||
|
|
||||||
# Environment variables
|
# Environment variables
|
||||||
REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0")
|
REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0")
|
||||||
MONGODB_URL = os.getenv("MONGODB_URL", "mongodb://localhost:27017")
|
MONGODB_URL = os.getenv("MONGODB_URL", "mongodb://localhost:27017")
|
||||||
@@ -21,6 +18,8 @@ celery_app = Celery(
|
|||||||
backend=REDIS_URL,
|
backend=REDIS_URL,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
celery_app.autodiscover_tasks(["tasks.document_processing"])
|
||||||
|
|
||||||
# Celery configuration
|
# Celery configuration
|
||||||
celery_app.conf.update(
|
celery_app.conf.update(
|
||||||
task_serializer="json",
|
task_serializer="json",
|
||||||
@@ -33,11 +32,5 @@ celery_app.conf.update(
|
|||||||
task_soft_time_limit=240, # 4 minutes
|
task_soft_time_limit=240, # 4 minutes
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(bind=True, autoretry_for=(Exception,), retry_kwargs={'max_retries': 3, 'countdown': 60})
|
|
||||||
def process_document(self, filepath: str):
|
|
||||||
return asyncio.run(process_document_async(self, filepath))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
celery_app.start()
|
celery_app.start()
|
||||||
|
|||||||
@@ -1,18 +1,16 @@
|
|||||||
"""
|
"""
|
||||||
Test suite for FileDocumentRepository with async/await support.
|
Test suite for FileDocumentRepository with async/support.
|
||||||
|
|
||||||
This module contains comprehensive tests for all FileDocumentRepository methods
|
This module contains comprehensive tests for all FileDocumentRepository methods
|
||||||
using mongomock-motor for in-memory MongoDB testing.
|
using mongomock-motor for in-memory MongoDB testing.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Dict, Any
|
|
||||||
|
|
||||||
import pytest_asyncio
|
import pytest
|
||||||
from bson import ObjectId
|
from bson import ObjectId
|
||||||
from pymongo.errors import DuplicateKeyError, PyMongoError
|
from mongomock.mongo_client import MongoClient
|
||||||
from mongomock_motor import AsyncMongoMockClient
|
from pymongo.errors import PyMongoError
|
||||||
|
|
||||||
from app.database.repositories.document_repository import (
|
from app.database.repositories.document_repository import (
|
||||||
FileDocumentRepository,
|
FileDocumentRepository,
|
||||||
@@ -23,13 +21,13 @@ from app.database.repositories.document_repository import (
|
|||||||
from app.models.document import FileDocument, FileType, ExtractionMethod
|
from app.models.document import FileDocument, FileType, ExtractionMethod
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
@pytest.fixture
|
||||||
async def in_memory_repository():
|
def in_memory_repository():
|
||||||
"""Create an in-memory FileDocumentRepository for testing."""
|
"""Create an in-memory FileDocumentRepository for testing."""
|
||||||
client = AsyncMongoMockClient()
|
client = MongoClient()
|
||||||
db = client.test_database
|
db = client.test_database
|
||||||
repo = FileDocumentRepository(db)
|
repo = FileDocumentRepository(db)
|
||||||
await repo.initialize()
|
repo.initialize()
|
||||||
return repo
|
return repo
|
||||||
|
|
||||||
|
|
||||||
@@ -107,14 +105,13 @@ def multiple_sample_files():
|
|||||||
class TestFileDocumentRepositoryInitialization:
|
class TestFileDocumentRepositoryInitialization:
|
||||||
"""Tests for repository initialization."""
|
"""Tests for repository initialization."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_initialize_repository(self):
|
||||||
async def test_i_can_initialize_repository(self):
|
|
||||||
"""Test repository initialization."""
|
"""Test repository initialization."""
|
||||||
# Arrange
|
# Arrange
|
||||||
client = AsyncMongoMockClient()
|
client = MongoClient()
|
||||||
db = client.test_database
|
db = client.test_database
|
||||||
repo = FileDocumentRepository(db)
|
repo = FileDocumentRepository(db)
|
||||||
await repo.initialize()
|
repo.initialize()
|
||||||
|
|
||||||
# Act & Assert (should not raise any exception)
|
# Act & Assert (should not raise any exception)
|
||||||
assert repo.db is not None
|
assert repo.db is not None
|
||||||
@@ -125,11 +122,10 @@ class TestFileDocumentRepositoryInitialization:
|
|||||||
class TestFileDocumentRepositoryCreation:
|
class TestFileDocumentRepositoryCreation:
|
||||||
"""Tests for file document creation functionality."""
|
"""Tests for file document creation functionality."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_create_file_document(self, in_memory_repository, sample_file_document):
|
||||||
async def test_i_can_create_file_document(self, in_memory_repository, sample_file_document):
|
|
||||||
"""Test successful file document creation."""
|
"""Test successful file document creation."""
|
||||||
# Act
|
# Act
|
||||||
created_file = await in_memory_repository.create_document(sample_file_document)
|
created_file = in_memory_repository.create_document(sample_file_document)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert created_file is not None
|
assert created_file is not None
|
||||||
@@ -144,46 +140,20 @@ class TestFileDocumentRepositoryCreation:
|
|||||||
assert created_file.id is not None
|
assert created_file.id is not None
|
||||||
assert isinstance(created_file.id, ObjectId)
|
assert isinstance(created_file.id, ObjectId)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_create_file_document_without_id(self, in_memory_repository, sample_file_document):
|
||||||
async def test_i_can_create_file_document_without_id(self, in_memory_repository, sample_file_document):
|
|
||||||
"""Test creating file document with _id set to None (should be removed)."""
|
"""Test creating file document with _id set to None (should be removed)."""
|
||||||
# Arrange
|
# Arrange
|
||||||
sample_file_document.id = None
|
sample_file_document.id = None
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
created_file = await in_memory_repository.create_document(sample_file_document)
|
created_file = in_memory_repository.create_document(sample_file_document)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert created_file is not None
|
assert created_file is not None
|
||||||
assert created_file.id is not None
|
assert created_file.id is not None
|
||||||
assert isinstance(created_file.id, ObjectId)
|
assert isinstance(created_file.id, ObjectId)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_create_file_document_with_pymongo_error(self, in_memory_repository,
|
||||||
async def test_i_cannot_create_duplicate_file_document(self, in_memory_repository, sample_file_document):
|
|
||||||
"""Test that creating file document with duplicate filepath raises DuplicateKeyError."""
|
|
||||||
# Arrange
|
|
||||||
await in_memory_repository.create_document(sample_file_document)
|
|
||||||
duplicate_file = FileDocument(
|
|
||||||
filename="different_name.pdf",
|
|
||||||
filepath=sample_file_document.filepath, # Same filepath
|
|
||||||
file_type=FileType.PDF,
|
|
||||||
extraction_method=ExtractionMethod.OCR,
|
|
||||||
metadata={"different": "metadata"},
|
|
||||||
detected_at=datetime.now(),
|
|
||||||
file_hash="different_hash_123456789012345678901234567890123456789012345678",
|
|
||||||
encoding="utf-8",
|
|
||||||
file_size=2000,
|
|
||||||
mime_type="application/pdf"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Act & Assert
|
|
||||||
with pytest.raises(DuplicateKeyError) as exc_info:
|
|
||||||
await in_memory_repository.create_document(duplicate_file)
|
|
||||||
|
|
||||||
assert "already exists" in str(exc_info.value)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_i_cannot_create_file_document_with_pymongo_error(self, in_memory_repository,
|
|
||||||
sample_file_document, mocker):
|
sample_file_document, mocker):
|
||||||
"""Test handling of PyMongo errors during file document creation."""
|
"""Test handling of PyMongo errors during file document creation."""
|
||||||
# Arrange
|
# Arrange
|
||||||
@@ -191,7 +161,7 @@ class TestFileDocumentRepositoryCreation:
|
|||||||
|
|
||||||
# Act & Assert
|
# Act & Assert
|
||||||
with pytest.raises(ValueError) as exc_info:
|
with pytest.raises(ValueError) as exc_info:
|
||||||
await in_memory_repository.create_document(sample_file_document)
|
in_memory_repository.create_document(sample_file_document)
|
||||||
|
|
||||||
assert "Failed to create file document" in str(exc_info.value)
|
assert "Failed to create file document" in str(exc_info.value)
|
||||||
|
|
||||||
@@ -199,14 +169,13 @@ class TestFileDocumentRepositoryCreation:
|
|||||||
class TestFileDocumentRepositoryFinding:
|
class TestFileDocumentRepositoryFinding:
|
||||||
"""Tests for file document finding functionality."""
|
"""Tests for file document finding functionality."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_find_document_by_valid_id(self, in_memory_repository, sample_file_document):
|
||||||
async def test_i_can_find_document_by_valid_id(self, in_memory_repository, sample_file_document):
|
|
||||||
"""Test finding file document by valid ObjectId."""
|
"""Test finding file document by valid ObjectId."""
|
||||||
# Arrange
|
# Arrange
|
||||||
created_file = await in_memory_repository.create_document(sample_file_document)
|
created_file = in_memory_repository.create_document(sample_file_document)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
found_file = await in_memory_repository.find_document_by_id(str(created_file.id))
|
found_file = in_memory_repository.find_document_by_id(str(created_file.id))
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert found_file is not None
|
assert found_file is not None
|
||||||
@@ -214,81 +183,74 @@ class TestFileDocumentRepositoryFinding:
|
|||||||
assert found_file.filename == created_file.filename
|
assert found_file.filename == created_file.filename
|
||||||
assert found_file.filepath == created_file.filepath
|
assert found_file.filepath == created_file.filepath
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_find_document_with_invalid_id(self, in_memory_repository):
|
||||||
async def test_i_cannot_find_document_with_invalid_id(self, in_memory_repository):
|
|
||||||
"""Test that invalid ObjectId returns None."""
|
"""Test that invalid ObjectId returns None."""
|
||||||
# Act
|
# Act
|
||||||
found_file = await in_memory_repository.find_document_by_id("invalid_id")
|
found_file = in_memory_repository.find_document_by_id("invalid_id")
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert found_file is None
|
assert found_file is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_find_document_by_nonexistent_id(self, in_memory_repository):
|
||||||
async def test_i_cannot_find_document_by_nonexistent_id(self, in_memory_repository):
|
|
||||||
"""Test that nonexistent but valid ObjectId returns None."""
|
"""Test that nonexistent but valid ObjectId returns None."""
|
||||||
# Arrange
|
# Arrange
|
||||||
nonexistent_id = str(ObjectId())
|
nonexistent_id = str(ObjectId())
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
found_file = await in_memory_repository.find_document_by_id(nonexistent_id)
|
found_file = in_memory_repository.find_document_by_id(nonexistent_id)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert found_file is None
|
assert found_file is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_find_document_by_file_hash(self, in_memory_repository, sample_file_document):
|
||||||
async def test_i_can_find_document_by_file_hash(self, in_memory_repository, sample_file_document):
|
|
||||||
"""Test finding file document by file hash."""
|
"""Test finding file document by file hash."""
|
||||||
# Arrange
|
# Arrange
|
||||||
created_file = await in_memory_repository.create_document(sample_file_document)
|
created_file = in_memory_repository.create_document(sample_file_document)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
found_file = await in_memory_repository.find_document_by_hash(sample_file_document.file_hash)
|
found_file = in_memory_repository.find_document_by_hash(sample_file_document.file_hash)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert found_file is not None
|
assert found_file is not None
|
||||||
assert found_file.file_hash == created_file.file_hash
|
assert found_file.file_hash == created_file.file_hash
|
||||||
assert found_file.id == created_file.id
|
assert found_file.id == created_file.id
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_find_document_with_nonexistent_file_hash(self, in_memory_repository):
|
||||||
async def test_i_cannot_find_document_with_nonexistent_file_hash(self, in_memory_repository):
|
|
||||||
"""Test that nonexistent file hash returns None."""
|
"""Test that nonexistent file hash returns None."""
|
||||||
# Act
|
# Act
|
||||||
found_file = await in_memory_repository.find_document_by_hash("nonexistent_hash")
|
found_file = in_memory_repository.find_document_by_hash("nonexistent_hash")
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert found_file is None
|
assert found_file is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_find_document_by_filepath(self, in_memory_repository, sample_file_document):
|
||||||
async def test_i_can_find_document_by_filepath(self, in_memory_repository, sample_file_document):
|
|
||||||
"""Test finding file document by filepath."""
|
"""Test finding file document by filepath."""
|
||||||
# Arrange
|
# Arrange
|
||||||
created_file = await in_memory_repository.create_document(sample_file_document)
|
created_file = in_memory_repository.create_document(sample_file_document)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
found_file = await in_memory_repository.find_document_by_filepath(sample_file_document.filepath)
|
found_file = in_memory_repository.find_document_by_filepath(sample_file_document.filepath)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert found_file is not None
|
assert found_file is not None
|
||||||
assert found_file.filepath == created_file.filepath
|
assert found_file.filepath == created_file.filepath
|
||||||
assert found_file.id == created_file.id
|
assert found_file.id == created_file.id
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_find_document_with_nonexistent_filepath(self, in_memory_repository):
|
||||||
async def test_i_cannot_find_document_with_nonexistent_filepath(self, in_memory_repository):
|
|
||||||
"""Test that nonexistent filepath returns None."""
|
"""Test that nonexistent filepath returns None."""
|
||||||
# Act
|
# Act
|
||||||
found_file = await in_memory_repository.find_document_by_filepath("/nonexistent/path/file.pdf")
|
found_file = in_memory_repository.find_document_by_filepath("/nonexistent/path/file.pdf")
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert found_file is None
|
assert found_file is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_find_document_with_pymongo_error(self, in_memory_repository, mocker):
|
||||||
async def test_i_cannot_find_document_with_pymongo_error(self, in_memory_repository, mocker):
|
|
||||||
"""Test handling of PyMongo errors during file document finding."""
|
"""Test handling of PyMongo errors during file document finding."""
|
||||||
# Arrange
|
# Arrange
|
||||||
mocker.patch.object(in_memory_repository.collection, 'find_one', side_effect=PyMongoError("Database error"))
|
mocker.patch.object(in_memory_repository.collection, 'find_one', side_effect=PyMongoError("Database error"))
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
found_file = await in_memory_repository.find_document_by_hash("test_hash")
|
found_file = in_memory_repository.find_document_by_hash("test_hash")
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert found_file is None
|
assert found_file is None
|
||||||
@@ -297,16 +259,15 @@ class TestFileDocumentRepositoryFinding:
|
|||||||
class TestFileDocumentRepositoryNameMatching:
|
class TestFileDocumentRepositoryNameMatching:
|
||||||
"""Tests for file document name matching functionality."""
|
"""Tests for file document name matching functionality."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_find_documents_by_name_with_fuzzy_matching(self, in_memory_repository, multiple_sample_files):
|
||||||
async def test_i_can_find_documents_by_name_with_fuzzy_matching(self, in_memory_repository, multiple_sample_files):
|
|
||||||
"""Test finding file documents by filename using fuzzy matching."""
|
"""Test finding file documents by filename using fuzzy matching."""
|
||||||
# Arrange
|
# Arrange
|
||||||
for file_doc in multiple_sample_files:
|
for file_doc in multiple_sample_files:
|
||||||
await in_memory_repository.create_document(file_doc)
|
in_memory_repository.create_document(file_doc)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
fuzzy_method = FuzzyMatching(threshold=0.5)
|
fuzzy_method = FuzzyMatching(threshold=0.5)
|
||||||
found_files = await in_memory_repository.find_document_by_name("document", fuzzy_method)
|
found_files = in_memory_repository.find_document_by_name("document", fuzzy_method)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(found_files) >= 1
|
assert len(found_files) >= 1
|
||||||
@@ -315,44 +276,41 @@ class TestFileDocumentRepositoryNameMatching:
|
|||||||
found_filenames = [f.filename for f in found_files]
|
found_filenames = [f.filename for f in found_files]
|
||||||
assert any("document" in fname.lower() for fname in found_filenames)
|
assert any("document" in fname.lower() for fname in found_filenames)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_find_documents_by_name_with_subsequence_matching(self, in_memory_repository,
|
||||||
async def test_i_can_find_documents_by_name_with_subsequence_matching(self, in_memory_repository,
|
|
||||||
multiple_sample_files):
|
multiple_sample_files):
|
||||||
"""Test finding file documents by filename using subsequence matching."""
|
"""Test finding file documents by filename using subsequence matching."""
|
||||||
# Arrange
|
# Arrange
|
||||||
for file_doc in multiple_sample_files:
|
for file_doc in multiple_sample_files:
|
||||||
await in_memory_repository.create_document(file_doc)
|
in_memory_repository.create_document(file_doc)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
subsequence_method = SubsequenceMatching()
|
subsequence_method = SubsequenceMatching()
|
||||||
found_files = await in_memory_repository.find_document_by_name("doc", subsequence_method)
|
found_files = in_memory_repository.find_document_by_name("doc", subsequence_method)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(found_files) >= 1
|
assert len(found_files) >= 1
|
||||||
assert all(isinstance(file_doc, FileDocument) for file_doc in found_files)
|
assert all(isinstance(file_doc, FileDocument) for file_doc in found_files)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_find_documents_by_name_with_default_method(self, in_memory_repository, multiple_sample_files):
|
||||||
async def test_i_can_find_documents_by_name_with_default_method(self, in_memory_repository, multiple_sample_files):
|
|
||||||
"""Test finding file documents by filename with default matching method."""
|
"""Test finding file documents by filename with default matching method."""
|
||||||
# Arrange
|
# Arrange
|
||||||
for file_doc in multiple_sample_files:
|
for file_doc in multiple_sample_files:
|
||||||
await in_memory_repository.create_document(file_doc)
|
in_memory_repository.create_document(file_doc)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
found_files = await in_memory_repository.find_document_by_name("first")
|
found_files = in_memory_repository.find_document_by_name("first")
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(found_files) >= 0
|
assert len(found_files) >= 0
|
||||||
assert all(isinstance(file_doc, FileDocument) for file_doc in found_files)
|
assert all(isinstance(file_doc, FileDocument) for file_doc in found_files)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_find_documents_by_name_with_pymongo_error(self, in_memory_repository, mocker):
|
||||||
async def test_i_cannot_find_documents_by_name_with_pymongo_error(self, in_memory_repository, mocker):
|
|
||||||
"""Test handling of PyMongo errors during document name matching."""
|
"""Test handling of PyMongo errors during document name matching."""
|
||||||
# Arrange
|
# Arrange
|
||||||
mocker.patch.object(in_memory_repository.collection, 'find', side_effect=PyMongoError("Database error"))
|
mocker.patch.object(in_memory_repository.collection, 'find', side_effect=PyMongoError("Database error"))
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
found_files = await in_memory_repository.find_document_by_name("test")
|
found_files = in_memory_repository.find_document_by_name("test")
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert found_files == []
|
assert found_files == []
|
||||||
@@ -361,30 +319,28 @@ class TestFileDocumentRepositoryNameMatching:
|
|||||||
class TestFileDocumentRepositoryListing:
|
class TestFileDocumentRepositoryListing:
|
||||||
"""Tests for file document listing functionality."""
|
"""Tests for file document listing functionality."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_list_documents_with_default_pagination(self, in_memory_repository, multiple_sample_files):
|
||||||
async def test_i_can_list_documents_with_default_pagination(self, in_memory_repository, multiple_sample_files):
|
|
||||||
"""Test listing file documents with default pagination."""
|
"""Test listing file documents with default pagination."""
|
||||||
# Arrange
|
# Arrange
|
||||||
for file_doc in multiple_sample_files:
|
for file_doc in multiple_sample_files:
|
||||||
await in_memory_repository.create_document(file_doc)
|
in_memory_repository.create_document(file_doc)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
files = await in_memory_repository.list_documents()
|
files = in_memory_repository.list_documents()
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(files) == len(multiple_sample_files)
|
assert len(files) == len(multiple_sample_files)
|
||||||
assert all(isinstance(file_doc, FileDocument) for file_doc in files)
|
assert all(isinstance(file_doc, FileDocument) for file_doc in files)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_list_documents_with_custom_pagination(self, in_memory_repository, multiple_sample_files):
|
||||||
async def test_i_can_list_documents_with_custom_pagination(self, in_memory_repository, multiple_sample_files):
|
|
||||||
"""Test listing file documents with custom pagination."""
|
"""Test listing file documents with custom pagination."""
|
||||||
# Arrange
|
# Arrange
|
||||||
for file_doc in multiple_sample_files:
|
for file_doc in multiple_sample_files:
|
||||||
await in_memory_repository.create_document(file_doc)
|
in_memory_repository.create_document(file_doc)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
files_page1 = await in_memory_repository.list_documents(skip=0, limit=2)
|
files_page1 = in_memory_repository.list_documents(skip=0, limit=2)
|
||||||
files_page2 = await in_memory_repository.list_documents(skip=2, limit=2)
|
files_page2 = in_memory_repository.list_documents(skip=2, limit=2)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(files_page1) == 2
|
assert len(files_page1) == 2
|
||||||
@@ -395,8 +351,7 @@ class TestFileDocumentRepositoryListing:
|
|||||||
page2_ids = [file_doc.id for file_doc in files_page2]
|
page2_ids = [file_doc.id for file_doc in files_page2]
|
||||||
assert len(set(page1_ids).intersection(set(page2_ids))) == 0
|
assert len(set(page1_ids).intersection(set(page2_ids))) == 0
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_list_documents_sorted_by_detected_at(self, in_memory_repository, sample_file_document):
|
||||||
async def test_i_can_list_documents_sorted_by_detected_at(self, in_memory_repository, sample_file_document):
|
|
||||||
"""Test that file documents are sorted by detected_at in descending order."""
|
"""Test that file documents are sorted by detected_at in descending order."""
|
||||||
# Arrange
|
# Arrange
|
||||||
file1 = sample_file_document.model_copy()
|
file1 = sample_file_document.model_copy()
|
||||||
@@ -411,11 +366,11 @@ class TestFileDocumentRepositoryListing:
|
|||||||
file2.file_hash = "hash2" + "0" * 58
|
file2.file_hash = "hash2" + "0" * 58
|
||||||
file2.detected_at = datetime(2024, 1, 2, 10, 0, 0) # Later date
|
file2.detected_at = datetime(2024, 1, 2, 10, 0, 0) # Later date
|
||||||
|
|
||||||
created_file1 = await in_memory_repository.create_document(file1)
|
created_file1 = in_memory_repository.create_document(file1)
|
||||||
created_file2 = await in_memory_repository.create_document(file2)
|
created_file2 = in_memory_repository.create_document(file2)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
files = await in_memory_repository.list_documents()
|
files = in_memory_repository.list_documents()
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(files) == 2
|
assert len(files) == 2
|
||||||
@@ -423,23 +378,21 @@ class TestFileDocumentRepositoryListing:
|
|||||||
assert files[0].id == created_file2.id
|
assert files[0].id == created_file2.id
|
||||||
assert files[1].id == created_file1.id
|
assert files[1].id == created_file1.id
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_list_empty_documents(self, in_memory_repository):
|
||||||
async def test_i_can_list_empty_documents(self, in_memory_repository):
|
|
||||||
"""Test listing file documents from empty collection."""
|
"""Test listing file documents from empty collection."""
|
||||||
# Act
|
# Act
|
||||||
files = await in_memory_repository.list_documents()
|
files = in_memory_repository.list_documents()
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert files == []
|
assert files == []
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_list_documents_with_pymongo_error(self, in_memory_repository, mocker):
|
||||||
async def test_i_cannot_list_documents_with_pymongo_error(self, in_memory_repository, mocker):
|
|
||||||
"""Test handling of PyMongo errors during file document listing."""
|
"""Test handling of PyMongo errors during file document listing."""
|
||||||
# Arrange
|
# Arrange
|
||||||
mocker.patch.object(in_memory_repository.collection, 'find', side_effect=PyMongoError("Database error"))
|
mocker.patch.object(in_memory_repository.collection, 'find', side_effect=PyMongoError("Database error"))
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
files = await in_memory_repository.list_documents()
|
files = in_memory_repository.list_documents()
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert files == []
|
assert files == []
|
||||||
@@ -448,15 +401,14 @@ class TestFileDocumentRepositoryListing:
|
|||||||
class TestFileDocumentRepositoryUpdate:
|
class TestFileDocumentRepositoryUpdate:
|
||||||
"""Tests for file document update functionality."""
|
"""Tests for file document update functionality."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_update_document_successfully(self, in_memory_repository, sample_file_document,
|
||||||
async def test_i_can_update_document_successfully(self, in_memory_repository, sample_file_document,
|
|
||||||
sample_update_data):
|
sample_update_data):
|
||||||
"""Test successful file document update."""
|
"""Test successful file document update."""
|
||||||
# Arrange
|
# Arrange
|
||||||
created_file = await in_memory_repository.create_document(sample_file_document)
|
created_file = in_memory_repository.create_document(sample_file_document)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
updated_file = await in_memory_repository.update_document(str(created_file.id), sample_update_data)
|
updated_file = in_memory_repository.update_document(str(created_file.id), sample_update_data)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert updated_file is not None
|
assert updated_file is not None
|
||||||
@@ -467,15 +419,14 @@ class TestFileDocumentRepositoryUpdate:
|
|||||||
assert updated_file.filename == created_file.filename # Unchanged fields remain
|
assert updated_file.filename == created_file.filename # Unchanged fields remain
|
||||||
assert updated_file.filepath == created_file.filepath
|
assert updated_file.filepath == created_file.filepath
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_update_document_with_partial_data(self, in_memory_repository, sample_file_document):
|
||||||
async def test_i_can_update_document_with_partial_data(self, in_memory_repository, sample_file_document):
|
|
||||||
"""Test updating file document with partial data."""
|
"""Test updating file document with partial data."""
|
||||||
# Arrange
|
# Arrange
|
||||||
created_file = await in_memory_repository.create_document(sample_file_document)
|
created_file = in_memory_repository.create_document(sample_file_document)
|
||||||
partial_update = {"file_size": 999999}
|
partial_update = {"file_size": 999999}
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
updated_file = await in_memory_repository.update_document(str(created_file.id), partial_update)
|
updated_file = in_memory_repository.update_document(str(created_file.id), partial_update)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert updated_file is not None
|
assert updated_file is not None
|
||||||
@@ -483,30 +434,28 @@ class TestFileDocumentRepositoryUpdate:
|
|||||||
assert updated_file.filename == created_file.filename # Should remain unchanged
|
assert updated_file.filename == created_file.filename # Should remain unchanged
|
||||||
assert updated_file.metadata == created_file.metadata # Should remain unchanged
|
assert updated_file.metadata == created_file.metadata # Should remain unchanged
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_update_document_filtering_none_values(self, in_memory_repository, sample_file_document):
|
||||||
async def test_i_can_update_document_filtering_none_values(self, in_memory_repository, sample_file_document):
|
|
||||||
"""Test that None values are filtered out from update data."""
|
"""Test that None values are filtered out from update data."""
|
||||||
# Arrange
|
# Arrange
|
||||||
created_file = await in_memory_repository.create_document(sample_file_document)
|
created_file = in_memory_repository.create_document(sample_file_document)
|
||||||
update_with_none = {"file_size": 777777, "metadata": None}
|
update_with_none = {"file_size": 777777, "metadata": None}
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
updated_file = await in_memory_repository.update_document(str(created_file.id), update_with_none)
|
updated_file = in_memory_repository.update_document(str(created_file.id), update_with_none)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert updated_file is not None
|
assert updated_file is not None
|
||||||
assert updated_file.file_size == 777777
|
assert updated_file.file_size == 777777
|
||||||
assert updated_file.metadata == created_file.metadata # Should remain unchanged (None filtered out)
|
assert updated_file.metadata == created_file.metadata # Should remain unchanged (None filtered out)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_update_document_with_empty_data(self, in_memory_repository, sample_file_document):
|
||||||
async def test_i_can_update_document_with_empty_data(self, in_memory_repository, sample_file_document):
|
|
||||||
"""Test updating file document with empty data returns current document."""
|
"""Test updating file document with empty data returns current document."""
|
||||||
# Arrange
|
# Arrange
|
||||||
created_file = await in_memory_repository.create_document(sample_file_document)
|
created_file = in_memory_repository.create_document(sample_file_document)
|
||||||
empty_update = {}
|
empty_update = {}
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
result = await in_memory_repository.update_document(str(created_file.id), empty_update)
|
result = in_memory_repository.update_document(str(created_file.id), empty_update)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert result is not None
|
assert result is not None
|
||||||
@@ -514,38 +463,35 @@ class TestFileDocumentRepositoryUpdate:
|
|||||||
assert result.filepath == created_file.filepath
|
assert result.filepath == created_file.filepath
|
||||||
assert result.metadata == created_file.metadata
|
assert result.metadata == created_file.metadata
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_update_document_with_invalid_id(self, in_memory_repository, sample_update_data):
|
||||||
async def test_i_cannot_update_document_with_invalid_id(self, in_memory_repository, sample_update_data):
|
|
||||||
"""Test that updating with invalid ID returns None."""
|
"""Test that updating with invalid ID returns None."""
|
||||||
# Act
|
# Act
|
||||||
result = await in_memory_repository.update_document("invalid_id", sample_update_data)
|
result = in_memory_repository.update_document("invalid_id", sample_update_data)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_update_nonexistent_document(self, in_memory_repository, sample_update_data):
|
||||||
async def test_i_cannot_update_nonexistent_document(self, in_memory_repository, sample_update_data):
|
|
||||||
"""Test that updating nonexistent file document returns None."""
|
"""Test that updating nonexistent file document returns None."""
|
||||||
# Arrange
|
# Arrange
|
||||||
nonexistent_id = str(ObjectId())
|
nonexistent_id = str(ObjectId())
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
result = await in_memory_repository.update_document(nonexistent_id, sample_update_data)
|
result = in_memory_repository.update_document(nonexistent_id, sample_update_data)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_update_document_with_pymongo_error(self, in_memory_repository, sample_file_document,
|
||||||
async def test_i_cannot_update_document_with_pymongo_error(self, in_memory_repository, sample_file_document,
|
|
||||||
sample_update_data, mocker):
|
sample_update_data, mocker):
|
||||||
"""Test handling of PyMongo errors during file document update."""
|
"""Test handling of PyMongo errors during file document update."""
|
||||||
# Arrange
|
# Arrange
|
||||||
created_file = await in_memory_repository.create_document(sample_file_document)
|
created_file = in_memory_repository.create_document(sample_file_document)
|
||||||
mocker.patch.object(in_memory_repository.collection, 'find_one_and_update',
|
mocker.patch.object(in_memory_repository.collection, 'find_one_and_update',
|
||||||
side_effect=PyMongoError("Database error"))
|
side_effect=PyMongoError("Database error"))
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
result = await in_memory_repository.update_document(str(created_file.id), sample_update_data)
|
result = in_memory_repository.update_document(str(created_file.id), sample_update_data)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert result is None
|
assert result is None
|
||||||
@@ -554,52 +500,48 @@ class TestFileDocumentRepositoryUpdate:
|
|||||||
class TestFileDocumentRepositoryDeletion:
|
class TestFileDocumentRepositoryDeletion:
|
||||||
"""Tests for file document deletion functionality."""
|
"""Tests for file document deletion functionality."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_delete_existing_document(self, in_memory_repository, sample_file_document):
|
||||||
async def test_i_can_delete_existing_document(self, in_memory_repository, sample_file_document):
|
|
||||||
"""Test successful file document deletion."""
|
"""Test successful file document deletion."""
|
||||||
# Arrange
|
# Arrange
|
||||||
created_file = await in_memory_repository.create_document(sample_file_document)
|
created_file = in_memory_repository.create_document(sample_file_document)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
deletion_result = await in_memory_repository.delete_document(str(created_file.id))
|
deletion_result = in_memory_repository.delete_document(str(created_file.id))
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert deletion_result is True
|
assert deletion_result is True
|
||||||
|
|
||||||
# Verify document is actually deleted
|
# Verify document is actually deleted
|
||||||
found_file = await in_memory_repository.find_document_by_id(str(created_file.id))
|
found_file = in_memory_repository.find_document_by_id(str(created_file.id))
|
||||||
assert found_file is None
|
assert found_file is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_delete_document_with_invalid_id(self, in_memory_repository):
|
||||||
async def test_i_cannot_delete_document_with_invalid_id(self, in_memory_repository):
|
|
||||||
"""Test that deleting with invalid ID returns False."""
|
"""Test that deleting with invalid ID returns False."""
|
||||||
# Act
|
# Act
|
||||||
result = await in_memory_repository.delete_document("invalid_id")
|
result = in_memory_repository.delete_document("invalid_id")
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_delete_nonexistent_document(self, in_memory_repository):
|
||||||
async def test_i_cannot_delete_nonexistent_document(self, in_memory_repository):
|
|
||||||
"""Test that deleting nonexistent file document returns False."""
|
"""Test that deleting nonexistent file document returns False."""
|
||||||
# Arrange
|
# Arrange
|
||||||
nonexistent_id = str(ObjectId())
|
nonexistent_id = str(ObjectId())
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
result = await in_memory_repository.delete_document(nonexistent_id)
|
result = in_memory_repository.delete_document(nonexistent_id)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_delete_document_with_pymongo_error(self, in_memory_repository, sample_file_document, mocker):
|
||||||
async def test_i_cannot_delete_document_with_pymongo_error(self, in_memory_repository, sample_file_document, mocker):
|
|
||||||
"""Test handling of PyMongo errors during file document deletion."""
|
"""Test handling of PyMongo errors during file document deletion."""
|
||||||
# Arrange
|
# Arrange
|
||||||
created_file = await in_memory_repository.create_document(sample_file_document)
|
created_file = in_memory_repository.create_document(sample_file_document)
|
||||||
mocker.patch.object(in_memory_repository.collection, 'delete_one', side_effect=PyMongoError("Database error"))
|
mocker.patch.object(in_memory_repository.collection, 'delete_one', side_effect=PyMongoError("Database error"))
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
result = await in_memory_repository.delete_document(str(created_file.id))
|
result = in_memory_repository.delete_document(str(created_file.id))
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert result is False
|
assert result is False
|
||||||
@@ -608,36 +550,33 @@ class TestFileDocumentRepositoryDeletion:
|
|||||||
class TestFileDocumentRepositoryUtilities:
|
class TestFileDocumentRepositoryUtilities:
|
||||||
"""Tests for utility methods."""
|
"""Tests for utility methods."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_count_documents(self, in_memory_repository, sample_file_document):
|
||||||
async def test_i_can_count_documents(self, in_memory_repository, sample_file_document):
|
|
||||||
"""Test counting file documents."""
|
"""Test counting file documents."""
|
||||||
# Arrange
|
# Arrange
|
||||||
initial_count = await in_memory_repository.count_documents()
|
initial_count = in_memory_repository.count_documents()
|
||||||
await in_memory_repository.create_document(sample_file_document)
|
in_memory_repository.create_document(sample_file_document)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
final_count = await in_memory_repository.count_documents()
|
final_count = in_memory_repository.count_documents()
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert final_count == initial_count + 1
|
assert final_count == initial_count + 1
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_count_zero_documents(self, in_memory_repository):
|
||||||
async def test_i_can_count_zero_documents(self, in_memory_repository):
|
|
||||||
"""Test counting file documents in empty collection."""
|
"""Test counting file documents in empty collection."""
|
||||||
# Act
|
# Act
|
||||||
count = await in_memory_repository.count_documents()
|
count = in_memory_repository.count_documents()
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert count == 0
|
assert count == 0
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_count_documents_with_pymongo_error(self, in_memory_repository, mocker):
|
||||||
async def test_i_cannot_count_documents_with_pymongo_error(self, in_memory_repository, mocker):
|
|
||||||
"""Test handling of PyMongo errors during file document counting."""
|
"""Test handling of PyMongo errors during file document counting."""
|
||||||
# Arrange
|
# Arrange
|
||||||
mocker.patch.object(in_memory_repository.collection, 'count_documents', side_effect=PyMongoError("Database error"))
|
mocker.patch.object(in_memory_repository.collection, 'count_documents', side_effect=PyMongoError("Database error"))
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
count = await in_memory_repository.count_documents()
|
count = in_memory_repository.count_documents()
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert count == 0
|
assert count == 0
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Test suite for JobRepository with async/await support.
|
Test suite for JobRepository with async/support.
|
||||||
|
|
||||||
This module contains comprehensive tests for all JobRepository methods
|
This module contains comprehensive tests for all JobRepository methods
|
||||||
using mongomock-motor for in-memory MongoDB testing.
|
using mongomock-motor for in-memory MongoDB testing.
|
||||||
@@ -8,8 +8,8 @@ using mongomock-motor for in-memory MongoDB testing.
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
|
||||||
from bson import ObjectId
|
from bson import ObjectId
|
||||||
|
from mongomock.mongo_client import MongoClient
|
||||||
from mongomock_motor import AsyncMongoMockClient
|
from mongomock_motor import AsyncMongoMockClient
|
||||||
from pymongo.errors import PyMongoError
|
from pymongo.errors import PyMongoError
|
||||||
|
|
||||||
@@ -19,13 +19,13 @@ from app.models.job import ProcessingJob, ProcessingStatus
|
|||||||
from app.models.types import PyObjectId
|
from app.models.types import PyObjectId
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
@pytest.fixture
|
||||||
async def in_memory_repository():
|
def in_memory_repository():
|
||||||
"""Create an in-memory JobRepository for testing."""
|
"""Create an in-memory JobRepository for testing."""
|
||||||
client = AsyncMongoMockClient()
|
client = MongoClient()
|
||||||
db = client.test_database
|
db = client.test_database
|
||||||
repo = JobRepository(db)
|
repo = JobRepository(db)
|
||||||
await repo.initialize()
|
repo.initialize()
|
||||||
return repo
|
return repo
|
||||||
|
|
||||||
|
|
||||||
@@ -82,8 +82,7 @@ def multiple_sample_jobs():
|
|||||||
class TestJobRepositoryInitialization:
|
class TestJobRepositoryInitialization:
|
||||||
"""Tests for repository initialization."""
|
"""Tests for repository initialization."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_initialize_repository(self):
|
||||||
async def test_i_can_initialize_repository(self):
|
|
||||||
"""Test repository initialization."""
|
"""Test repository initialization."""
|
||||||
# Arrange
|
# Arrange
|
||||||
client = AsyncMongoMockClient()
|
client = AsyncMongoMockClient()
|
||||||
@@ -91,7 +90,7 @@ class TestJobRepositoryInitialization:
|
|||||||
repo = JobRepository(db)
|
repo = JobRepository(db)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
initialized_repo = await repo.initialize()
|
initialized_repo = repo.initialize()
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert initialized_repo is repo
|
assert initialized_repo is repo
|
||||||
@@ -102,11 +101,10 @@ class TestJobRepositoryInitialization:
|
|||||||
class TestJobRepositoryCreation:
|
class TestJobRepositoryCreation:
|
||||||
"""Tests for job creation functionality."""
|
"""Tests for job creation functionality."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_create_job_with_task_id(self, in_memory_repository, sample_document_id, sample_task_id):
|
||||||
async def test_i_can_create_job_with_task_id(self, in_memory_repository, sample_document_id, sample_task_id):
|
|
||||||
"""Test successful job creation with task ID."""
|
"""Test successful job creation with task ID."""
|
||||||
# Act
|
# Act
|
||||||
created_job = await in_memory_repository.create_job(sample_document_id, sample_task_id)
|
created_job = in_memory_repository.create_job(sample_document_id, sample_task_id)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert created_job is not None
|
assert created_job is not None
|
||||||
@@ -120,11 +118,10 @@ class TestJobRepositoryCreation:
|
|||||||
assert created_job.id is not None
|
assert created_job.id is not None
|
||||||
assert isinstance(created_job.id, ObjectId)
|
assert isinstance(created_job.id, ObjectId)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_create_job_without_task_id(self, in_memory_repository, sample_document_id):
|
||||||
async def test_i_can_create_job_without_task_id(self, in_memory_repository, sample_document_id):
|
|
||||||
"""Test successful job creation without task ID."""
|
"""Test successful job creation without task ID."""
|
||||||
# Act
|
# Act
|
||||||
created_job = await in_memory_repository.create_job(sample_document_id)
|
created_job = in_memory_repository.create_job(sample_document_id)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert created_job is not None
|
assert created_job is not None
|
||||||
@@ -138,28 +135,26 @@ class TestJobRepositoryCreation:
|
|||||||
assert created_job.id is not None
|
assert created_job.id is not None
|
||||||
assert isinstance(created_job.id, ObjectId)
|
assert isinstance(created_job.id, ObjectId)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_create_duplicate_job_for_document(self, in_memory_repository, sample_document_id,
|
||||||
async def test_i_cannot_create_duplicate_job_for_document(self, in_memory_repository, sample_document_id,
|
|
||||||
sample_task_id):
|
sample_task_id):
|
||||||
"""Test that creating job with duplicate document_id raises DuplicateKeyError."""
|
"""Test that creating job with duplicate document_id raises DuplicateKeyError."""
|
||||||
# Arrange
|
# Arrange
|
||||||
await in_memory_repository.create_job(sample_document_id, sample_task_id)
|
in_memory_repository.create_job(sample_document_id, sample_task_id)
|
||||||
|
|
||||||
# Act & Assert
|
# Act & Assert
|
||||||
with pytest.raises(JobRepositoryError) as exc_info:
|
with pytest.raises(JobRepositoryError) as exc_info:
|
||||||
await in_memory_repository.create_job(sample_document_id, "different-task-id")
|
in_memory_repository.create_job(sample_document_id, "different-task-id")
|
||||||
|
|
||||||
assert "create_job" in str(exc_info.value)
|
assert "create_job" in str(exc_info.value)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_create_job_with_pymongo_error(self, in_memory_repository, sample_document_id, mocker):
|
||||||
async def test_i_cannot_create_job_with_pymongo_error(self, in_memory_repository, sample_document_id, mocker):
|
|
||||||
"""Test handling of PyMongo errors during job creation."""
|
"""Test handling of PyMongo errors during job creation."""
|
||||||
# Arrange
|
# Arrange
|
||||||
mocker.patch.object(in_memory_repository.collection, 'insert_one', side_effect=PyMongoError("Database error"))
|
mocker.patch.object(in_memory_repository.collection, 'insert_one', side_effect=PyMongoError("Database error"))
|
||||||
|
|
||||||
# Act & Assert
|
# Act & Assert
|
||||||
with pytest.raises(JobRepositoryError) as exc_info:
|
with pytest.raises(JobRepositoryError) as exc_info:
|
||||||
await in_memory_repository.create_job(sample_document_id)
|
in_memory_repository.create_job(sample_document_id)
|
||||||
|
|
||||||
assert "create_job" in str(exc_info.value)
|
assert "create_job" in str(exc_info.value)
|
||||||
|
|
||||||
@@ -167,14 +162,13 @@ class TestJobRepositoryCreation:
|
|||||||
class TestJobRepositoryFinding:
|
class TestJobRepositoryFinding:
|
||||||
"""Tests for job finding functionality."""
|
"""Tests for job finding functionality."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_find_job_by_valid_id(self, in_memory_repository, sample_document_id, sample_task_id):
|
||||||
async def test_i_can_find_job_by_valid_id(self, in_memory_repository, sample_document_id, sample_task_id):
|
|
||||||
"""Test finding job by valid ObjectId."""
|
"""Test finding job by valid ObjectId."""
|
||||||
# Arrange
|
# Arrange
|
||||||
created_job = await in_memory_repository.create_job(sample_document_id, sample_task_id)
|
created_job = in_memory_repository.create_job(sample_document_id, sample_task_id)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
found_job = await in_memory_repository.find_job_by_id(created_job.id)
|
found_job = in_memory_repository.find_job_by_id(created_job.id)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert found_job is not None
|
assert found_job is not None
|
||||||
@@ -183,97 +177,90 @@ class TestJobRepositoryFinding:
|
|||||||
assert found_job.task_id == created_job.task_id
|
assert found_job.task_id == created_job.task_id
|
||||||
assert found_job.status == created_job.status
|
assert found_job.status == created_job.status
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_find_job_by_nonexistent_id(self, in_memory_repository):
|
||||||
async def test_i_cannot_find_job_by_nonexistent_id(self, in_memory_repository):
|
|
||||||
"""Test that nonexistent ObjectId returns None."""
|
"""Test that nonexistent ObjectId returns None."""
|
||||||
# Arrange
|
# Arrange
|
||||||
nonexistent_id = PyObjectId()
|
nonexistent_id = PyObjectId()
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
found_job = await in_memory_repository.find_job_by_id(nonexistent_id)
|
found_job = in_memory_repository.find_job_by_id(nonexistent_id)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert found_job is None
|
assert found_job is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_find_job_with_pymongo_error(self, in_memory_repository, mocker):
|
||||||
async def test_i_cannot_find_job_with_pymongo_error(self, in_memory_repository, mocker):
|
|
||||||
"""Test handling of PyMongo errors during job finding."""
|
"""Test handling of PyMongo errors during job finding."""
|
||||||
# Arrange
|
# Arrange
|
||||||
mocker.patch.object(in_memory_repository.collection, 'find_one', side_effect=PyMongoError("Database error"))
|
mocker.patch.object(in_memory_repository.collection, 'find_one', side_effect=PyMongoError("Database error"))
|
||||||
|
|
||||||
# Act & Assert
|
# Act & Assert
|
||||||
with pytest.raises(JobRepositoryError) as exc_info:
|
with pytest.raises(JobRepositoryError) as exc_info:
|
||||||
await in_memory_repository.find_job_by_id(PyObjectId())
|
in_memory_repository.find_job_by_id(PyObjectId())
|
||||||
|
|
||||||
assert "get_job_by_id" in str(exc_info.value)
|
assert "get_job_by_id" in str(exc_info.value)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_find_jobs_by_document_id(self, in_memory_repository, sample_document_id, sample_task_id):
|
||||||
async def test_i_can_find_jobs_by_document_id(self, in_memory_repository, sample_document_id, sample_task_id):
|
|
||||||
"""Test finding jobs by document ID."""
|
"""Test finding jobs by document ID."""
|
||||||
# Arrange
|
# Arrange
|
||||||
created_job = await in_memory_repository.create_job(sample_document_id, sample_task_id)
|
created_job = in_memory_repository.create_job(sample_document_id, sample_task_id)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
found_jobs = await in_memory_repository.find_jobs_by_document_id(sample_document_id)
|
found_jobs = in_memory_repository.find_jobs_by_document_id(sample_document_id)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(found_jobs) == 1
|
assert len(found_jobs) == 1
|
||||||
assert found_jobs[0].id == created_job.id
|
assert found_jobs[0].id == created_job.id
|
||||||
assert found_jobs[0].document_id == sample_document_id
|
assert found_jobs[0].document_id == sample_document_id
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_find_empty_jobs_list_for_nonexistent_document(self, in_memory_repository):
|
||||||
async def test_i_can_find_empty_jobs_list_for_nonexistent_document(self, in_memory_repository):
|
|
||||||
"""Test that nonexistent document ID returns empty list."""
|
"""Test that nonexistent document ID returns empty list."""
|
||||||
# Arrange
|
# Arrange
|
||||||
nonexistent_id = ObjectId()
|
nonexistent_id = ObjectId()
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
found_jobs = await in_memory_repository.find_jobs_by_document_id(nonexistent_id)
|
found_jobs = in_memory_repository.find_jobs_by_document_id(nonexistent_id)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert found_jobs == []
|
assert found_jobs == []
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_find_jobs_by_document_with_pymongo_error(self, in_memory_repository, mocker):
|
||||||
async def test_i_cannot_find_jobs_by_document_with_pymongo_error(self, in_memory_repository, mocker):
|
|
||||||
"""Test handling of PyMongo errors during finding jobs by document ID."""
|
"""Test handling of PyMongo errors during finding jobs by document ID."""
|
||||||
# Arrange
|
# Arrange
|
||||||
mocker.patch.object(in_memory_repository.collection, 'find', side_effect=PyMongoError("Database error"))
|
mocker.patch.object(in_memory_repository.collection, 'find', side_effect=PyMongoError("Database error"))
|
||||||
|
|
||||||
# Act & Assert
|
# Act & Assert
|
||||||
with pytest.raises(JobRepositoryError) as exc_info:
|
with pytest.raises(JobRepositoryError) as exc_info:
|
||||||
await in_memory_repository.find_jobs_by_document_id(PyObjectId())
|
in_memory_repository.find_jobs_by_document_id(PyObjectId())
|
||||||
|
|
||||||
assert "get_jobs_by_file_id" in str(exc_info.value)
|
assert "get_jobs_by_file_id" in str(exc_info.value)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.parametrize("status", [
|
@pytest.mark.parametrize("status", [
|
||||||
ProcessingStatus.PENDING,
|
ProcessingStatus.PENDING,
|
||||||
ProcessingStatus.PROCESSING,
|
ProcessingStatus.PROCESSING,
|
||||||
ProcessingStatus.COMPLETED
|
ProcessingStatus.COMPLETED
|
||||||
])
|
])
|
||||||
async def test_i_can_find_jobs_by_pending_status(self, in_memory_repository, sample_document_id, status):
|
def test_i_can_find_jobs_by_pending_status(self, in_memory_repository, sample_document_id, status):
|
||||||
"""Test finding jobs by PENDING status."""
|
"""Test finding jobs by PENDING status."""
|
||||||
# Arrange
|
# Arrange
|
||||||
created_job = await in_memory_repository.create_job(sample_document_id)
|
created_job = in_memory_repository.create_job(sample_document_id)
|
||||||
await in_memory_repository.update_job_status(created_job.id, status)
|
in_memory_repository.update_job_status(created_job.id, status)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
found_jobs = await in_memory_repository.get_jobs_by_status(status)
|
found_jobs = in_memory_repository.get_jobs_by_status(status)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(found_jobs) == 1
|
assert len(found_jobs) == 1
|
||||||
assert found_jobs[0].id == created_job.id
|
assert found_jobs[0].id == created_job.id
|
||||||
assert found_jobs[0].status == status
|
assert found_jobs[0].status == status
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_find_jobs_by_failed_status(self, in_memory_repository, sample_document_id):
|
||||||
async def test_i_can_find_jobs_by_failed_status(self, in_memory_repository, sample_document_id):
|
|
||||||
"""Test finding jobs by FAILED status."""
|
"""Test finding jobs by FAILED status."""
|
||||||
# Arrange
|
# Arrange
|
||||||
created_job = await in_memory_repository.create_job(sample_document_id)
|
created_job = in_memory_repository.create_job(sample_document_id)
|
||||||
await in_memory_repository.update_job_status(created_job.id, ProcessingStatus.FAILED, "Test error")
|
in_memory_repository.update_job_status(created_job.id, ProcessingStatus.FAILED, "Test error")
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
found_jobs = await in_memory_repository.get_jobs_by_status(ProcessingStatus.FAILED)
|
found_jobs = in_memory_repository.get_jobs_by_status(ProcessingStatus.FAILED)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(found_jobs) == 1
|
assert len(found_jobs) == 1
|
||||||
@@ -281,24 +268,22 @@ class TestJobRepositoryFinding:
|
|||||||
assert found_jobs[0].status == ProcessingStatus.FAILED
|
assert found_jobs[0].status == ProcessingStatus.FAILED
|
||||||
assert found_jobs[0].error_message == "Test error"
|
assert found_jobs[0].error_message == "Test error"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_find_empty_jobs_list_for_unused_status(self, in_memory_repository):
|
||||||
async def test_i_can_find_empty_jobs_list_for_unused_status(self, in_memory_repository):
|
|
||||||
"""Test that unused status returns empty list."""
|
"""Test that unused status returns empty list."""
|
||||||
# Act
|
# Act
|
||||||
found_jobs = await in_memory_repository.get_jobs_by_status(ProcessingStatus.COMPLETED)
|
found_jobs = in_memory_repository.get_jobs_by_status(ProcessingStatus.COMPLETED)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert found_jobs == []
|
assert found_jobs == []
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_find_jobs_by_status_with_pymongo_error(self, in_memory_repository, mocker):
|
||||||
async def test_i_cannot_find_jobs_by_status_with_pymongo_error(self, in_memory_repository, mocker):
|
|
||||||
"""Test handling of PyMongo errors during finding jobs by status."""
|
"""Test handling of PyMongo errors during finding jobs by status."""
|
||||||
# Arrange
|
# Arrange
|
||||||
mocker.patch.object(in_memory_repository.collection, 'find', side_effect=PyMongoError("Database error"))
|
mocker.patch.object(in_memory_repository.collection, 'find', side_effect=PyMongoError("Database error"))
|
||||||
|
|
||||||
# Act & Assert
|
# Act & Assert
|
||||||
with pytest.raises(JobRepositoryError) as exc_info:
|
with pytest.raises(JobRepositoryError) as exc_info:
|
||||||
await in_memory_repository.get_jobs_by_status(ProcessingStatus.PENDING)
|
in_memory_repository.get_jobs_by_status(ProcessingStatus.PENDING)
|
||||||
|
|
||||||
assert "get_jobs_by_status" in str(exc_info.value)
|
assert "get_jobs_by_status" in str(exc_info.value)
|
||||||
|
|
||||||
@@ -306,14 +291,13 @@ class TestJobRepositoryFinding:
|
|||||||
class TestJobRepositoryStatusUpdate:
|
class TestJobRepositoryStatusUpdate:
|
||||||
"""Tests for job status update functionality."""
|
"""Tests for job status update functionality."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_update_job_status_to_processing(self, in_memory_repository, sample_document_id):
|
||||||
async def test_i_can_update_job_status_to_processing(self, in_memory_repository, sample_document_id):
|
|
||||||
"""Test updating job status to PROCESSING with started_at timestamp."""
|
"""Test updating job status to PROCESSING with started_at timestamp."""
|
||||||
# Arrange
|
# Arrange
|
||||||
created_job = await in_memory_repository.create_job(sample_document_id)
|
created_job = in_memory_repository.create_job(sample_document_id)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
updated_job = await in_memory_repository.update_job_status(created_job.id, ProcessingStatus.PROCESSING)
|
updated_job = in_memory_repository.update_job_status(created_job.id, ProcessingStatus.PROCESSING)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert updated_job is not None
|
assert updated_job is not None
|
||||||
@@ -323,15 +307,14 @@ class TestJobRepositoryStatusUpdate:
|
|||||||
assert updated_job.completed_at is None
|
assert updated_job.completed_at is None
|
||||||
assert updated_job.error_message is None
|
assert updated_job.error_message is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_update_job_status_to_completed(self, in_memory_repository, sample_document_id):
|
||||||
async def test_i_can_update_job_status_to_completed(self, in_memory_repository, sample_document_id):
|
|
||||||
"""Test updating job status to COMPLETED with completed_at timestamp."""
|
"""Test updating job status to COMPLETED with completed_at timestamp."""
|
||||||
# Arrange
|
# Arrange
|
||||||
created_job = await in_memory_repository.create_job(sample_document_id)
|
created_job = in_memory_repository.create_job(sample_document_id)
|
||||||
await in_memory_repository.update_job_status(created_job.id, ProcessingStatus.PROCESSING)
|
in_memory_repository.update_job_status(created_job.id, ProcessingStatus.PROCESSING)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
updated_job = await in_memory_repository.update_job_status(created_job.id, ProcessingStatus.COMPLETED)
|
updated_job = in_memory_repository.update_job_status(created_job.id, ProcessingStatus.COMPLETED)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert updated_job is not None
|
assert updated_job is not None
|
||||||
@@ -341,15 +324,14 @@ class TestJobRepositoryStatusUpdate:
|
|||||||
assert updated_job.completed_at is not None
|
assert updated_job.completed_at is not None
|
||||||
assert updated_job.error_message is None
|
assert updated_job.error_message is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_update_job_status_to_failed_with_error(self, in_memory_repository, sample_document_id):
|
||||||
async def test_i_can_update_job_status_to_failed_with_error(self, in_memory_repository, sample_document_id):
|
|
||||||
"""Test updating job status to FAILED with error message and completed_at timestamp."""
|
"""Test updating job status to FAILED with error message and completed_at timestamp."""
|
||||||
# Arrange
|
# Arrange
|
||||||
created_job = await in_memory_repository.create_job(sample_document_id)
|
created_job = in_memory_repository.create_job(sample_document_id)
|
||||||
error_message = "Processing failed due to invalid format"
|
error_message = "Processing failed due to invalid format"
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
updated_job = await in_memory_repository.update_job_status(
|
updated_job = in_memory_repository.update_job_status(
|
||||||
created_job.id, ProcessingStatus.FAILED, error_message
|
created_job.id, ProcessingStatus.FAILED, error_message
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -360,14 +342,13 @@ class TestJobRepositoryStatusUpdate:
|
|||||||
assert updated_job.completed_at is not None
|
assert updated_job.completed_at is not None
|
||||||
assert updated_job.error_message == error_message
|
assert updated_job.error_message == error_message
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_update_job_status_to_failed_without_error(self, in_memory_repository, sample_document_id):
|
||||||
async def test_i_can_update_job_status_to_failed_without_error(self, in_memory_repository, sample_document_id):
|
|
||||||
"""Test updating job status to FAILED without error message."""
|
"""Test updating job status to FAILED without error message."""
|
||||||
# Arrange
|
# Arrange
|
||||||
created_job = await in_memory_repository.create_job(sample_document_id)
|
created_job = in_memory_repository.create_job(sample_document_id)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
updated_job = await in_memory_repository.update_job_status(created_job.id, ProcessingStatus.FAILED)
|
updated_job = in_memory_repository.update_job_status(created_job.id, ProcessingStatus.FAILED)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert updated_job is not None
|
assert updated_job is not None
|
||||||
@@ -376,29 +357,27 @@ class TestJobRepositoryStatusUpdate:
|
|||||||
assert updated_job.completed_at is not None
|
assert updated_job.completed_at is not None
|
||||||
assert updated_job.error_message is None
|
assert updated_job.error_message is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_update_nonexistent_job_status(self, in_memory_repository):
|
||||||
async def test_i_cannot_update_nonexistent_job_status(self, in_memory_repository):
|
|
||||||
"""Test that updating nonexistent job returns None."""
|
"""Test that updating nonexistent job returns None."""
|
||||||
# Arrange
|
# Arrange
|
||||||
nonexistent_id = ObjectId()
|
nonexistent_id = ObjectId()
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
result = await in_memory_repository.update_job_status(nonexistent_id, ProcessingStatus.COMPLETED)
|
result = in_memory_repository.update_job_status(nonexistent_id, ProcessingStatus.COMPLETED)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_update_job_status_with_pymongo_error(self, in_memory_repository, sample_document_id, mocker):
|
||||||
async def test_i_cannot_update_job_status_with_pymongo_error(self, in_memory_repository, sample_document_id, mocker):
|
|
||||||
"""Test handling of PyMongo errors during job status update."""
|
"""Test handling of PyMongo errors during job status update."""
|
||||||
# Arrange
|
# Arrange
|
||||||
created_job = await in_memory_repository.create_job(sample_document_id)
|
created_job = in_memory_repository.create_job(sample_document_id)
|
||||||
mocker.patch.object(in_memory_repository.collection, 'find_one_and_update',
|
mocker.patch.object(in_memory_repository.collection, 'find_one_and_update',
|
||||||
side_effect=PyMongoError("Database error"))
|
side_effect=PyMongoError("Database error"))
|
||||||
|
|
||||||
# Act & Assert
|
# Act & Assert
|
||||||
with pytest.raises(JobRepositoryError) as exc_info:
|
with pytest.raises(JobRepositoryError) as exc_info:
|
||||||
await in_memory_repository.update_job_status(created_job.id, ProcessingStatus.COMPLETED)
|
in_memory_repository.update_job_status(created_job.id, ProcessingStatus.COMPLETED)
|
||||||
|
|
||||||
assert "update_job_status" in str(exc_info.value)
|
assert "update_job_status" in str(exc_info.value)
|
||||||
|
|
||||||
@@ -406,44 +385,41 @@ class TestJobRepositoryStatusUpdate:
|
|||||||
class TestJobRepositoryDeletion:
|
class TestJobRepositoryDeletion:
|
||||||
"""Tests for job deletion functionality."""
|
"""Tests for job deletion functionality."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_delete_existing_job(self, in_memory_repository, sample_document_id):
|
||||||
async def test_i_can_delete_existing_job(self, in_memory_repository, sample_document_id):
|
|
||||||
"""Test successful job deletion."""
|
"""Test successful job deletion."""
|
||||||
# Arrange
|
# Arrange
|
||||||
created_job = await in_memory_repository.create_job(sample_document_id)
|
created_job = in_memory_repository.create_job(sample_document_id)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
deletion_result = await in_memory_repository.delete_job(created_job.id)
|
deletion_result = in_memory_repository.delete_job(created_job.id)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert deletion_result is True
|
assert deletion_result is True
|
||||||
|
|
||||||
# Verify job is actually deleted
|
# Verify job is actually deleted
|
||||||
found_job = await in_memory_repository.find_job_by_id(created_job.id)
|
found_job = in_memory_repository.find_job_by_id(created_job.id)
|
||||||
assert found_job is None
|
assert found_job is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_delete_nonexistent_job(self, in_memory_repository):
|
||||||
async def test_i_cannot_delete_nonexistent_job(self, in_memory_repository):
|
|
||||||
"""Test that deleting nonexistent job returns False."""
|
"""Test that deleting nonexistent job returns False."""
|
||||||
# Arrange
|
# Arrange
|
||||||
nonexistent_id = ObjectId()
|
nonexistent_id = ObjectId()
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
result = await in_memory_repository.delete_job(nonexistent_id)
|
result = in_memory_repository.delete_job(nonexistent_id)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_delete_job_with_pymongo_error(self, in_memory_repository, sample_document_id, mocker):
|
||||||
async def test_i_cannot_delete_job_with_pymongo_error(self, in_memory_repository, sample_document_id, mocker):
|
|
||||||
"""Test handling of PyMongo errors during job deletion."""
|
"""Test handling of PyMongo errors during job deletion."""
|
||||||
# Arrange
|
# Arrange
|
||||||
created_job = await in_memory_repository.create_job(sample_document_id)
|
created_job = in_memory_repository.create_job(sample_document_id)
|
||||||
mocker.patch.object(in_memory_repository.collection, 'delete_one', side_effect=PyMongoError("Database error"))
|
mocker.patch.object(in_memory_repository.collection, 'delete_one', side_effect=PyMongoError("Database error"))
|
||||||
|
|
||||||
# Act & Assert
|
# Act & Assert
|
||||||
with pytest.raises(JobRepositoryError) as exc_info:
|
with pytest.raises(JobRepositoryError) as exc_info:
|
||||||
await in_memory_repository.delete_job(created_job.id)
|
in_memory_repository.delete_job(created_job.id)
|
||||||
|
|
||||||
assert "delete_job" in str(exc_info.value)
|
assert "delete_job" in str(exc_info.value)
|
||||||
|
|
||||||
@@ -451,38 +427,36 @@ class TestJobRepositoryDeletion:
|
|||||||
class TestJobRepositoryComplexScenarios:
|
class TestJobRepositoryComplexScenarios:
|
||||||
"""Tests for complex job repository scenarios."""
|
"""Tests for complex job repository scenarios."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_handle_complete_job_lifecycle(self, in_memory_repository, sample_document_id, sample_task_id):
|
||||||
async def test_i_can_handle_complete_job_lifecycle(self, in_memory_repository, sample_document_id, sample_task_id):
|
|
||||||
"""Test complete job lifecycle from creation to completion."""
|
"""Test complete job lifecycle from creation to completion."""
|
||||||
# Create job
|
# Create job
|
||||||
job = await in_memory_repository.create_job(sample_document_id, sample_task_id)
|
job = in_memory_repository.create_job(sample_document_id, sample_task_id)
|
||||||
assert job.status == ProcessingStatus.PENDING
|
assert job.status == ProcessingStatus.PENDING
|
||||||
assert job.started_at is None
|
assert job.started_at is None
|
||||||
assert job.completed_at is None
|
assert job.completed_at is None
|
||||||
|
|
||||||
# Start processing
|
# Start processing
|
||||||
job = await in_memory_repository.update_job_status(job.id, ProcessingStatus.PROCESSING)
|
job = in_memory_repository.update_job_status(job.id, ProcessingStatus.PROCESSING)
|
||||||
assert job.status == ProcessingStatus.PROCESSING
|
assert job.status == ProcessingStatus.PROCESSING
|
||||||
assert job.started_at is not None
|
assert job.started_at is not None
|
||||||
assert job.completed_at is None
|
assert job.completed_at is None
|
||||||
|
|
||||||
# Complete job
|
# Complete job
|
||||||
job = await in_memory_repository.update_job_status(job.id, ProcessingStatus.COMPLETED)
|
job = in_memory_repository.update_job_status(job.id, ProcessingStatus.COMPLETED)
|
||||||
assert job.status == ProcessingStatus.COMPLETED
|
assert job.status == ProcessingStatus.COMPLETED
|
||||||
assert job.started_at is not None
|
assert job.started_at is not None
|
||||||
assert job.completed_at is not None
|
assert job.completed_at is not None
|
||||||
assert job.error_message is None
|
assert job.error_message is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_handle_job_failure_scenario(self, in_memory_repository, sample_document_id, sample_task_id):
|
||||||
async def test_i_can_handle_job_failure_scenario(self, in_memory_repository, sample_document_id, sample_task_id):
|
|
||||||
"""Test job failure scenario with error message."""
|
"""Test job failure scenario with error message."""
|
||||||
# Create and start job
|
# Create and start job
|
||||||
job = await in_memory_repository.create_job(sample_document_id, sample_task_id)
|
job = in_memory_repository.create_job(sample_document_id, sample_task_id)
|
||||||
job = await in_memory_repository.update_job_status(job.id, ProcessingStatus.PROCESSING)
|
job = in_memory_repository.update_job_status(job.id, ProcessingStatus.PROCESSING)
|
||||||
|
|
||||||
# Fail job with error
|
# Fail job with error
|
||||||
error_msg = "File format not supported"
|
error_msg = "File format not supported"
|
||||||
job = await in_memory_repository.update_job_status(job.id, ProcessingStatus.FAILED, error_msg)
|
job = in_memory_repository.update_job_status(job.id, ProcessingStatus.FAILED, error_msg)
|
||||||
|
|
||||||
# Assert failure state
|
# Assert failure state
|
||||||
assert job.status == ProcessingStatus.FAILED
|
assert job.status == ProcessingStatus.FAILED
|
||||||
@@ -490,28 +464,27 @@ class TestJobRepositoryComplexScenarios:
|
|||||||
assert job.completed_at is not None
|
assert job.completed_at is not None
|
||||||
assert job.error_message == error_msg
|
assert job.error_message == error_msg
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_handle_multiple_documents_with_different_statuses(self, in_memory_repository):
|
||||||
async def test_i_can_handle_multiple_documents_with_different_statuses(self, in_memory_repository):
|
|
||||||
"""Test managing multiple jobs for different documents with various statuses."""
|
"""Test managing multiple jobs for different documents with various statuses."""
|
||||||
# Create jobs for different documents
|
# Create jobs for different documents
|
||||||
doc1 = PyObjectId()
|
doc1 = PyObjectId()
|
||||||
doc2 = PyObjectId()
|
doc2 = PyObjectId()
|
||||||
doc3 = PyObjectId()
|
doc3 = PyObjectId()
|
||||||
|
|
||||||
job1 = await in_memory_repository.create_job(doc1, "task-1")
|
job1 = in_memory_repository.create_job(doc1, "task-1")
|
||||||
job2 = await in_memory_repository.create_job(doc2, "task-2")
|
job2 = in_memory_repository.create_job(doc2, "task-2")
|
||||||
job3 = await in_memory_repository.create_job(doc3, "task-3")
|
job3 = in_memory_repository.create_job(doc3, "task-3")
|
||||||
|
|
||||||
# Update to different statuses
|
# Update to different statuses
|
||||||
await in_memory_repository.update_job_status(job1.id, ProcessingStatus.PROCESSING)
|
in_memory_repository.update_job_status(job1.id, ProcessingStatus.PROCESSING)
|
||||||
await in_memory_repository.update_job_status(job2.id, ProcessingStatus.COMPLETED)
|
in_memory_repository.update_job_status(job2.id, ProcessingStatus.COMPLETED)
|
||||||
await in_memory_repository.update_job_status(job3.id, ProcessingStatus.FAILED, "Error occurred")
|
in_memory_repository.update_job_status(job3.id, ProcessingStatus.FAILED, "Error occurred")
|
||||||
|
|
||||||
# Verify status queries
|
# Verify status queries
|
||||||
pending_jobs = await in_memory_repository.get_jobs_by_status(ProcessingStatus.PENDING)
|
pending_jobs = in_memory_repository.get_jobs_by_status(ProcessingStatus.PENDING)
|
||||||
processing_jobs = await in_memory_repository.get_jobs_by_status(ProcessingStatus.PROCESSING)
|
processing_jobs = in_memory_repository.get_jobs_by_status(ProcessingStatus.PROCESSING)
|
||||||
completed_jobs = await in_memory_repository.get_jobs_by_status(ProcessingStatus.COMPLETED)
|
completed_jobs = in_memory_repository.get_jobs_by_status(ProcessingStatus.COMPLETED)
|
||||||
failed_jobs = await in_memory_repository.get_jobs_by_status(ProcessingStatus.FAILED)
|
failed_jobs = in_memory_repository.get_jobs_by_status(ProcessingStatus.FAILED)
|
||||||
|
|
||||||
assert len(pending_jobs) == 0
|
assert len(pending_jobs) == 0
|
||||||
assert len(processing_jobs) == 1
|
assert len(processing_jobs) == 1
|
||||||
|
|||||||
@@ -1,29 +1,26 @@
|
|||||||
"""
|
"""
|
||||||
Test suite for UserRepository with async/await support.
|
Test suite for UserRepository with async/support.
|
||||||
|
|
||||||
This module contains comprehensive tests for all UserRepository methods
|
This module contains comprehensive tests for all UserRepository methods
|
||||||
using mongomock-motor for in-memory MongoDB testing.
|
using mongomock-motor for in-memory MongoDB testing.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
import pytest_asyncio
|
|
||||||
from bson import ObjectId
|
from bson import ObjectId
|
||||||
|
from mongomock.mongo_client import MongoClient
|
||||||
from pymongo.errors import DuplicateKeyError
|
from pymongo.errors import DuplicateKeyError
|
||||||
from mongomock_motor import AsyncMongoMockClient
|
|
||||||
|
|
||||||
from app.database.repositories.user_repository import UserRepository
|
from app.database.repositories.user_repository import UserRepository
|
||||||
from app.models.user import UserCreate, UserUpdate, UserInDB
|
from app.models.user import UserCreate, UserUpdate
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
@pytest.fixture
|
||||||
async def in_memory_repository():
|
def in_memory_repository():
|
||||||
"""Create an in-memory UserRepository for testing."""
|
"""Create an in-memory UserRepository for testing."""
|
||||||
client = AsyncMongoMockClient()
|
client = MongoClient()
|
||||||
db = client.test_database
|
db = client.test_database
|
||||||
repo = UserRepository(db)
|
repo = UserRepository(db)
|
||||||
await repo.initialize()
|
repo.initialize()
|
||||||
return repo
|
return repo
|
||||||
|
|
||||||
|
|
||||||
@@ -51,11 +48,10 @@ def sample_user_update():
|
|||||||
class TestUserRepositoryCreation:
|
class TestUserRepositoryCreation:
|
||||||
"""Tests for user creation functionality."""
|
"""Tests for user creation functionality."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_create_user(self, in_memory_repository, sample_user_create):
|
||||||
async def test_i_can_create_user(self, in_memory_repository, sample_user_create):
|
|
||||||
"""Test successful user creation."""
|
"""Test successful user creation."""
|
||||||
# Act
|
# Act
|
||||||
created_user = await in_memory_repository.create_user(sample_user_create)
|
created_user = in_memory_repository.create_user(sample_user_create)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert created_user is not None
|
assert created_user is not None
|
||||||
@@ -68,15 +64,14 @@ class TestUserRepositoryCreation:
|
|||||||
assert created_user.updated_at is not None
|
assert created_user.updated_at is not None
|
||||||
assert created_user.hashed_password != sample_user_create.password # Should be hashed
|
assert created_user.hashed_password != sample_user_create.password # Should be hashed
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_create_user_with_duplicate_username(self, in_memory_repository, sample_user_create):
|
||||||
async def test_i_cannot_create_user_with_duplicate_username(self, in_memory_repository, sample_user_create):
|
|
||||||
"""Test that creating user with duplicate username raises DuplicateKeyError."""
|
"""Test that creating user with duplicate username raises DuplicateKeyError."""
|
||||||
# Arrange
|
# Arrange
|
||||||
await in_memory_repository.create_user(sample_user_create)
|
in_memory_repository.create_user(sample_user_create)
|
||||||
|
|
||||||
# Act & Assert
|
# Act & Assert
|
||||||
with pytest.raises(DuplicateKeyError) as exc_info:
|
with pytest.raises(DuplicateKeyError) as exc_info:
|
||||||
await in_memory_repository.create_user(sample_user_create)
|
in_memory_repository.create_user(sample_user_create)
|
||||||
|
|
||||||
assert "already exists" in str(exc_info.value)
|
assert "already exists" in str(exc_info.value)
|
||||||
|
|
||||||
@@ -84,14 +79,13 @@ class TestUserRepositoryCreation:
|
|||||||
class TestUserRepositoryFinding:
|
class TestUserRepositoryFinding:
|
||||||
"""Tests for user finding functionality."""
|
"""Tests for user finding functionality."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_find_user_by_id(self, in_memory_repository, sample_user_create):
|
||||||
async def test_i_can_find_user_by_id(self, in_memory_repository, sample_user_create):
|
|
||||||
"""Test finding user by valid ID."""
|
"""Test finding user by valid ID."""
|
||||||
# Arrange
|
# Arrange
|
||||||
created_user = await in_memory_repository.create_user(sample_user_create)
|
created_user = in_memory_repository.create_user(sample_user_create)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
found_user = await in_memory_repository.find_user_by_id(str(created_user.id))
|
found_user = in_memory_repository.find_user_by_id(str(created_user.id))
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert found_user is not None
|
assert found_user is not None
|
||||||
@@ -99,69 +93,63 @@ class TestUserRepositoryFinding:
|
|||||||
assert found_user.username == created_user.username
|
assert found_user.username == created_user.username
|
||||||
assert found_user.email == created_user.email
|
assert found_user.email == created_user.email
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_find_user_by_invalid_id(self, in_memory_repository):
|
||||||
async def test_i_cannot_find_user_by_invalid_id(self, in_memory_repository):
|
|
||||||
"""Test that invalid ObjectId returns None."""
|
"""Test that invalid ObjectId returns None."""
|
||||||
# Act
|
# Act
|
||||||
found_user = await in_memory_repository.find_user_by_id("invalid_id")
|
found_user = in_memory_repository.find_user_by_id("invalid_id")
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert found_user is None
|
assert found_user is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_find_user_by_nonexistent_id(self, in_memory_repository):
|
||||||
async def test_i_cannot_find_user_by_nonexistent_id(self, in_memory_repository):
|
|
||||||
"""Test that nonexistent but valid ObjectId returns None."""
|
"""Test that nonexistent but valid ObjectId returns None."""
|
||||||
# Arrange
|
# Arrange
|
||||||
nonexistent_id = str(ObjectId())
|
nonexistent_id = str(ObjectId())
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
found_user = await in_memory_repository.find_user_by_id(nonexistent_id)
|
found_user = in_memory_repository.find_user_by_id(nonexistent_id)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert found_user is None
|
assert found_user is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_find_user_by_username(self, in_memory_repository, sample_user_create):
|
||||||
async def test_i_can_find_user_by_username(self, in_memory_repository, sample_user_create):
|
|
||||||
"""Test finding user by username."""
|
"""Test finding user by username."""
|
||||||
# Arrange
|
# Arrange
|
||||||
created_user = await in_memory_repository.create_user(sample_user_create)
|
created_user = in_memory_repository.create_user(sample_user_create)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
found_user = await in_memory_repository.find_user_by_username(sample_user_create.username)
|
found_user = in_memory_repository.find_user_by_username(sample_user_create.username)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert found_user is not None
|
assert found_user is not None
|
||||||
assert found_user.username == created_user.username
|
assert found_user.username == created_user.username
|
||||||
assert found_user.id == created_user.id
|
assert found_user.id == created_user.id
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_find_user_by_nonexistent_username(self, in_memory_repository):
|
||||||
async def test_i_cannot_find_user_by_nonexistent_username(self, in_memory_repository):
|
|
||||||
"""Test that nonexistent username returns None."""
|
"""Test that nonexistent username returns None."""
|
||||||
# Act
|
# Act
|
||||||
found_user = await in_memory_repository.find_user_by_username("nonexistent")
|
found_user = in_memory_repository.find_user_by_username("nonexistent")
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert found_user is None
|
assert found_user is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_find_user_by_email(self, in_memory_repository, sample_user_create):
|
||||||
async def test_i_can_find_user_by_email(self, in_memory_repository, sample_user_create):
|
|
||||||
"""Test finding user by email."""
|
"""Test finding user by email."""
|
||||||
# Arrange
|
# Arrange
|
||||||
created_user = await in_memory_repository.create_user(sample_user_create)
|
created_user = in_memory_repository.create_user(sample_user_create)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
found_user = await in_memory_repository.find_user_by_email(str(sample_user_create.email))
|
found_user = in_memory_repository.find_user_by_email(str(sample_user_create.email))
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert found_user is not None
|
assert found_user is not None
|
||||||
assert found_user.email == created_user.email
|
assert found_user.email == created_user.email
|
||||||
assert found_user.id == created_user.id
|
assert found_user.id == created_user.id
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_find_user_by_nonexistent_email(self, in_memory_repository):
|
||||||
async def test_i_cannot_find_user_by_nonexistent_email(self, in_memory_repository):
|
|
||||||
"""Test that nonexistent email returns None."""
|
"""Test that nonexistent email returns None."""
|
||||||
# Act
|
# Act
|
||||||
found_user = await in_memory_repository.find_user_by_email("nonexistent@example.com")
|
found_user = in_memory_repository.find_user_by_email("nonexistent@example.com")
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert found_user is None
|
assert found_user is None
|
||||||
@@ -170,15 +158,14 @@ class TestUserRepositoryFinding:
|
|||||||
class TestUserRepositoryUpdate:
|
class TestUserRepositoryUpdate:
|
||||||
"""Tests for user update functionality."""
|
"""Tests for user update functionality."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_update_user(self, in_memory_repository, sample_user_create, sample_user_update):
|
||||||
async def test_i_can_update_user(self, in_memory_repository, sample_user_create, sample_user_update):
|
|
||||||
"""Test successful user update."""
|
"""Test successful user update."""
|
||||||
# Arrange
|
# Arrange
|
||||||
created_user = await in_memory_repository.create_user(sample_user_create)
|
created_user = in_memory_repository.create_user(sample_user_create)
|
||||||
original_updated_at = created_user.updated_at
|
original_updated_at = created_user.updated_at
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
updated_user = await in_memory_repository.update_user(str(created_user.id), sample_user_update)
|
updated_user = in_memory_repository.update_user(str(created_user.id), sample_user_update)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert updated_user is not None
|
assert updated_user is not None
|
||||||
@@ -187,24 +174,22 @@ class TestUserRepositoryUpdate:
|
|||||||
assert updated_user.role == sample_user_update.role
|
assert updated_user.role == sample_user_update.role
|
||||||
assert updated_user.id == created_user.id
|
assert updated_user.id == created_user.id
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_update_user_with_invalid_id(self, in_memory_repository, sample_user_update):
|
||||||
async def test_i_cannot_update_user_with_invalid_id(self, in_memory_repository, sample_user_update):
|
|
||||||
"""Test that updating with invalid ID returns None."""
|
"""Test that updating with invalid ID returns None."""
|
||||||
# Act
|
# Act
|
||||||
result = await in_memory_repository.update_user("invalid_id", sample_user_update)
|
result = in_memory_repository.update_user("invalid_id", sample_user_update)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_update_user_with_partial_data(self, in_memory_repository, sample_user_create):
|
||||||
async def test_i_can_update_user_with_partial_data(self, in_memory_repository, sample_user_create):
|
|
||||||
"""Test updating user with partial data."""
|
"""Test updating user with partial data."""
|
||||||
# Arrange
|
# Arrange
|
||||||
created_user = await in_memory_repository.create_user(sample_user_create)
|
created_user = in_memory_repository.create_user(sample_user_create)
|
||||||
partial_update = UserUpdate(username="newusername")
|
partial_update = UserUpdate(username="newusername")
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
updated_user = await in_memory_repository.update_user(str(created_user.id), partial_update)
|
updated_user = in_memory_repository.update_user(str(created_user.id), partial_update)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert updated_user is not None
|
assert updated_user is not None
|
||||||
@@ -212,15 +197,14 @@ class TestUserRepositoryUpdate:
|
|||||||
assert updated_user.email == created_user.email # Should remain unchanged
|
assert updated_user.email == created_user.email # Should remain unchanged
|
||||||
assert updated_user.role == created_user.role # Should remain unchanged
|
assert updated_user.role == created_user.role # Should remain unchanged
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_update_user_with_empty_data(self, in_memory_repository, sample_user_create):
|
||||||
async def test_i_can_update_user_with_empty_data(self, in_memory_repository, sample_user_create):
|
|
||||||
"""Test updating user with empty data returns current user."""
|
"""Test updating user with empty data returns current user."""
|
||||||
# Arrange
|
# Arrange
|
||||||
created_user = await in_memory_repository.create_user(sample_user_create)
|
created_user = in_memory_repository.create_user(sample_user_create)
|
||||||
empty_update = UserUpdate()
|
empty_update = UserUpdate()
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
result = await in_memory_repository.update_user(str(created_user.id), empty_update)
|
result = in_memory_repository.update_user(str(created_user.id), empty_update)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert result is not None
|
assert result is not None
|
||||||
@@ -231,39 +215,36 @@ class TestUserRepositoryUpdate:
|
|||||||
class TestUserRepositoryDeletion:
|
class TestUserRepositoryDeletion:
|
||||||
"""Tests for user deletion functionality."""
|
"""Tests for user deletion functionality."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_delete_user(self, in_memory_repository, sample_user_create):
|
||||||
async def test_i_can_delete_user(self, in_memory_repository, sample_user_create):
|
|
||||||
"""Test successful user deletion."""
|
"""Test successful user deletion."""
|
||||||
# Arrange
|
# Arrange
|
||||||
created_user = await in_memory_repository.create_user(sample_user_create)
|
created_user = in_memory_repository.create_user(sample_user_create)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
deletion_result = await in_memory_repository.delete_user(str(created_user.id))
|
deletion_result = in_memory_repository.delete_user(str(created_user.id))
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert deletion_result is True
|
assert deletion_result is True
|
||||||
|
|
||||||
# Verify user is actually deleted
|
# Verify user is actually deleted
|
||||||
found_user = await in_memory_repository.find_user_by_id(str(created_user.id))
|
found_user = in_memory_repository.find_user_by_id(str(created_user.id))
|
||||||
assert found_user is None
|
assert found_user is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_delete_user_with_invalid_id(self, in_memory_repository):
|
||||||
async def test_i_cannot_delete_user_with_invalid_id(self, in_memory_repository):
|
|
||||||
"""Test that deleting with invalid ID returns False."""
|
"""Test that deleting with invalid ID returns False."""
|
||||||
# Act
|
# Act
|
||||||
result = await in_memory_repository.delete_user("invalid_id")
|
result = in_memory_repository.delete_user("invalid_id")
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_delete_nonexistent_user(self, in_memory_repository):
|
||||||
async def test_i_cannot_delete_nonexistent_user(self, in_memory_repository):
|
|
||||||
"""Test that deleting nonexistent user returns False."""
|
"""Test that deleting nonexistent user returns False."""
|
||||||
# Arrange
|
# Arrange
|
||||||
nonexistent_id = str(ObjectId())
|
nonexistent_id = str(ObjectId())
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
result = await in_memory_repository.delete_user(nonexistent_id)
|
result = in_memory_repository.delete_user(nonexistent_id)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert result is False
|
assert result is False
|
||||||
@@ -272,30 +253,27 @@ class TestUserRepositoryDeletion:
|
|||||||
class TestUserRepositoryUtilities:
|
class TestUserRepositoryUtilities:
|
||||||
"""Tests for utility methods."""
|
"""Tests for utility methods."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_count_users(self, in_memory_repository, sample_user_create):
|
||||||
async def test_i_can_count_users(self, in_memory_repository, sample_user_create):
|
|
||||||
"""Test counting users."""
|
"""Test counting users."""
|
||||||
# Arrange
|
# Arrange
|
||||||
initial_count = await in_memory_repository.count_users()
|
initial_count = in_memory_repository.count_users()
|
||||||
await in_memory_repository.create_user(sample_user_create)
|
in_memory_repository.create_user(sample_user_create)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
final_count = await in_memory_repository.count_users()
|
final_count = in_memory_repository.count_users()
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert final_count == initial_count + 1
|
assert final_count == initial_count + 1
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_check_user_exists(self, in_memory_repository, sample_user_create):
|
||||||
async def test_i_can_check_user_exists(self, in_memory_repository, sample_user_create):
|
|
||||||
"""Test checking if user exists."""
|
"""Test checking if user exists."""
|
||||||
# Arrange
|
# Arrange
|
||||||
await in_memory_repository.create_user(sample_user_create)
|
in_memory_repository.create_user(sample_user_create)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
exists = await in_memory_repository.user_exists(sample_user_create.username)
|
exists = in_memory_repository.user_exists(sample_user_create.username)
|
||||||
not_exists = await in_memory_repository.user_exists("nonexistent")
|
not_exists = in_memory_repository.user_exists("nonexistent")
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert exists is True
|
assert exists is True
|
||||||
assert not_exists is False
|
assert not_exists is False
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from unittest.mock import patch
|
|||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
from bson import ObjectId
|
from bson import ObjectId
|
||||||
from mongomock_motor import AsyncMongoMockClient
|
from mongomock.mongo_client import MongoClient
|
||||||
|
|
||||||
from app.models.document import FileType
|
from app.models.document import FileType
|
||||||
from app.services.document_service import DocumentService
|
from app.services.document_service import DocumentService
|
||||||
@@ -24,15 +24,15 @@ def cleanup_test_folder():
|
|||||||
shutil.rmtree("test_folder", ignore_errors=True)
|
shutil.rmtree("test_folder", ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
@pytest.fixture
|
||||||
async def in_memory_database():
|
def in_memory_database():
|
||||||
"""Create an in-memory database for testing."""
|
"""Create an in-memory database for testing."""
|
||||||
client = AsyncMongoMockClient()
|
client = MongoClient()
|
||||||
return client.test_database
|
return client.test_database
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
@pytest_asyncio.fixture
|
||||||
async def document_service(in_memory_database):
|
def document_service(in_memory_database):
|
||||||
"""Create DocumentService with in-memory repositories."""
|
"""Create DocumentService with in-memory repositories."""
|
||||||
service = DocumentService(in_memory_database, objects_folder="test_folder")
|
service = DocumentService(in_memory_database, objects_folder="test_folder")
|
||||||
return service
|
return service
|
||||||
@@ -72,8 +72,7 @@ class TestCreateDocument:
|
|||||||
|
|
||||||
@patch('app.services.document_service.magic.from_buffer')
|
@patch('app.services.document_service.magic.from_buffer')
|
||||||
@patch('app.services.document_service.datetime')
|
@patch('app.services.document_service.datetime')
|
||||||
@pytest.mark.asyncio
|
def test_i_can_create_document_with_new_content(
|
||||||
async def test_i_can_create_document_with_new_content(
|
|
||||||
self,
|
self,
|
||||||
mock_datetime,
|
mock_datetime,
|
||||||
mock_magic,
|
mock_magic,
|
||||||
@@ -87,7 +86,7 @@ class TestCreateDocument:
|
|||||||
mock_magic.return_value = "application/pdf"
|
mock_magic.return_value = "application/pdf"
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
result = await document_service.create_document(
|
result = document_service.create_document(
|
||||||
"/test/test.pdf",
|
"/test/test.pdf",
|
||||||
sample_file_bytes,
|
sample_file_bytes,
|
||||||
"utf-8"
|
"utf-8"
|
||||||
@@ -102,7 +101,7 @@ class TestCreateDocument:
|
|||||||
assert result.file_hash == document_service._calculate_file_hash(sample_file_bytes)
|
assert result.file_hash == document_service._calculate_file_hash(sample_file_bytes)
|
||||||
|
|
||||||
# Verify document created in database
|
# Verify document created in database
|
||||||
doc_in_db = await document_service.document_repository.find_document_by_id(result.id)
|
doc_in_db = document_service.document_repository.find_document_by_id(result.id)
|
||||||
assert doc_in_db is not None
|
assert doc_in_db is not None
|
||||||
assert doc_in_db.id == result.id
|
assert doc_in_db.id == result.id
|
||||||
assert doc_in_db.filename == result.filename
|
assert doc_in_db.filename == result.filename
|
||||||
@@ -116,8 +115,7 @@ class TestCreateDocument:
|
|||||||
|
|
||||||
@patch('app.services.document_service.magic.from_buffer')
|
@patch('app.services.document_service.magic.from_buffer')
|
||||||
@patch('app.services.document_service.datetime')
|
@patch('app.services.document_service.datetime')
|
||||||
@pytest.mark.asyncio
|
def test_i_can_create_document_with_existing_content(
|
||||||
async def test_i_can_create_document_with_existing_content(
|
|
||||||
self,
|
self,
|
||||||
mock_datetime,
|
mock_datetime,
|
||||||
mock_magic,
|
mock_magic,
|
||||||
@@ -131,14 +129,14 @@ class TestCreateDocument:
|
|||||||
mock_magic.return_value = "application/pdf"
|
mock_magic.return_value = "application/pdf"
|
||||||
|
|
||||||
# Create first document
|
# Create first document
|
||||||
first_doc = await document_service.create_document(
|
first_doc = document_service.create_document(
|
||||||
"/test/first.pdf",
|
"/test/first.pdf",
|
||||||
sample_file_bytes,
|
sample_file_bytes,
|
||||||
"utf-8"
|
"utf-8"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create second document with same content
|
# Create second document with same content
|
||||||
second_doc = await document_service.create_document(
|
second_doc = document_service.create_document(
|
||||||
"/test/second.pdf",
|
"/test/second.pdf",
|
||||||
sample_file_bytes,
|
sample_file_bytes,
|
||||||
"utf-8"
|
"utf-8"
|
||||||
@@ -149,37 +147,34 @@ class TestCreateDocument:
|
|||||||
assert first_doc.filename != second_doc.filename
|
assert first_doc.filename != second_doc.filename
|
||||||
assert first_doc.filepath != second_doc.filepath
|
assert first_doc.filepath != second_doc.filepath
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_create_document_with_unsupported_file_type(
|
||||||
async def test_i_cannot_create_document_with_unsupported_file_type(
|
|
||||||
self,
|
self,
|
||||||
document_service,
|
document_service,
|
||||||
sample_file_bytes
|
sample_file_bytes
|
||||||
):
|
):
|
||||||
"""Test that unsupported file types raise ValueError."""
|
"""Test that unsupported file types raise ValueError."""
|
||||||
with pytest.raises(ValueError, match="Unsupported file type"):
|
with pytest.raises(ValueError, match="Unsupported file type"):
|
||||||
await document_service.create_document(
|
document_service.create_document(
|
||||||
"/test/test.xyz", # Unsupported extension
|
"/test/test.xyz", # Unsupported extension
|
||||||
sample_file_bytes,
|
sample_file_bytes,
|
||||||
"utf-8"
|
"utf-8"
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_create_document_with_empty_file_path(
|
||||||
async def test_i_cannot_create_document_with_empty_file_path(
|
|
||||||
self,
|
self,
|
||||||
document_service,
|
document_service,
|
||||||
sample_file_bytes
|
sample_file_bytes
|
||||||
):
|
):
|
||||||
"""Test that empty file path raises ValueError."""
|
"""Test that empty file path raises ValueError."""
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await document_service.create_document(
|
document_service.create_document(
|
||||||
"", # Empty path
|
"", # Empty path
|
||||||
sample_file_bytes,
|
sample_file_bytes,
|
||||||
"utf-8"
|
"utf-8"
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch('app.services.document_service.magic.from_buffer')
|
@patch('app.services.document_service.magic.from_buffer')
|
||||||
@pytest.mark.asyncio
|
def test_i_can_create_document_with_empty_bytes(
|
||||||
async def test_i_can_create_document_with_empty_bytes(
|
|
||||||
self,
|
self,
|
||||||
mock_magic,
|
mock_magic,
|
||||||
document_service
|
document_service
|
||||||
@@ -189,7 +184,7 @@ class TestCreateDocument:
|
|||||||
mock_magic.return_value = "text/plain"
|
mock_magic.return_value = "text/plain"
|
||||||
|
|
||||||
# Execute with empty bytes
|
# Execute with empty bytes
|
||||||
result = await document_service.create_document(
|
result = document_service.create_document(
|
||||||
"/test/empty.txt",
|
"/test/empty.txt",
|
||||||
b"", # Empty bytes
|
b"", # Empty bytes
|
||||||
"utf-8"
|
"utf-8"
|
||||||
@@ -203,8 +198,7 @@ class TestGetMethods:
|
|||||||
"""Tests for document retrieval methods."""
|
"""Tests for document retrieval methods."""
|
||||||
|
|
||||||
@patch('app.services.document_service.magic.from_buffer')
|
@patch('app.services.document_service.magic.from_buffer')
|
||||||
@pytest.mark.asyncio
|
def test_i_can_get_document_by_id(
|
||||||
async def test_i_can_get_document_by_id(
|
|
||||||
self,
|
self,
|
||||||
mock_magic,
|
mock_magic,
|
||||||
document_service,
|
document_service,
|
||||||
@@ -215,14 +209,14 @@ class TestGetMethods:
|
|||||||
mock_magic.return_value = "application/pdf"
|
mock_magic.return_value = "application/pdf"
|
||||||
|
|
||||||
# Create a document first
|
# Create a document first
|
||||||
created_doc = await document_service.create_document(
|
created_doc = document_service.create_document(
|
||||||
"/test/test.pdf",
|
"/test/test.pdf",
|
||||||
sample_file_bytes,
|
sample_file_bytes,
|
||||||
"utf-8"
|
"utf-8"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
result = await document_service.get_document_by_id(created_doc.id)
|
result = document_service.get_document_by_id(created_doc.id)
|
||||||
|
|
||||||
# Verify
|
# Verify
|
||||||
assert result is not None
|
assert result is not None
|
||||||
@@ -230,8 +224,7 @@ class TestGetMethods:
|
|||||||
assert result.filename == created_doc.filename
|
assert result.filename == created_doc.filename
|
||||||
|
|
||||||
@patch('app.services.document_service.magic.from_buffer')
|
@patch('app.services.document_service.magic.from_buffer')
|
||||||
@pytest.mark.asyncio
|
def test_i_can_get_document_by_hash(
|
||||||
async def test_i_can_get_document_by_hash(
|
|
||||||
self,
|
self,
|
||||||
mock_magic,
|
mock_magic,
|
||||||
document_service,
|
document_service,
|
||||||
@@ -242,14 +235,14 @@ class TestGetMethods:
|
|||||||
mock_magic.return_value = "application/pdf"
|
mock_magic.return_value = "application/pdf"
|
||||||
|
|
||||||
# Create a document first
|
# Create a document first
|
||||||
created_doc = await document_service.create_document(
|
created_doc = document_service.create_document(
|
||||||
"/test/test.pdf",
|
"/test/test.pdf",
|
||||||
sample_file_bytes,
|
sample_file_bytes,
|
||||||
"utf-8"
|
"utf-8"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
result = await document_service.get_document_by_hash(created_doc.file_hash)
|
result = document_service.get_document_by_hash(created_doc.file_hash)
|
||||||
|
|
||||||
# Verify
|
# Verify
|
||||||
assert result is not None
|
assert result is not None
|
||||||
@@ -257,8 +250,7 @@ class TestGetMethods:
|
|||||||
assert result.filename == created_doc.filename
|
assert result.filename == created_doc.filename
|
||||||
|
|
||||||
@patch('app.services.document_service.magic.from_buffer')
|
@patch('app.services.document_service.magic.from_buffer')
|
||||||
@pytest.mark.asyncio
|
def test_i_can_get_document_by_filepath(
|
||||||
async def test_i_can_get_document_by_filepath(
|
|
||||||
self,
|
self,
|
||||||
mock_magic,
|
mock_magic,
|
||||||
document_service,
|
document_service,
|
||||||
@@ -270,14 +262,14 @@ class TestGetMethods:
|
|||||||
test_path = "/test/unique_test.pdf"
|
test_path = "/test/unique_test.pdf"
|
||||||
|
|
||||||
# Create a document first
|
# Create a document first
|
||||||
created_doc = await document_service.create_document(
|
created_doc = document_service.create_document(
|
||||||
test_path,
|
test_path,
|
||||||
sample_file_bytes,
|
sample_file_bytes,
|
||||||
"utf-8"
|
"utf-8"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
result = await document_service.get_document_by_filepath(test_path)
|
result = document_service.get_document_by_filepath(test_path)
|
||||||
|
|
||||||
# Verify
|
# Verify
|
||||||
assert result is not None
|
assert result is not None
|
||||||
@@ -285,8 +277,7 @@ class TestGetMethods:
|
|||||||
assert result.id == created_doc.id
|
assert result.id == created_doc.id
|
||||||
|
|
||||||
@patch('app.services.document_service.magic.from_buffer')
|
@patch('app.services.document_service.magic.from_buffer')
|
||||||
@pytest.mark.asyncio
|
def test_i_can_get_document_content(
|
||||||
async def test_i_can_get_document_content(
|
|
||||||
self,
|
self,
|
||||||
mock_magic,
|
mock_magic,
|
||||||
document_service,
|
document_service,
|
||||||
@@ -297,38 +288,36 @@ class TestGetMethods:
|
|||||||
mock_magic.return_value = "application/pdf"
|
mock_magic.return_value = "application/pdf"
|
||||||
|
|
||||||
# Create a document first
|
# Create a document first
|
||||||
created_doc = await document_service.create_document(
|
created_doc = document_service.create_document(
|
||||||
"/test/test.pdf",
|
"/test/test.pdf",
|
||||||
sample_file_bytes,
|
sample_file_bytes,
|
||||||
"utf-8"
|
"utf-8"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
result = await document_service.get_document_content_by_hash(created_doc.file_hash)
|
result = document_service.get_document_content_by_hash(created_doc.file_hash)
|
||||||
|
|
||||||
# Verify
|
# Verify
|
||||||
assert result == sample_file_bytes
|
assert result == sample_file_bytes
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_get_nonexistent_document_by_id(
|
||||||
async def test_i_cannot_get_nonexistent_document_by_id(
|
|
||||||
self,
|
self,
|
||||||
document_service
|
document_service
|
||||||
):
|
):
|
||||||
"""Test that nonexistent document returns None."""
|
"""Test that nonexistent document returns None."""
|
||||||
# Execute with random ObjectId
|
# Execute with random ObjectId
|
||||||
result = await document_service.get_document_by_id(ObjectId())
|
result = document_service.get_document_by_id(ObjectId())
|
||||||
|
|
||||||
# Verify
|
# Verify
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_get_nonexistent_document_by_hash(
|
||||||
async def test_i_cannot_get_nonexistent_document_by_hash(
|
|
||||||
self,
|
self,
|
||||||
document_service
|
document_service
|
||||||
):
|
):
|
||||||
"""Test that nonexistent document hash returns None."""
|
"""Test that nonexistent document hash returns None."""
|
||||||
# Execute
|
# Execute
|
||||||
result = await document_service.get_document_by_hash("nonexistent_hash")
|
result = document_service.get_document_by_hash("nonexistent_hash")
|
||||||
|
|
||||||
# Verify
|
# Verify
|
||||||
assert result is None
|
assert result is None
|
||||||
@@ -338,8 +327,7 @@ class TestPaginationAndCounting:
|
|||||||
"""Tests for document listing and counting."""
|
"""Tests for document listing and counting."""
|
||||||
|
|
||||||
@patch('app.services.document_service.magic.from_buffer')
|
@patch('app.services.document_service.magic.from_buffer')
|
||||||
@pytest.mark.asyncio
|
def test_i_can_list_documents_with_pagination(
|
||||||
async def test_i_can_list_documents_with_pagination(
|
|
||||||
self,
|
self,
|
||||||
mock_magic,
|
mock_magic,
|
||||||
document_service,
|
document_service,
|
||||||
@@ -351,25 +339,24 @@ class TestPaginationAndCounting:
|
|||||||
|
|
||||||
# Create multiple documents
|
# Create multiple documents
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
await document_service.create_document(
|
document_service.create_document(
|
||||||
f"/test/test{i}.pdf",
|
f"/test/test{i}.pdf",
|
||||||
sample_file_bytes + bytes(str(i), 'utf-8'), # Make each file unique
|
sample_file_bytes + bytes(str(i), 'utf-8'), # Make each file unique
|
||||||
"utf-8"
|
"utf-8"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Execute with pagination
|
# Execute with pagination
|
||||||
result = await document_service.list_documents(skip=1, limit=2)
|
result = document_service.list_documents(skip=1, limit=2)
|
||||||
|
|
||||||
# Verify
|
# Verify
|
||||||
assert len(result) == 2
|
assert len(result) == 2
|
||||||
|
|
||||||
# Test counting
|
# Test counting
|
||||||
total_count = await document_service.count_documents()
|
total_count = document_service.count_documents()
|
||||||
assert total_count == 5
|
assert total_count == 5
|
||||||
|
|
||||||
@patch('app.services.document_service.magic.from_buffer')
|
@patch('app.services.document_service.magic.from_buffer')
|
||||||
@pytest.mark.asyncio
|
def test_i_can_count_documents(
|
||||||
async def test_i_can_count_documents(
|
|
||||||
self,
|
self,
|
||||||
mock_magic,
|
mock_magic,
|
||||||
document_service,
|
document_service,
|
||||||
@@ -380,19 +367,19 @@ class TestPaginationAndCounting:
|
|||||||
mock_magic.return_value = "text/plain"
|
mock_magic.return_value = "text/plain"
|
||||||
|
|
||||||
# Initially should be 0
|
# Initially should be 0
|
||||||
initial_count = await document_service.count_documents()
|
initial_count = document_service.count_documents()
|
||||||
assert initial_count == 0
|
assert initial_count == 0
|
||||||
|
|
||||||
# Create some documents
|
# Create some documents
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
await document_service.create_document(
|
document_service.create_document(
|
||||||
f"/test/test{i}.txt",
|
f"/test/test{i}.txt",
|
||||||
sample_file_bytes + bytes(str(i), 'utf-8'),
|
sample_file_bytes + bytes(str(i), 'utf-8'),
|
||||||
"utf-8"
|
"utf-8"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
final_count = await document_service.count_documents()
|
final_count = document_service.count_documents()
|
||||||
|
|
||||||
# Verify
|
# Verify
|
||||||
assert final_count == 3
|
assert final_count == 3
|
||||||
@@ -402,8 +389,7 @@ class TestUpdateAndDelete:
|
|||||||
"""Tests for document update and deletion operations."""
|
"""Tests for document update and deletion operations."""
|
||||||
|
|
||||||
@patch('app.services.document_service.magic.from_buffer')
|
@patch('app.services.document_service.magic.from_buffer')
|
||||||
@pytest.mark.asyncio
|
def test_i_can_update_document_metadata(
|
||||||
async def test_i_can_update_document_metadata(
|
|
||||||
self,
|
self,
|
||||||
mock_magic,
|
mock_magic,
|
||||||
document_service,
|
document_service,
|
||||||
@@ -414,7 +400,7 @@ class TestUpdateAndDelete:
|
|||||||
mock_magic.return_value = "application/pdf"
|
mock_magic.return_value = "application/pdf"
|
||||||
|
|
||||||
# Create a document first
|
# Create a document first
|
||||||
created_doc = await document_service.create_document(
|
created_doc = document_service.create_document(
|
||||||
"/test/test.pdf",
|
"/test/test.pdf",
|
||||||
sample_file_bytes,
|
sample_file_bytes,
|
||||||
"utf-8"
|
"utf-8"
|
||||||
@@ -422,7 +408,7 @@ class TestUpdateAndDelete:
|
|||||||
|
|
||||||
# Execute update
|
# Execute update
|
||||||
update_data = {"metadata": {"page_count": 5}}
|
update_data = {"metadata": {"page_count": 5}}
|
||||||
result = await document_service.update_document(created_doc.id, update_data)
|
result = document_service.update_document(created_doc.id, update_data)
|
||||||
|
|
||||||
# Verify
|
# Verify
|
||||||
assert result is not None
|
assert result is not None
|
||||||
@@ -433,14 +419,13 @@ class TestUpdateAndDelete:
|
|||||||
assert result.file_type == created_doc.file_type
|
assert result.file_type == created_doc.file_type
|
||||||
assert result.metadata == update_data['metadata']
|
assert result.metadata == update_data['metadata']
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_update_document_content(
|
||||||
async def test_i_can_update_document_content(
|
|
||||||
self,
|
self,
|
||||||
document_service,
|
document_service,
|
||||||
sample_file_bytes
|
sample_file_bytes
|
||||||
):
|
):
|
||||||
# Create a document first
|
# Create a document first
|
||||||
created_doc = await document_service.create_document(
|
created_doc = document_service.create_document(
|
||||||
"/test/test.pdf",
|
"/test/test.pdf",
|
||||||
sample_file_bytes,
|
sample_file_bytes,
|
||||||
"utf-8"
|
"utf-8"
|
||||||
@@ -448,7 +433,7 @@ class TestUpdateAndDelete:
|
|||||||
|
|
||||||
# Execute update
|
# Execute update
|
||||||
update_data = {"file_bytes": b"this is an updated file content"}
|
update_data = {"file_bytes": b"this is an updated file content"}
|
||||||
result = await document_service.update_document(created_doc.id, update_data)
|
result = document_service.update_document(created_doc.id, update_data)
|
||||||
|
|
||||||
assert result.filename == created_doc.filename
|
assert result.filename == created_doc.filename
|
||||||
assert result.filepath == created_doc.filepath
|
assert result.filepath == created_doc.filepath
|
||||||
@@ -460,8 +445,7 @@ class TestUpdateAndDelete:
|
|||||||
validate_file_saved(document_service, result.file_hash, b"this is an updated file content")
|
validate_file_saved(document_service, result.file_hash, b"this is an updated file content")
|
||||||
|
|
||||||
@patch('app.services.document_service.magic.from_buffer')
|
@patch('app.services.document_service.magic.from_buffer')
|
||||||
@pytest.mark.asyncio
|
def test_i_can_delete_document_and_orphaned_content(
|
||||||
async def test_i_can_delete_document_and_orphaned_content(
|
|
||||||
self,
|
self,
|
||||||
mock_magic,
|
mock_magic,
|
||||||
document_service,
|
document_service,
|
||||||
@@ -472,7 +456,7 @@ class TestUpdateAndDelete:
|
|||||||
mock_magic.return_value = "application/pdf"
|
mock_magic.return_value = "application/pdf"
|
||||||
|
|
||||||
# Create a document
|
# Create a document
|
||||||
created_doc = await document_service.create_document(
|
created_doc = document_service.create_document(
|
||||||
"/test/test.pdf",
|
"/test/test.pdf",
|
||||||
sample_file_bytes,
|
sample_file_bytes,
|
||||||
"utf-8"
|
"utf-8"
|
||||||
@@ -482,12 +466,12 @@ class TestUpdateAndDelete:
|
|||||||
validate_file_saved(document_service, created_doc.file_hash, sample_file_bytes)
|
validate_file_saved(document_service, created_doc.file_hash, sample_file_bytes)
|
||||||
|
|
||||||
# Execute deletion
|
# Execute deletion
|
||||||
result = await document_service.delete_document(created_doc.id)
|
result = document_service.delete_document(created_doc.id)
|
||||||
|
|
||||||
# Verify document and content are deleted
|
# Verify document and content are deleted
|
||||||
assert result is True
|
assert result is True
|
||||||
|
|
||||||
deleted_doc = await document_service.get_document_by_id(created_doc.id)
|
deleted_doc = document_service.get_document_by_id(created_doc.id)
|
||||||
assert deleted_doc is None
|
assert deleted_doc is None
|
||||||
|
|
||||||
# validate content is deleted
|
# validate content is deleted
|
||||||
@@ -496,8 +480,7 @@ class TestUpdateAndDelete:
|
|||||||
assert not os.path.exists(target_file_path)
|
assert not os.path.exists(target_file_path)
|
||||||
|
|
||||||
@patch('app.services.document_service.magic.from_buffer')
|
@patch('app.services.document_service.magic.from_buffer')
|
||||||
@pytest.mark.asyncio
|
def test_i_can_delete_document_without_affecting_shared_content(
|
||||||
async def test_i_can_delete_document_without_affecting_shared_content(
|
|
||||||
self,
|
self,
|
||||||
mock_magic,
|
mock_magic,
|
||||||
document_service,
|
document_service,
|
||||||
@@ -508,13 +491,13 @@ class TestUpdateAndDelete:
|
|||||||
mock_magic.return_value = "application/pdf"
|
mock_magic.return_value = "application/pdf"
|
||||||
|
|
||||||
# Create two documents with same content
|
# Create two documents with same content
|
||||||
doc1 = await document_service.create_document(
|
doc1 = document_service.create_document(
|
||||||
"/test/test1.pdf",
|
"/test/test1.pdf",
|
||||||
sample_file_bytes,
|
sample_file_bytes,
|
||||||
"utf-8"
|
"utf-8"
|
||||||
)
|
)
|
||||||
|
|
||||||
doc2 = await document_service.create_document(
|
doc2 = document_service.create_document(
|
||||||
"/test/test2.pdf",
|
"/test/test2.pdf",
|
||||||
sample_file_bytes,
|
sample_file_bytes,
|
||||||
"utf-8"
|
"utf-8"
|
||||||
@@ -524,14 +507,14 @@ class TestUpdateAndDelete:
|
|||||||
assert doc1.file_hash == doc2.file_hash
|
assert doc1.file_hash == doc2.file_hash
|
||||||
|
|
||||||
# Delete first document
|
# Delete first document
|
||||||
result = await document_service.delete_document(doc1.id)
|
result = document_service.delete_document(doc1.id)
|
||||||
assert result is True
|
assert result is True
|
||||||
|
|
||||||
# Verify first document is deleted but content still exists
|
# Verify first document is deleted but content still exists
|
||||||
deleted_doc = await document_service.get_document_by_id(doc1.id)
|
deleted_doc = document_service.get_document_by_id(doc1.id)
|
||||||
assert deleted_doc is None
|
assert deleted_doc is None
|
||||||
|
|
||||||
remaining_doc = await document_service.get_document_by_id(doc2.id)
|
remaining_doc = document_service.get_document_by_id(doc2.id)
|
||||||
assert remaining_doc is not None
|
assert remaining_doc is not None
|
||||||
|
|
||||||
validate_file_saved(document_service, doc2.file_hash, sample_file_bytes)
|
validate_file_saved(document_service, doc2.file_hash, sample_file_bytes)
|
||||||
|
|||||||
@@ -6,9 +6,8 @@ using mongomock for better integration testing.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
|
||||||
from bson import ObjectId
|
from bson import ObjectId
|
||||||
from mongomock_motor import AsyncMongoMockClient
|
from mongomock.mongo_client import MongoClient
|
||||||
|
|
||||||
from app.exceptions.job_exceptions import InvalidStatusTransitionError
|
from app.exceptions.job_exceptions import InvalidStatusTransitionError
|
||||||
from app.models.job import ProcessingStatus
|
from app.models.job import ProcessingStatus
|
||||||
@@ -16,17 +15,17 @@ from app.models.types import PyObjectId
|
|||||||
from app.services.job_service import JobService
|
from app.services.job_service import JobService
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
@pytest.fixture
|
||||||
async def in_memory_database():
|
def in_memory_database():
|
||||||
"""Create an in-memory database for testing."""
|
"""Create an in-memory database for testing."""
|
||||||
client = AsyncMongoMockClient()
|
client = MongoClient()
|
||||||
return client.test_database
|
return client.test_database
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
@pytest.fixture
|
||||||
async def job_service(in_memory_database):
|
def job_service(in_memory_database):
|
||||||
"""Create JobService with in-memory repositories."""
|
"""Create JobService with in-memory repositories."""
|
||||||
service = await JobService(in_memory_database).initialize()
|
service = JobService(in_memory_database).initialize()
|
||||||
return service
|
return service
|
||||||
|
|
||||||
|
|
||||||
@@ -45,8 +44,7 @@ def sample_task_id():
|
|||||||
class TestCreateJob:
|
class TestCreateJob:
|
||||||
"""Tests for create_job method."""
|
"""Tests for create_job method."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_create_job_with_task_id(
|
||||||
async def test_i_can_create_job_with_task_id(
|
|
||||||
self,
|
self,
|
||||||
job_service,
|
job_service,
|
||||||
sample_document_id,
|
sample_document_id,
|
||||||
@@ -54,7 +52,7 @@ class TestCreateJob:
|
|||||||
):
|
):
|
||||||
"""Test creating job with task ID."""
|
"""Test creating job with task ID."""
|
||||||
# Execute
|
# Execute
|
||||||
result = await job_service.create_job(sample_document_id, sample_task_id)
|
result = job_service.create_job(sample_document_id, sample_task_id)
|
||||||
|
|
||||||
# Verify job creation
|
# Verify job creation
|
||||||
assert result is not None
|
assert result is not None
|
||||||
@@ -66,22 +64,21 @@ class TestCreateJob:
|
|||||||
assert result.error_message is None
|
assert result.error_message is None
|
||||||
|
|
||||||
# Verify job exists in database
|
# Verify job exists in database
|
||||||
job_in_db = await job_service.get_job_by_id(result.id)
|
job_in_db = job_service.get_job_by_id(result.id)
|
||||||
assert job_in_db is not None
|
assert job_in_db is not None
|
||||||
assert job_in_db.id == result.id
|
assert job_in_db.id == result.id
|
||||||
assert job_in_db.document_id == sample_document_id
|
assert job_in_db.document_id == sample_document_id
|
||||||
assert job_in_db.task_id == sample_task_id
|
assert job_in_db.task_id == sample_task_id
|
||||||
assert job_in_db.status == ProcessingStatus.PENDING
|
assert job_in_db.status == ProcessingStatus.PENDING
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_create_job_without_task_id(
|
||||||
async def test_i_can_create_job_without_task_id(
|
|
||||||
self,
|
self,
|
||||||
job_service,
|
job_service,
|
||||||
sample_document_id
|
sample_document_id
|
||||||
):
|
):
|
||||||
"""Test creating job without task ID."""
|
"""Test creating job without task ID."""
|
||||||
# Execute
|
# Execute
|
||||||
result = await job_service.create_job(sample_document_id)
|
result = job_service.create_job(sample_document_id)
|
||||||
|
|
||||||
# Verify job creation
|
# Verify job creation
|
||||||
assert result is not None
|
assert result is not None
|
||||||
@@ -96,8 +93,7 @@ class TestCreateJob:
|
|||||||
class TestGetJobMethods:
|
class TestGetJobMethods:
|
||||||
"""Tests for job retrieval methods."""
|
"""Tests for job retrieval methods."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_get_job_by_id(
|
||||||
async def test_i_can_get_job_by_id(
|
|
||||||
self,
|
self,
|
||||||
job_service,
|
job_service,
|
||||||
sample_document_id,
|
sample_document_id,
|
||||||
@@ -105,10 +101,10 @@ class TestGetJobMethods:
|
|||||||
):
|
):
|
||||||
"""Test retrieving job by ID."""
|
"""Test retrieving job by ID."""
|
||||||
# Create a job first
|
# Create a job first
|
||||||
created_job = await job_service.create_job(sample_document_id, sample_task_id)
|
created_job = job_service.create_job(sample_document_id, sample_task_id)
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
result = await job_service.get_job_by_id(created_job.id)
|
result = job_service.get_job_by_id(created_job.id)
|
||||||
|
|
||||||
# Verify
|
# Verify
|
||||||
assert result is not None
|
assert result is not None
|
||||||
@@ -117,25 +113,24 @@ class TestGetJobMethods:
|
|||||||
assert result.task_id == created_job.task_id
|
assert result.task_id == created_job.task_id
|
||||||
assert result.status == created_job.status
|
assert result.status == created_job.status
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_get_jobs_by_status(
|
||||||
async def test_i_can_get_jobs_by_status(
|
|
||||||
self,
|
self,
|
||||||
job_service,
|
job_service,
|
||||||
sample_document_id
|
sample_document_id
|
||||||
):
|
):
|
||||||
"""Test retrieving jobs by status."""
|
"""Test retrieving jobs by status."""
|
||||||
# Create jobs with different statuses
|
# Create jobs with different statuses
|
||||||
pending_job = await job_service.create_job(sample_document_id, "pending-task")
|
pending_job = job_service.create_job(sample_document_id, "pending-task")
|
||||||
|
|
||||||
processing_job = await job_service.create_job(ObjectId(), "processing-task")
|
processing_job = job_service.create_job(ObjectId(), "processing-task")
|
||||||
await job_service.mark_job_as_started(processing_job.id)
|
job_service.mark_job_as_started(processing_job.id)
|
||||||
|
|
||||||
completed_job = await job_service.create_job(ObjectId(), "completed-task")
|
completed_job = job_service.create_job(ObjectId(), "completed-task")
|
||||||
await job_service.mark_job_as_started(completed_job.id)
|
job_service.mark_job_as_started(completed_job.id)
|
||||||
await job_service.mark_job_as_completed(completed_job.id)
|
job_service.mark_job_as_completed(completed_job.id)
|
||||||
|
|
||||||
# Execute - get pending jobs
|
# Execute - get pending jobs
|
||||||
pending_results = await job_service.get_jobs_by_status(ProcessingStatus.PENDING)
|
pending_results = job_service.get_jobs_by_status(ProcessingStatus.PENDING)
|
||||||
|
|
||||||
# Verify
|
# Verify
|
||||||
assert len(pending_results) == 1
|
assert len(pending_results) == 1
|
||||||
@@ -143,12 +138,12 @@ class TestGetJobMethods:
|
|||||||
assert pending_results[0].status == ProcessingStatus.PENDING
|
assert pending_results[0].status == ProcessingStatus.PENDING
|
||||||
|
|
||||||
# Execute - get processing jobs
|
# Execute - get processing jobs
|
||||||
processing_results = await job_service.get_jobs_by_status(ProcessingStatus.PROCESSING)
|
processing_results = job_service.get_jobs_by_status(ProcessingStatus.PROCESSING)
|
||||||
assert len(processing_results) == 1
|
assert len(processing_results) == 1
|
||||||
assert processing_results[0].status == ProcessingStatus.PROCESSING
|
assert processing_results[0].status == ProcessingStatus.PROCESSING
|
||||||
|
|
||||||
# Execute - get completed jobs
|
# Execute - get completed jobs
|
||||||
completed_results = await job_service.get_jobs_by_status(ProcessingStatus.COMPLETED)
|
completed_results = job_service.get_jobs_by_status(ProcessingStatus.COMPLETED)
|
||||||
assert len(completed_results) == 1
|
assert len(completed_results) == 1
|
||||||
assert completed_results[0].status == ProcessingStatus.COMPLETED
|
assert completed_results[0].status == ProcessingStatus.COMPLETED
|
||||||
|
|
||||||
@@ -156,8 +151,7 @@ class TestGetJobMethods:
|
|||||||
class TestUpdateStatus:
|
class TestUpdateStatus:
|
||||||
"""Tests for mark_job_as_started method."""
|
"""Tests for mark_job_as_started method."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_mark_pending_job_as_started(
|
||||||
async def test_i_can_mark_pending_job_as_started(
|
|
||||||
self,
|
self,
|
||||||
job_service,
|
job_service,
|
||||||
sample_document_id,
|
sample_document_id,
|
||||||
@@ -165,11 +159,11 @@ class TestUpdateStatus:
|
|||||||
):
|
):
|
||||||
"""Test marking pending job as started (PENDING → PROCESSING)."""
|
"""Test marking pending job as started (PENDING → PROCESSING)."""
|
||||||
# Create a pending job
|
# Create a pending job
|
||||||
created_job = await job_service.create_job(sample_document_id, sample_task_id)
|
created_job = job_service.create_job(sample_document_id, sample_task_id)
|
||||||
assert created_job.status == ProcessingStatus.PENDING
|
assert created_job.status == ProcessingStatus.PENDING
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
result = await job_service.mark_job_as_started(created_job.id)
|
result = job_service.mark_job_as_started(created_job.id)
|
||||||
|
|
||||||
# Verify status transition
|
# Verify status transition
|
||||||
assert result is not None
|
assert result is not None
|
||||||
@@ -177,11 +171,10 @@ class TestUpdateStatus:
|
|||||||
assert result.status == ProcessingStatus.PROCESSING
|
assert result.status == ProcessingStatus.PROCESSING
|
||||||
|
|
||||||
# Verify in database
|
# Verify in database
|
||||||
updated_job = await job_service.get_job_by_id(created_job.id)
|
updated_job = job_service.get_job_by_id(created_job.id)
|
||||||
assert updated_job.status == ProcessingStatus.PROCESSING
|
assert updated_job.status == ProcessingStatus.PROCESSING
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_mark_processing_job_as_started(
|
||||||
async def test_i_cannot_mark_processing_job_as_started(
|
|
||||||
self,
|
self,
|
||||||
job_service,
|
job_service,
|
||||||
sample_document_id,
|
sample_document_id,
|
||||||
@@ -189,19 +182,18 @@ class TestUpdateStatus:
|
|||||||
):
|
):
|
||||||
"""Test that processing job cannot be marked as started."""
|
"""Test that processing job cannot be marked as started."""
|
||||||
# Create and start a job
|
# Create and start a job
|
||||||
created_job = await job_service.create_job(sample_document_id, sample_task_id)
|
created_job = job_service.create_job(sample_document_id, sample_task_id)
|
||||||
await job_service.mark_job_as_started(created_job.id)
|
job_service.mark_job_as_started(created_job.id)
|
||||||
|
|
||||||
# Try to start it again
|
# Try to start it again
|
||||||
with pytest.raises(InvalidStatusTransitionError) as exc_info:
|
with pytest.raises(InvalidStatusTransitionError) as exc_info:
|
||||||
await job_service.mark_job_as_started(created_job.id)
|
job_service.mark_job_as_started(created_job.id)
|
||||||
|
|
||||||
# Verify exception details
|
# Verify exception details
|
||||||
assert exc_info.value.current_status == ProcessingStatus.PROCESSING
|
assert exc_info.value.current_status == ProcessingStatus.PROCESSING
|
||||||
assert exc_info.value.target_status == ProcessingStatus.PROCESSING
|
assert exc_info.value.target_status == ProcessingStatus.PROCESSING
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_mark_completed_job_as_started(
|
||||||
async def test_i_cannot_mark_completed_job_as_started(
|
|
||||||
self,
|
self,
|
||||||
job_service,
|
job_service,
|
||||||
sample_document_id,
|
sample_document_id,
|
||||||
@@ -209,20 +201,19 @@ class TestUpdateStatus:
|
|||||||
):
|
):
|
||||||
"""Test that completed job cannot be marked as started."""
|
"""Test that completed job cannot be marked as started."""
|
||||||
# Create, start, and complete a job
|
# Create, start, and complete a job
|
||||||
created_job = await job_service.create_job(sample_document_id, sample_task_id)
|
created_job = job_service.create_job(sample_document_id, sample_task_id)
|
||||||
await job_service.mark_job_as_started(created_job.id)
|
job_service.mark_job_as_started(created_job.id)
|
||||||
await job_service.mark_job_as_completed(created_job.id)
|
job_service.mark_job_as_completed(created_job.id)
|
||||||
|
|
||||||
# Try to start it again
|
# Try to start it again
|
||||||
with pytest.raises(InvalidStatusTransitionError) as exc_info:
|
with pytest.raises(InvalidStatusTransitionError) as exc_info:
|
||||||
await job_service.mark_job_as_started(created_job.id)
|
job_service.mark_job_as_started(created_job.id)
|
||||||
|
|
||||||
# Verify exception details
|
# Verify exception details
|
||||||
assert exc_info.value.current_status == ProcessingStatus.COMPLETED
|
assert exc_info.value.current_status == ProcessingStatus.COMPLETED
|
||||||
assert exc_info.value.target_status == ProcessingStatus.PROCESSING
|
assert exc_info.value.target_status == ProcessingStatus.PROCESSING
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_mark_failed_job_as_started(
|
||||||
async def test_i_cannot_mark_failed_job_as_started(
|
|
||||||
self,
|
self,
|
||||||
job_service,
|
job_service,
|
||||||
sample_document_id,
|
sample_document_id,
|
||||||
@@ -230,20 +221,19 @@ class TestUpdateStatus:
|
|||||||
):
|
):
|
||||||
"""Test that failed job cannot be marked as started."""
|
"""Test that failed job cannot be marked as started."""
|
||||||
# Create, start, and fail a job
|
# Create, start, and fail a job
|
||||||
created_job = await job_service.create_job(sample_document_id, sample_task_id)
|
created_job = job_service.create_job(sample_document_id, sample_task_id)
|
||||||
await job_service.mark_job_as_started(created_job.id)
|
job_service.mark_job_as_started(created_job.id)
|
||||||
await job_service.mark_job_as_failed(created_job.id, "Test error")
|
job_service.mark_job_as_failed(created_job.id, "Test error")
|
||||||
|
|
||||||
# Try to start it again
|
# Try to start it again
|
||||||
with pytest.raises(InvalidStatusTransitionError) as exc_info:
|
with pytest.raises(InvalidStatusTransitionError) as exc_info:
|
||||||
await job_service.mark_job_as_started(created_job.id)
|
job_service.mark_job_as_started(created_job.id)
|
||||||
|
|
||||||
# Verify exception details
|
# Verify exception details
|
||||||
assert exc_info.value.current_status == ProcessingStatus.FAILED
|
assert exc_info.value.current_status == ProcessingStatus.FAILED
|
||||||
assert exc_info.value.target_status == ProcessingStatus.PROCESSING
|
assert exc_info.value.target_status == ProcessingStatus.PROCESSING
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_mark_processing_job_as_completed(
|
||||||
async def test_i_can_mark_processing_job_as_completed(
|
|
||||||
self,
|
self,
|
||||||
job_service,
|
job_service,
|
||||||
sample_document_id,
|
sample_document_id,
|
||||||
@@ -251,11 +241,11 @@ class TestUpdateStatus:
|
|||||||
):
|
):
|
||||||
"""Test marking processing job as completed (PROCESSING → COMPLETED)."""
|
"""Test marking processing job as completed (PROCESSING → COMPLETED)."""
|
||||||
# Create and start a job
|
# Create and start a job
|
||||||
created_job = await job_service.create_job(sample_document_id, sample_task_id)
|
created_job = job_service.create_job(sample_document_id, sample_task_id)
|
||||||
started_job = await job_service.mark_job_as_started(created_job.id)
|
started_job = job_service.mark_job_as_started(created_job.id)
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
result = await job_service.mark_job_as_completed(created_job.id)
|
result = job_service.mark_job_as_completed(created_job.id)
|
||||||
|
|
||||||
# Verify status transition
|
# Verify status transition
|
||||||
assert result is not None
|
assert result is not None
|
||||||
@@ -263,11 +253,10 @@ class TestUpdateStatus:
|
|||||||
assert result.status == ProcessingStatus.COMPLETED
|
assert result.status == ProcessingStatus.COMPLETED
|
||||||
|
|
||||||
# Verify in database
|
# Verify in database
|
||||||
updated_job = await job_service.get_job_by_id(created_job.id)
|
updated_job = job_service.get_job_by_id(created_job.id)
|
||||||
assert updated_job.status == ProcessingStatus.COMPLETED
|
assert updated_job.status == ProcessingStatus.COMPLETED
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_mark_pending_job_as_completed(
|
||||||
async def test_i_cannot_mark_pending_job_as_completed(
|
|
||||||
self,
|
self,
|
||||||
job_service,
|
job_service,
|
||||||
sample_document_id,
|
sample_document_id,
|
||||||
@@ -275,18 +264,17 @@ class TestUpdateStatus:
|
|||||||
):
|
):
|
||||||
"""Test that pending job cannot be marked as completed."""
|
"""Test that pending job cannot be marked as completed."""
|
||||||
# Create a pending job
|
# Create a pending job
|
||||||
created_job = await job_service.create_job(sample_document_id, sample_task_id)
|
created_job = job_service.create_job(sample_document_id, sample_task_id)
|
||||||
|
|
||||||
# Try to complete it directly
|
# Try to complete it directly
|
||||||
with pytest.raises(InvalidStatusTransitionError) as exc_info:
|
with pytest.raises(InvalidStatusTransitionError) as exc_info:
|
||||||
await job_service.mark_job_as_completed(created_job.id)
|
job_service.mark_job_as_completed(created_job.id)
|
||||||
|
|
||||||
# Verify exception details
|
# Verify exception details
|
||||||
assert exc_info.value.current_status == ProcessingStatus.PENDING
|
assert exc_info.value.current_status == ProcessingStatus.PENDING
|
||||||
assert exc_info.value.target_status == ProcessingStatus.COMPLETED
|
assert exc_info.value.target_status == ProcessingStatus.COMPLETED
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_mark_completed_job_as_completed(
|
||||||
async def test_i_cannot_mark_completed_job_as_completed(
|
|
||||||
self,
|
self,
|
||||||
job_service,
|
job_service,
|
||||||
sample_document_id,
|
sample_document_id,
|
||||||
@@ -294,20 +282,19 @@ class TestUpdateStatus:
|
|||||||
):
|
):
|
||||||
"""Test that completed job cannot be marked as completed again."""
|
"""Test that completed job cannot be marked as completed again."""
|
||||||
# Create, start, and complete a job
|
# Create, start, and complete a job
|
||||||
created_job = await job_service.create_job(sample_document_id, sample_task_id)
|
created_job = job_service.create_job(sample_document_id, sample_task_id)
|
||||||
await job_service.mark_job_as_started(created_job.id)
|
job_service.mark_job_as_started(created_job.id)
|
||||||
await job_service.mark_job_as_completed(created_job.id)
|
job_service.mark_job_as_completed(created_job.id)
|
||||||
|
|
||||||
# Try to complete it again
|
# Try to complete it again
|
||||||
with pytest.raises(InvalidStatusTransitionError) as exc_info:
|
with pytest.raises(InvalidStatusTransitionError) as exc_info:
|
||||||
await job_service.mark_job_as_completed(created_job.id)
|
job_service.mark_job_as_completed(created_job.id)
|
||||||
|
|
||||||
# Verify exception details
|
# Verify exception details
|
||||||
assert exc_info.value.current_status == ProcessingStatus.COMPLETED
|
assert exc_info.value.current_status == ProcessingStatus.COMPLETED
|
||||||
assert exc_info.value.target_status == ProcessingStatus.COMPLETED
|
assert exc_info.value.target_status == ProcessingStatus.COMPLETED
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_mark_failed_job_as_completed(
|
||||||
async def test_i_cannot_mark_failed_job_as_completed(
|
|
||||||
self,
|
self,
|
||||||
job_service,
|
job_service,
|
||||||
sample_document_id,
|
sample_document_id,
|
||||||
@@ -315,20 +302,19 @@ class TestUpdateStatus:
|
|||||||
):
|
):
|
||||||
"""Test that failed job cannot be marked as completed."""
|
"""Test that failed job cannot be marked as completed."""
|
||||||
# Create, start, and fail a job
|
# Create, start, and fail a job
|
||||||
created_job = await job_service.create_job(sample_document_id, sample_task_id)
|
created_job = job_service.create_job(sample_document_id, sample_task_id)
|
||||||
await job_service.mark_job_as_started(created_job.id)
|
job_service.mark_job_as_started(created_job.id)
|
||||||
await job_service.mark_job_as_failed(created_job.id, "Test error")
|
job_service.mark_job_as_failed(created_job.id, "Test error")
|
||||||
|
|
||||||
# Try to complete it
|
# Try to complete it
|
||||||
with pytest.raises(InvalidStatusTransitionError) as exc_info:
|
with pytest.raises(InvalidStatusTransitionError) as exc_info:
|
||||||
await job_service.mark_job_as_completed(created_job.id)
|
job_service.mark_job_as_completed(created_job.id)
|
||||||
|
|
||||||
# Verify exception details
|
# Verify exception details
|
||||||
assert exc_info.value.current_status == ProcessingStatus.FAILED
|
assert exc_info.value.current_status == ProcessingStatus.FAILED
|
||||||
assert exc_info.value.target_status == ProcessingStatus.COMPLETED
|
assert exc_info.value.target_status == ProcessingStatus.COMPLETED
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_mark_processing_job_as_failed_with_error_message(
|
||||||
async def test_i_can_mark_processing_job_as_failed_with_error_message(
|
|
||||||
self,
|
self,
|
||||||
job_service,
|
job_service,
|
||||||
sample_document_id,
|
sample_document_id,
|
||||||
@@ -336,13 +322,13 @@ class TestUpdateStatus:
|
|||||||
):
|
):
|
||||||
"""Test marking processing job as failed with error message."""
|
"""Test marking processing job as failed with error message."""
|
||||||
# Create and start a job
|
# Create and start a job
|
||||||
created_job = await job_service.create_job(sample_document_id, sample_task_id)
|
created_job = job_service.create_job(sample_document_id, sample_task_id)
|
||||||
started_job = await job_service.mark_job_as_started(created_job.id)
|
started_job = job_service.mark_job_as_started(created_job.id)
|
||||||
|
|
||||||
error_message = "Processing failed due to invalid file format"
|
error_message = "Processing failed due to invalid file format"
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
result = await job_service.mark_job_as_failed(created_job.id, error_message)
|
result = job_service.mark_job_as_failed(created_job.id, error_message)
|
||||||
|
|
||||||
# Verify status transition
|
# Verify status transition
|
||||||
assert result is not None
|
assert result is not None
|
||||||
@@ -351,12 +337,11 @@ class TestUpdateStatus:
|
|||||||
assert result.error_message == error_message
|
assert result.error_message == error_message
|
||||||
|
|
||||||
# Verify in database
|
# Verify in database
|
||||||
updated_job = await job_service.get_job_by_id(created_job.id)
|
updated_job = job_service.get_job_by_id(created_job.id)
|
||||||
assert updated_job.status == ProcessingStatus.FAILED
|
assert updated_job.status == ProcessingStatus.FAILED
|
||||||
assert updated_job.error_message == error_message
|
assert updated_job.error_message == error_message
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_mark_processing_job_as_failed_without_error_message(
|
||||||
async def test_i_can_mark_processing_job_as_failed_without_error_message(
|
|
||||||
self,
|
self,
|
||||||
job_service,
|
job_service,
|
||||||
sample_document_id,
|
sample_document_id,
|
||||||
@@ -364,19 +349,18 @@ class TestUpdateStatus:
|
|||||||
):
|
):
|
||||||
"""Test marking processing job as failed without error message."""
|
"""Test marking processing job as failed without error message."""
|
||||||
# Create and start a job
|
# Create and start a job
|
||||||
created_job = await job_service.create_job(sample_document_id, sample_task_id)
|
created_job = job_service.create_job(sample_document_id, sample_task_id)
|
||||||
await job_service.mark_job_as_started(created_job.id)
|
job_service.mark_job_as_started(created_job.id)
|
||||||
|
|
||||||
# Execute without error message
|
# Execute without error message
|
||||||
result = await job_service.mark_job_as_failed(created_job.id)
|
result = job_service.mark_job_as_failed(created_job.id)
|
||||||
|
|
||||||
# Verify status transition
|
# Verify status transition
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result.status == ProcessingStatus.FAILED
|
assert result.status == ProcessingStatus.FAILED
|
||||||
assert result.error_message is None
|
assert result.error_message is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_mark_pending_job_as_failed(
|
||||||
async def test_i_cannot_mark_pending_job_as_failed(
|
|
||||||
self,
|
self,
|
||||||
job_service,
|
job_service,
|
||||||
sample_document_id,
|
sample_document_id,
|
||||||
@@ -384,18 +368,17 @@ class TestUpdateStatus:
|
|||||||
):
|
):
|
||||||
"""Test that pending job cannot be marked as failed."""
|
"""Test that pending job cannot be marked as failed."""
|
||||||
# Create a pending job
|
# Create a pending job
|
||||||
created_job = await job_service.create_job(sample_document_id, sample_task_id)
|
created_job = job_service.create_job(sample_document_id, sample_task_id)
|
||||||
|
|
||||||
# Try to fail it directly
|
# Try to fail it directly
|
||||||
with pytest.raises(InvalidStatusTransitionError) as exc_info:
|
with pytest.raises(InvalidStatusTransitionError) as exc_info:
|
||||||
await job_service.mark_job_as_failed(created_job.id, "Test error")
|
job_service.mark_job_as_failed(created_job.id, "Test error")
|
||||||
|
|
||||||
# Verify exception details
|
# Verify exception details
|
||||||
assert exc_info.value.current_status == ProcessingStatus.PENDING
|
assert exc_info.value.current_status == ProcessingStatus.PENDING
|
||||||
assert exc_info.value.target_status == ProcessingStatus.FAILED
|
assert exc_info.value.target_status == ProcessingStatus.FAILED
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_mark_completed_job_as_failed(
|
||||||
async def test_i_cannot_mark_completed_job_as_failed(
|
|
||||||
self,
|
self,
|
||||||
job_service,
|
job_service,
|
||||||
sample_document_id,
|
sample_document_id,
|
||||||
@@ -403,20 +386,19 @@ class TestUpdateStatus:
|
|||||||
):
|
):
|
||||||
"""Test that completed job cannot be marked as failed."""
|
"""Test that completed job cannot be marked as failed."""
|
||||||
# Create, start, and complete a job
|
# Create, start, and complete a job
|
||||||
created_job = await job_service.create_job(sample_document_id, sample_task_id)
|
created_job = job_service.create_job(sample_document_id, sample_task_id)
|
||||||
await job_service.mark_job_as_started(created_job.id)
|
job_service.mark_job_as_started(created_job.id)
|
||||||
await job_service.mark_job_as_completed(created_job.id)
|
job_service.mark_job_as_completed(created_job.id)
|
||||||
|
|
||||||
# Try to fail it
|
# Try to fail it
|
||||||
with pytest.raises(InvalidStatusTransitionError) as exc_info:
|
with pytest.raises(InvalidStatusTransitionError) as exc_info:
|
||||||
await job_service.mark_job_as_failed(created_job.id, "Test error")
|
job_service.mark_job_as_failed(created_job.id, "Test error")
|
||||||
|
|
||||||
# Verify exception details
|
# Verify exception details
|
||||||
assert exc_info.value.current_status == ProcessingStatus.COMPLETED
|
assert exc_info.value.current_status == ProcessingStatus.COMPLETED
|
||||||
assert exc_info.value.target_status == ProcessingStatus.FAILED
|
assert exc_info.value.target_status == ProcessingStatus.FAILED
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_mark_failed_job_as_failed(
|
||||||
async def test_i_cannot_mark_failed_job_as_failed(
|
|
||||||
self,
|
self,
|
||||||
job_service,
|
job_service,
|
||||||
sample_document_id,
|
sample_document_id,
|
||||||
@@ -424,13 +406,13 @@ class TestUpdateStatus:
|
|||||||
):
|
):
|
||||||
"""Test that failed job cannot be marked as failed again."""
|
"""Test that failed job cannot be marked as failed again."""
|
||||||
# Create, start, and fail a job
|
# Create, start, and fail a job
|
||||||
created_job = await job_service.create_job(sample_document_id, sample_task_id)
|
created_job = job_service.create_job(sample_document_id, sample_task_id)
|
||||||
await job_service.mark_job_as_started(created_job.id)
|
job_service.mark_job_as_started(created_job.id)
|
||||||
await job_service.mark_job_as_failed(created_job.id, "First error")
|
job_service.mark_job_as_failed(created_job.id, "First error")
|
||||||
|
|
||||||
# Try to fail it again
|
# Try to fail it again
|
||||||
with pytest.raises(InvalidStatusTransitionError) as exc_info:
|
with pytest.raises(InvalidStatusTransitionError) as exc_info:
|
||||||
await job_service.mark_job_as_failed(created_job.id, "Second error")
|
job_service.mark_job_as_failed(created_job.id, "Second error")
|
||||||
|
|
||||||
# Verify exception details
|
# Verify exception details
|
||||||
assert exc_info.value.current_status == ProcessingStatus.FAILED
|
assert exc_info.value.current_status == ProcessingStatus.FAILED
|
||||||
@@ -440,8 +422,7 @@ class TestUpdateStatus:
|
|||||||
class TestDeleteJob:
|
class TestDeleteJob:
|
||||||
"""Tests for delete_job method."""
|
"""Tests for delete_job method."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_can_delete_existing_job(
|
||||||
async def test_i_can_delete_existing_job(
|
|
||||||
self,
|
self,
|
||||||
job_service,
|
job_service,
|
||||||
sample_document_id,
|
sample_document_id,
|
||||||
@@ -449,30 +430,29 @@ class TestDeleteJob:
|
|||||||
):
|
):
|
||||||
"""Test deleting an existing job."""
|
"""Test deleting an existing job."""
|
||||||
# Create a job
|
# Create a job
|
||||||
created_job = await job_service.create_job(sample_document_id, sample_task_id)
|
created_job = job_service.create_job(sample_document_id, sample_task_id)
|
||||||
|
|
||||||
# Verify job exists
|
# Verify job exists
|
||||||
job_before_delete = await job_service.get_job_by_id(created_job.id)
|
job_before_delete = job_service.get_job_by_id(created_job.id)
|
||||||
assert job_before_delete is not None
|
assert job_before_delete is not None
|
||||||
|
|
||||||
# Execute deletion
|
# Execute deletion
|
||||||
result = await job_service.delete_job(created_job.id)
|
result = job_service.delete_job(created_job.id)
|
||||||
|
|
||||||
# Verify deletion
|
# Verify deletion
|
||||||
assert result is True
|
assert result is True
|
||||||
|
|
||||||
# Verify job no longer exists
|
# Verify job no longer exists
|
||||||
deleted_job = await job_service.get_job_by_id(created_job.id)
|
deleted_job = job_service.get_job_by_id(created_job.id)
|
||||||
assert deleted_job is None
|
assert deleted_job is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_i_cannot_delete_nonexistent_job(
|
||||||
async def test_i_cannot_delete_nonexistent_job(
|
|
||||||
self,
|
self,
|
||||||
job_service
|
job_service
|
||||||
):
|
):
|
||||||
"""Test deleting a nonexistent job returns False."""
|
"""Test deleting a nonexistent job returns False."""
|
||||||
# Execute deletion with random ObjectId
|
# Execute deletion with random ObjectId
|
||||||
result = await job_service.delete_job(ObjectId())
|
result = job_service.delete_job(ObjectId())
|
||||||
|
|
||||||
# Verify
|
# Verify
|
||||||
assert result is False
|
assert result is False
|
||||||
@@ -481,8 +461,7 @@ class TestDeleteJob:
|
|||||||
class TestStatusTransitionValidation:
|
class TestStatusTransitionValidation:
|
||||||
"""Tests for status transition validation across different scenarios."""
|
"""Tests for status transition validation across different scenarios."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_valid_job_lifecycle_flow(
|
||||||
async def test_valid_job_lifecycle_flow(
|
|
||||||
self,
|
self,
|
||||||
job_service,
|
job_service,
|
||||||
sample_document_id,
|
sample_document_id,
|
||||||
@@ -490,19 +469,18 @@ class TestStatusTransitionValidation:
|
|||||||
):
|
):
|
||||||
"""Test complete valid job lifecycle: PENDING → PROCESSING → COMPLETED."""
|
"""Test complete valid job lifecycle: PENDING → PROCESSING → COMPLETED."""
|
||||||
# Create job (PENDING)
|
# Create job (PENDING)
|
||||||
job = await job_service.create_job(sample_document_id, sample_task_id)
|
job = job_service.create_job(sample_document_id, sample_task_id)
|
||||||
assert job.status == ProcessingStatus.PENDING
|
assert job.status == ProcessingStatus.PENDING
|
||||||
|
|
||||||
# Start job (PENDING → PROCESSING)
|
# Start job (PENDING → PROCESSING)
|
||||||
started_job = await job_service.mark_job_as_started(job.id)
|
started_job = job_service.mark_job_as_started(job.id)
|
||||||
assert started_job.status == ProcessingStatus.PROCESSING
|
assert started_job.status == ProcessingStatus.PROCESSING
|
||||||
|
|
||||||
# Complete job (PROCESSING → COMPLETED)
|
# Complete job (PROCESSING → COMPLETED)
|
||||||
completed_job = await job_service.mark_job_as_completed(job.id)
|
completed_job = job_service.mark_job_as_completed(job.id)
|
||||||
assert completed_job.status == ProcessingStatus.COMPLETED
|
assert completed_job.status == ProcessingStatus.COMPLETED
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_valid_job_failure_flow(
|
||||||
async def test_valid_job_failure_flow(
|
|
||||||
self,
|
self,
|
||||||
job_service,
|
job_service,
|
||||||
sample_document_id,
|
sample_document_id,
|
||||||
@@ -510,69 +488,31 @@ class TestStatusTransitionValidation:
|
|||||||
):
|
):
|
||||||
"""Test valid job failure: PENDING → PROCESSING → FAILED."""
|
"""Test valid job failure: PENDING → PROCESSING → FAILED."""
|
||||||
# Create job (PENDING)
|
# Create job (PENDING)
|
||||||
job = await job_service.create_job(sample_document_id, sample_task_id)
|
job = job_service.create_job(sample_document_id, sample_task_id)
|
||||||
assert job.status == ProcessingStatus.PENDING
|
assert job.status == ProcessingStatus.PENDING
|
||||||
|
|
||||||
# Start job (PENDING → PROCESSING)
|
# Start job (PENDING → PROCESSING)
|
||||||
started_job = await job_service.mark_job_as_started(job.id)
|
started_job = job_service.mark_job_as_started(job.id)
|
||||||
assert started_job.status == ProcessingStatus.PROCESSING
|
assert started_job.status == ProcessingStatus.PROCESSING
|
||||||
|
|
||||||
# Fail job (PROCESSING → FAILED)
|
# Fail job (PROCESSING → FAILED)
|
||||||
failed_job = await job_service.mark_job_as_failed(job.id, "Test failure")
|
failed_job = job_service.mark_job_as_failed(job.id, "Test failure")
|
||||||
assert failed_job.status == ProcessingStatus.FAILED
|
assert failed_job.status == ProcessingStatus.FAILED
|
||||||
assert failed_job.error_message == "Test failure"
|
assert failed_job.error_message == "Test failure"
|
||||||
|
|
||||||
|
def test_job_operations_with_empty_database(
|
||||||
class TestEdgeCases:
|
|
||||||
"""Tests for edge cases and error conditions."""
|
|
||||||
#
|
|
||||||
# @pytest.mark.asyncio
|
|
||||||
# async def test_multiple_jobs_for_same_file(
|
|
||||||
# self,
|
|
||||||
# job_service,
|
|
||||||
# sample_document_id
|
|
||||||
# ):
|
|
||||||
# """Test handling multiple jobs for the same file."""
|
|
||||||
# # Create multiple jobs for same file
|
|
||||||
# job1 = await job_service.create_job(sample_document_id, "task-1")
|
|
||||||
# job2 = await job_service.create_job(sample_document_id, "task-2")
|
|
||||||
# job3 = await job_service.create_job(sample_document_id, "task-3")
|
|
||||||
#
|
|
||||||
# # Verify all jobs exist and are independent
|
|
||||||
# jobs_for_file = await job_service.get_jobs_by_file_id(sample_document_id)
|
|
||||||
# assert len(jobs_for_file) == 3
|
|
||||||
#
|
|
||||||
# job_ids = [job.id for job in jobs_for_file]
|
|
||||||
# assert job1.id in job_ids
|
|
||||||
# assert job2.id in job_ids
|
|
||||||
# assert job3.id in job_ids
|
|
||||||
#
|
|
||||||
# # Verify status transitions work independently
|
|
||||||
# await job_service.mark_job_as_started(job1.id)
|
|
||||||
# await job_service.mark_job_as_completed(job1.id)
|
|
||||||
#
|
|
||||||
# # Other jobs should still be pending
|
|
||||||
# updated_job2 = await job_service.get_job_by_id(job2.id)
|
|
||||||
# updated_job3 = await job_service.get_job_by_id(job3.id)
|
|
||||||
#
|
|
||||||
# assert updated_job2.status == ProcessingStatus.PENDING
|
|
||||||
# assert updated_job3.status == ProcessingStatus.PENDING
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_job_operations_with_empty_database(
|
|
||||||
self,
|
self,
|
||||||
job_service
|
job_service
|
||||||
):
|
):
|
||||||
"""Test job operations when database is empty."""
|
"""Test job operations when database is empty."""
|
||||||
# Try to get nonexistent job
|
# Try to get nonexistent job
|
||||||
result = await job_service.get_job_by_id(ObjectId())
|
result = job_service.get_job_by_id(ObjectId())
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
# Try to get jobs by status when none exist
|
# Try to get jobs by status when none exist
|
||||||
pending_jobs = await job_service.get_jobs_by_status(ProcessingStatus.PENDING)
|
pending_jobs = job_service.get_jobs_by_status(ProcessingStatus.PENDING)
|
||||||
assert pending_jobs == []
|
assert pending_jobs == []
|
||||||
|
|
||||||
# Try to delete nonexistent job
|
# Try to delete nonexistent job
|
||||||
delete_result = await job_service.delete_job(ObjectId())
|
delete_result = job_service.delete_job(ObjectId())
|
||||||
assert delete_result is False
|
assert delete_result is False
|
||||||
|
|||||||
Reference in New Issue
Block a user