Removed async

This commit is contained in:
2025-09-24 21:53:48 +02:00
parent e17c4c7e7b
commit 48f5b009ae
23 changed files with 609 additions and 770 deletions

View File

@@ -67,7 +67,7 @@ services:
networks: networks:
- mydocmanager-network - mydocmanager-network
command: celery -A tasks.main worker --loglevel=info command: celery -A tasks.main worker --loglevel=info
#command: celery -A main --loglevel=info # pour la production
volumes: volumes:
mongodb-data: mongodb-data:

View File

@@ -1,6 +1,7 @@
amqp==5.3.1 amqp==5.3.1
annotated-types==0.7.0 annotated-types==0.7.0
anyio==4.10.0 anyio==4.10.0
asgiref==3.9.1
bcrypt==4.3.0 bcrypt==4.3.0
billiard==4.2.1 billiard==4.2.1
celery==5.5.3 celery==5.5.3
@@ -12,9 +13,12 @@ dnspython==2.8.0
email-validator==2.3.0 email-validator==2.3.0
fastapi==0.116.1 fastapi==0.116.1
h11==0.16.0 h11==0.16.0
hiredis==3.2.1
httptools==0.6.4 httptools==0.6.4
idna==3.10 idna==3.10
importlib_metadata==8.7.0
iniconfig==2.1.0 iniconfig==2.1.0
izulu==0.50.0
kombu==5.5.4 kombu==5.5.4
mongomock==4.3.0 mongomock==4.3.0
mongomock-motor==0.0.36 mongomock-motor==0.0.36
@@ -23,9 +27,11 @@ packaging==25.0
pipdeptree==2.28.0 pipdeptree==2.28.0
pluggy==1.6.0 pluggy==1.6.0
prompt_toolkit==3.0.52 prompt_toolkit==3.0.52
pycron==3.2.0
pydantic==2.11.9 pydantic==2.11.9
pydantic_core==2.33.2 pydantic_core==2.33.2
Pygments==2.19.2 Pygments==2.19.2
PyJWT==2.10.1
pymongo==4.15.1 pymongo==4.15.1
pytest==8.4.2 pytest==8.4.2
pytest-asyncio==1.2.0 pytest-asyncio==1.2.0
@@ -35,6 +41,7 @@ python-dotenv==1.1.1
python-magic==0.4.27 python-magic==0.4.27
pytz==2025.2 pytz==2025.2
PyYAML==6.0.2 PyYAML==6.0.2
redis==6.4.0
sentinels==1.1.1 sentinels==1.1.1
six==1.17.0 six==1.17.0
sniffio==1.3.1 sniffio==1.3.1
@@ -49,3 +56,4 @@ watchdog==6.0.0
watchfiles==1.1.0 watchfiles==1.1.0
wcwidth==0.2.13 wcwidth==0.2.13
websockets==15.0.1 websockets==15.0.1
zipp==3.23.0

View File

@@ -30,6 +30,26 @@ def get_mongodb_database_name() -> str:
return os.getenv("MONGODB_DATABASE", "mydocmanager") return os.getenv("MONGODB_DATABASE", "mydocmanager")
def get_redis_url() -> str:
return os.getenv("REDIS_URL", "redis://localhost:6379/0")
# def get_redis_host() -> str:
# redis_url = get_redis_url()
# if redis_url.startswith("redis://"):
# return redis_url.split("redis://")[1].split("/")[0]
# else:
# return redis_url
#
#
# def get_redis_port() -> int:
# redis_url = get_redis_url()
# if redis_url.startswith("redis://"):
# return int(redis_url.split("redis://")[1].split("/")[0].split(":")[1])
# else:
# return int(redis_url.split(":")[1])
def get_jwt_secret_key() -> str: def get_jwt_secret_key() -> str:
""" """
Get JWT secret key from environment variables. Get JWT secret key from environment variables.

View File

@@ -7,10 +7,11 @@ The application will terminate if MongoDB is not accessible at startup.
import sys import sys
from typing import Optional from typing import Optional
from pymongo import MongoClient from pymongo import MongoClient
from pymongo.database import Database from pymongo.database import Database
from pymongo.errors import ConnectionFailure, ServerSelectionTimeoutError from pymongo.errors import ConnectionFailure, ServerSelectionTimeoutError
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
from app.config.settings import get_mongodb_url, get_mongodb_database_name from app.config.settings import get_mongodb_url, get_mongodb_database_name
# Global variables for singleton pattern # Global variables for singleton pattern
@@ -18,7 +19,7 @@ _client: Optional[MongoClient] = None
_database: Optional[Database] = None _database: Optional[Database] = None
def create_mongodb_client() -> AsyncIOMotorClient: def create_mongodb_client() -> MongoClient:
""" """
Create MongoDB client with connection validation. Create MongoDB client with connection validation.
@@ -32,7 +33,7 @@ def create_mongodb_client() -> AsyncIOMotorClient:
try: try:
# Create client with short timeout for fail-fast behavior # Create client with short timeout for fail-fast behavior
client = AsyncIOMotorClient( client = MongoClient(
mongodb_url, mongodb_url,
serverSelectionTimeoutMS=5000, # 5 seconds timeout serverSelectionTimeoutMS=5000, # 5 seconds timeout
connectTimeoutMS=5000, connectTimeoutMS=5000,

View File

@@ -8,7 +8,8 @@ in MongoDB with proper error handling and type safety.
from typing import Optional, List from typing import Optional, List
from bson import ObjectId from bson import ObjectId
from motor.motor_asyncio import AsyncIOMotorCollection, AsyncIOMotorDatabase from pymongo.collection import Collection
from pymongo.database import Database
from pymongo.errors import DuplicateKeyError, PyMongoError from pymongo.errors import DuplicateKeyError, PyMongoError
from app.database.connection import get_extra_args from app.database.connection import get_extra_args
@@ -37,33 +38,29 @@ class FileDocumentRepository:
with proper error handling and data validation. with proper error handling and data validation.
""" """
def __init__(self, database: AsyncIOMotorDatabase): def __init__(self, database: Database):
"""Initialize file repository with database connection.""" """Initialize file repository with database connection."""
self.db = database self.db = database
self.collection: AsyncIOMotorCollection = self.db.documents self.collection: Collection = self.db.documents
async def initialize(self): def initialize(self):
""" """
Initialize repository by ensuring required indexes exist. Initialize repository by ensuring required indexes exist.
Should be called after repository instantiation to setup database indexes. Should be called after repository instantiation to setup database indexes.
""" """
await self._ensure_indexes() self._ensure_indexes()
return self return self
async def _ensure_indexes(self): def _ensure_indexes(self):
""" """
Ensure required database indexes exist. Ensure required database indexes exist.
Creates unique index on username field to prevent duplicates. Creates unique index on username field to prevent duplicates.
""" """
try:
await self.collection.create_index("filepath", unique=True)
except PyMongoError:
# Index might already exist, ignore error
pass pass
async def create_document(self, file_data: FileDocument, session=None) -> FileDocument: def create_document(self, file_data: FileDocument, session=None) -> FileDocument:
""" """
Create a new file document in database. Create a new file document in database.
@@ -83,7 +80,7 @@ class FileDocumentRepository:
if "_id" in file_dict and file_dict["_id"] is None: if "_id" in file_dict and file_dict["_id"] is None:
del file_dict["_id"] del file_dict["_id"]
result = await self.collection.insert_one(file_dict, **get_extra_args(session)) result = self.collection.insert_one(file_dict, **get_extra_args(session))
file_data.id = result.inserted_id file_data.id = result.inserted_id
return file_data return file_data
@@ -92,7 +89,7 @@ class FileDocumentRepository:
except PyMongoError as e: except PyMongoError as e:
raise ValueError(f"Failed to create file document: {e}") raise ValueError(f"Failed to create file document: {e}")
async def find_document_by_id(self, file_id: str) -> Optional[FileDocument]: def find_document_by_id(self, file_id: str) -> Optional[FileDocument]:
""" """
Find file document by ID. Find file document by ID.
@@ -106,7 +103,7 @@ class FileDocumentRepository:
if not ObjectId.is_valid(file_id): if not ObjectId.is_valid(file_id):
return None return None
file_doc = await self.collection.find_one({"_id": ObjectId(file_id)}) file_doc = self.collection.find_one({"_id": ObjectId(file_id)})
if file_doc: if file_doc:
return FileDocument(**file_doc) return FileDocument(**file_doc)
return None return None
@@ -114,7 +111,7 @@ class FileDocumentRepository:
except PyMongoError: except PyMongoError:
return None return None
async def find_document_by_hash(self, file_hash: str) -> Optional[FileDocument]: def find_document_by_hash(self, file_hash: str) -> Optional[FileDocument]:
""" """
Find file document by file hash to detect duplicates. Find file document by file hash to detect duplicates.
@@ -125,7 +122,7 @@ class FileDocumentRepository:
FileDocument or None: File document if found, None otherwise FileDocument or None: File document if found, None otherwise
""" """
try: try:
file_doc = await self.collection.find_one({"file_hash": file_hash}) file_doc = self.collection.find_one({"file_hash": file_hash})
if file_doc: if file_doc:
return FileDocument(**file_doc) return FileDocument(**file_doc)
return None return None
@@ -133,7 +130,7 @@ class FileDocumentRepository:
except PyMongoError: except PyMongoError:
return None return None
async def find_document_by_filepath(self, filepath: str) -> Optional[FileDocument]: def find_document_by_filepath(self, filepath: str) -> Optional[FileDocument]:
""" """
Find file document by exact filepath. Find file document by exact filepath.
@@ -144,7 +141,7 @@ class FileDocumentRepository:
FileDocument or None: File document if found, None otherwise FileDocument or None: File document if found, None otherwise
""" """
try: try:
file_doc = await self.collection.find_one({"filepath": filepath}) file_doc = self.collection.find_one({"filepath": filepath})
if file_doc: if file_doc:
return FileDocument(**file_doc) return FileDocument(**file_doc)
return None return None
@@ -152,7 +149,7 @@ class FileDocumentRepository:
except PyMongoError: except PyMongoError:
return None return None
async def find_document_by_name(self, filename: str, matching_method: MatchMethodBase = None) -> List[FileDocument]: def find_document_by_name(self, filename: str, matching_method: MatchMethodBase = None) -> List[FileDocument]:
""" """
Find file documents by filename using fuzzy matching. Find file documents by filename using fuzzy matching.
@@ -166,8 +163,7 @@ class FileDocumentRepository:
try: try:
# Get all files from database # Get all files from database
cursor = self.collection.find({}) cursor = self.collection.find({})
all_files = await cursor.to_list(length=None) all_documents = [FileDocument(**file_doc) for file_doc in cursor]
all_documents = [FileDocument(**file_doc) for file_doc in all_files]
if isinstance(matching_method, FuzzyMatching): if isinstance(matching_method, FuzzyMatching):
return fuzzy_matching(filename, all_documents, matching_method.threshold) return fuzzy_matching(filename, all_documents, matching_method.threshold)
@@ -177,7 +173,7 @@ class FileDocumentRepository:
except PyMongoError: except PyMongoError:
return [] return []
async def list_documents(self, skip: int = 0, limit: int = 100) -> List[FileDocument]: def list_documents(self, skip: int = 0, limit: int = 100) -> List[FileDocument]:
""" """
List file documents with pagination. List file documents with pagination.
@@ -190,13 +186,12 @@ class FileDocumentRepository:
""" """
try: try:
cursor = self.collection.find({}).skip(skip).limit(limit).sort("detected_at", -1) cursor = self.collection.find({}).skip(skip).limit(limit).sort("detected_at", -1)
file_docs = await cursor.to_list(length=limit) return [FileDocument(**doc) for doc in cursor]
return [FileDocument(**doc) for doc in file_docs]
except PyMongoError: except PyMongoError:
return [] return []
async def count_documents(self) -> int: def count_documents(self) -> int:
""" """
Count total number of file documents. Count total number of file documents.
@@ -204,11 +199,11 @@ class FileDocumentRepository:
int: Total number of file documents in collection int: Total number of file documents in collection
""" """
try: try:
return await self.collection.count_documents({}) return self.collection.count_documents({})
except PyMongoError: except PyMongoError:
return 0 return 0
async def update_document(self, file_id: str, update_data: dict, session=None) -> Optional[FileDocument]: def update_document(self, file_id: str, update_data: dict, session=None) -> Optional[FileDocument]:
""" """
Update file document with new data. Update file document with new data.
@@ -228,9 +223,9 @@ class FileDocumentRepository:
clean_update_data = {k: v for k, v in update_data.items() if v is not None} clean_update_data = {k: v for k, v in update_data.items() if v is not None}
if not clean_update_data: if not clean_update_data:
return await self.find_document_by_id(file_id) return self.find_document_by_id(file_id)
result = await self.collection.find_one_and_update( result = self.collection.find_one_and_update(
{"_id": ObjectId(file_id)}, {"_id": ObjectId(file_id)},
{"$set": clean_update_data}, {"$set": clean_update_data},
return_document=True, return_document=True,
@@ -244,7 +239,7 @@ class FileDocumentRepository:
except PyMongoError: except PyMongoError:
return None return None
async def delete_document(self, file_id: str, session=None) -> bool: def delete_document(self, file_id: str, session=None) -> bool:
""" """
Delete file document from database. Delete file document from database.
@@ -259,7 +254,7 @@ class FileDocumentRepository:
if not ObjectId.is_valid(file_id): if not ObjectId.is_valid(file_id):
return False return False
result = await self.collection.delete_one({"_id": ObjectId(file_id)}, **get_extra_args(session)) result = self.collection.delete_one({"_id": ObjectId(file_id)}, **get_extra_args(session))
return result.deleted_count > 0 return result.deleted_count > 0
except PyMongoError: except PyMongoError:

View File

@@ -8,7 +8,8 @@ with automatic timestamp management and error handling.
from datetime import datetime from datetime import datetime
from typing import List, Optional from typing import List, Optional
from motor.motor_asyncio import AsyncIOMotorCollection, AsyncIOMotorDatabase from pymongo.collection import Collection
from pymongo.database import Database
from pymongo.errors import PyMongoError from pymongo.errors import PyMongoError
from app.exceptions.job_exceptions import JobRepositoryError from app.exceptions.job_exceptions import JobRepositoryError
@@ -24,33 +25,33 @@ class JobRepository:
timestamp management and proper error handling. timestamp management and proper error handling.
""" """
def __init__(self, database: AsyncIOMotorDatabase): def __init__(self, database: Database):
"""Initialize repository with MongoDB collection reference.""" """Initialize repository with MongoDB collection reference."""
self.db = database self.db = database
self.collection: AsyncIOMotorCollection = self.db.processing_jobs self.collection: Collection = self.db.processing_jobs
async def _ensure_indexes(self): def _ensure_indexes(self):
""" """
Ensure required database indexes exist. Ensure required database indexes exist.
Creates unique index on username field to prevent duplicates. Creates unique index on username field to prevent duplicates.
""" """
try: try:
await self.collection.create_index("document_id", unique=True) self.collection.create_index("document_id", unique=True)
except PyMongoError: except PyMongoError:
# Index might already exist, ignore error # Index might already exist, ignore error
pass pass
async def initialize(self): def initialize(self):
""" """
Initialize repository by ensuring required indexes exist. Initialize repository by ensuring required indexes exist.
Should be called after repository instantiation to setup database indexes. Should be called after repository instantiation to setup database indexes.
""" """
await self._ensure_indexes() self._ensure_indexes()
return self return self
async def create_job(self, document_id: PyObjectId, task_id: Optional[str] = None) -> ProcessingJob: def create_job(self, document_id: PyObjectId, task_id: Optional[str] = None) -> ProcessingJob:
""" """
Create a new processing job. Create a new processing job.
@@ -75,7 +76,7 @@ class JobRepository:
"error_message": None "error_message": None
} }
result = await self.collection.insert_one(job_data) result = self.collection.insert_one(job_data)
job_data["_id"] = result.inserted_id job_data["_id"] = result.inserted_id
return ProcessingJob(**job_data) return ProcessingJob(**job_data)
@@ -83,7 +84,7 @@ class JobRepository:
except PyMongoError as e: except PyMongoError as e:
raise JobRepositoryError("create_job", e) raise JobRepositoryError("create_job", e)
async def find_job_by_id(self, job_id: PyObjectId) -> Optional[ProcessingJob]: def find_job_by_id(self, job_id: PyObjectId) -> Optional[ProcessingJob]:
""" """
Retrieve a job by its ID. Retrieve a job by its ID.
@@ -98,7 +99,7 @@ class JobRepository:
JobRepositoryError: If database operation fails JobRepositoryError: If database operation fails
""" """
try: try:
job_data = await self.collection.find_one({"_id": job_id}) job_data = self.collection.find_one({"_id": job_id})
if job_data: if job_data:
return ProcessingJob(**job_data) return ProcessingJob(**job_data)
@@ -107,7 +108,7 @@ class JobRepository:
except PyMongoError as e: except PyMongoError as e:
raise JobRepositoryError("get_job_by_id", e) raise JobRepositoryError("get_job_by_id", e)
async def update_job_status( def update_job_status(
self, self,
job_id: PyObjectId, job_id: PyObjectId,
status: ProcessingStatus, status: ProcessingStatus,
@@ -143,7 +144,7 @@ class JobRepository:
if error_message is not None: if error_message is not None:
update_data["error_message"] = error_message update_data["error_message"] = error_message
result = await self.collection.find_one_and_update( result = self.collection.find_one_and_update(
{"_id": job_id}, {"_id": job_id},
{"$set": update_data}, {"$set": update_data},
return_document=True return_document=True
@@ -157,7 +158,7 @@ class JobRepository:
except PyMongoError as e: except PyMongoError as e:
raise JobRepositoryError("update_job_status", e) raise JobRepositoryError("update_job_status", e)
async def delete_job(self, job_id: PyObjectId) -> bool: def delete_job(self, job_id: PyObjectId) -> bool:
""" """
Delete a job from the database. Delete a job from the database.
@@ -171,14 +172,14 @@ class JobRepository:
JobRepositoryError: If database operation fails JobRepositoryError: If database operation fails
""" """
try: try:
result = await self.collection.delete_one({"_id": job_id}) result = self.collection.delete_one({"_id": job_id})
return result.deleted_count > 0 return result.deleted_count > 0
except PyMongoError as e: except PyMongoError as e:
raise JobRepositoryError("delete_job", e) raise JobRepositoryError("delete_job", e)
async def find_jobs_by_document_id(self, document_id: PyObjectId) -> List[ProcessingJob]: def find_jobs_by_document_id(self, document_id: PyObjectId) -> List[ProcessingJob]:
""" """
Retrieve all jobs for a specific file. Retrieve all jobs for a specific file.
@@ -195,7 +196,7 @@ class JobRepository:
cursor = self.collection.find({"document_id": document_id}) cursor = self.collection.find({"document_id": document_id})
jobs = [] jobs = []
async for job_data in cursor: for job_data in cursor:
jobs.append(ProcessingJob(**job_data)) jobs.append(ProcessingJob(**job_data))
return jobs return jobs
@@ -203,7 +204,7 @@ class JobRepository:
except PyMongoError as e: except PyMongoError as e:
raise JobRepositoryError("get_jobs_by_file_id", e) raise JobRepositoryError("get_jobs_by_file_id", e)
async def get_jobs_by_status(self, status: ProcessingStatus) -> List[ProcessingJob]: def get_jobs_by_status(self, status: ProcessingStatus) -> List[ProcessingJob]:
""" """
Retrieve all jobs with a specific status. Retrieve all jobs with a specific status.
@@ -220,7 +221,7 @@ class JobRepository:
cursor = self.collection.find({"status": status}) cursor = self.collection.find({"status": status})
jobs = [] jobs = []
async for job_data in cursor: for job_data in cursor:
jobs.append(ProcessingJob(**job_data)) jobs.append(ProcessingJob(**job_data))
return jobs return jobs

View File

@@ -5,10 +5,12 @@ This module implements the repository pattern for user CRUD operations
with dependency injection of the database connection using async/await. with dependency injection of the database connection using async/await.
""" """
from typing import Optional, List
from datetime import datetime from datetime import datetime
from typing import Optional, List
from bson import ObjectId from bson import ObjectId
from motor.motor_asyncio import AsyncIOMotorDatabase, AsyncIOMotorCollection from pymongo.collection import Collection
from pymongo.database import Database
from pymongo.errors import DuplicateKeyError, PyMongoError from pymongo.errors import DuplicateKeyError, PyMongoError
from app.models.user import UserCreate, UserInDB, UserUpdate from app.models.user import UserCreate, UserInDB, UserUpdate
@@ -23,7 +25,7 @@ class UserRepository:
following the repository pattern with dependency injection and async/await. following the repository pattern with dependency injection and async/await.
""" """
def __init__(self, database: AsyncIOMotorDatabase): def __init__(self, database: Database):
""" """
Initialize repository with database dependency. Initialize repository with database dependency.
@@ -31,30 +33,30 @@ class UserRepository:
database (AsyncIOMotorDatabase): MongoDB database instance database (AsyncIOMotorDatabase): MongoDB database instance
""" """
self.db = database self.db = database
self.collection: AsyncIOMotorCollection = database.users self.collection: Collection = database.users
async def initialize(self): def initialize(self):
""" """
Initialize repository by ensuring required indexes exist. Initialize repository by ensuring required indexes exist.
Should be called after repository instantiation to setup database indexes. Should be called after repository instantiation to setup database indexes.
""" """
await self._ensure_indexes() self._ensure_indexes()
return self return self
async def _ensure_indexes(self): def _ensure_indexes(self):
""" """
Ensure required database indexes exist. Ensure required database indexes exist.
Creates unique index on username field to prevent duplicates. Creates unique index on username field to prevent duplicates.
""" """
try: try:
await self.collection.create_index("username", unique=True) self.collection.create_index("username", unique=True)
except PyMongoError: except PyMongoError:
# Index might already exist, ignore error # Index might already exist, ignore error
pass pass
async def create_user(self, user_data: UserCreate) -> UserInDB: def create_user(self, user_data: UserCreate) -> UserInDB:
""" """
Create a new user in the database. Create a new user in the database.
@@ -79,7 +81,7 @@ class UserRepository:
} }
try: try:
result = await self.collection.insert_one(user_dict) result = self.collection.insert_one(user_dict)
user_dict["_id"] = result.inserted_id user_dict["_id"] = result.inserted_id
return UserInDB(**user_dict) return UserInDB(**user_dict)
except DuplicateKeyError as e: except DuplicateKeyError as e:
@@ -87,7 +89,7 @@ class UserRepository:
except PyMongoError as e: except PyMongoError as e:
raise ValueError(f"Failed to create user: {e}") raise ValueError(f"Failed to create user: {e}")
async def find_user_by_username(self, username: str) -> Optional[UserInDB]: def find_user_by_username(self, username: str) -> Optional[UserInDB]:
""" """
Find user by username. Find user by username.
@@ -98,14 +100,14 @@ class UserRepository:
UserInDB or None: User if found, None otherwise UserInDB or None: User if found, None otherwise
""" """
try: try:
user_doc = await self.collection.find_one({"username": username}) user_doc = self.collection.find_one({"username": username})
if user_doc: if user_doc:
return UserInDB(**user_doc) return UserInDB(**user_doc)
return None return None
except PyMongoError: except PyMongoError:
return None return None
async def find_user_by_id(self, user_id: str) -> Optional[UserInDB]: def find_user_by_id(self, user_id: str) -> Optional[UserInDB]:
""" """
Find user by ID. Find user by ID.
@@ -119,14 +121,14 @@ class UserRepository:
if not ObjectId.is_valid(user_id): if not ObjectId.is_valid(user_id):
return None return None
user_doc = await self.collection.find_one({"_id": ObjectId(user_id)}) user_doc = self.collection.find_one({"_id": ObjectId(user_id)})
if user_doc: if user_doc:
return UserInDB(**user_doc) return UserInDB(**user_doc)
return None return None
except PyMongoError: except PyMongoError:
return None return None
async def find_user_by_email(self, email: str) -> Optional[UserInDB]: def find_user_by_email(self, email: str) -> Optional[UserInDB]:
""" """
Find user by email address. Find user by email address.
@@ -137,14 +139,14 @@ class UserRepository:
UserInDB or None: User if found, None otherwise UserInDB or None: User if found, None otherwise
""" """
try: try:
user_doc = await self.collection.find_one({"email": email}) user_doc = self.collection.find_one({"email": email})
if user_doc: if user_doc:
return UserInDB(**user_doc) return UserInDB(**user_doc)
return None return None
except PyMongoError: except PyMongoError:
return None return None
async def update_user(self, user_id: str, user_update: UserUpdate) -> Optional[UserInDB]: def update_user(self, user_id: str, user_update: UserUpdate) -> Optional[UserInDB]:
""" """
Update user information. Update user information.
@@ -177,9 +179,9 @@ class UserRepository:
clean_update_data = {k: v for k, v in update_data.items() if v is not None} clean_update_data = {k: v for k, v in update_data.items() if v is not None}
if not clean_update_data: if not clean_update_data:
return await self.find_user_by_id(user_id) return self.find_user_by_id(user_id)
result = await self.collection.find_one_and_update( result = self.collection.find_one_and_update(
{"_id": ObjectId(user_id)}, {"_id": ObjectId(user_id)},
{"$set": clean_update_data}, {"$set": clean_update_data},
return_document=True return_document=True
@@ -192,7 +194,7 @@ class UserRepository:
except PyMongoError: except PyMongoError:
return None return None
async def delete_user(self, user_id: str) -> bool: def delete_user(self, user_id: str) -> bool:
""" """
Delete user from database. Delete user from database.
@@ -206,12 +208,12 @@ class UserRepository:
if not ObjectId.is_valid(user_id): if not ObjectId.is_valid(user_id):
return False return False
result = await self.collection.delete_one({"_id": ObjectId(user_id)}) result = self.collection.delete_one({"_id": ObjectId(user_id)})
return result.deleted_count > 0 return result.deleted_count > 0
except PyMongoError: except PyMongoError:
return False return False
async def list_users(self, skip: int = 0, limit: int = 100) -> List[UserInDB]: def list_users(self, skip: int = 0, limit: int = 100) -> List[UserInDB]:
""" """
List users with pagination. List users with pagination.
@@ -224,12 +226,12 @@ class UserRepository:
""" """
try: try:
cursor = self.collection.find({}).skip(skip).limit(limit).sort("created_at", -1) cursor = self.collection.find({}).skip(skip).limit(limit).sort("created_at", -1)
user_docs = await cursor.to_list(length=limit) user_docs = cursor.to_list(length=limit)
return [UserInDB(**user_doc) for user_doc in user_docs] return [UserInDB(**user_doc) for user_doc in user_docs]
except PyMongoError: except PyMongoError:
return [] return []
async def count_users(self) -> int: def count_users(self) -> int:
""" """
Count total number of users. Count total number of users.
@@ -237,11 +239,11 @@ class UserRepository:
int: Total number of users in database int: Total number of users in database
""" """
try: try:
return await self.collection.count_documents({}) return self.collection.count_documents({})
except PyMongoError: except PyMongoError:
return 0 return 0
async def user_exists(self, username: str) -> bool: def user_exists(self, username: str) -> bool:
""" """
Check if user exists by username. Check if user exists by username.
@@ -252,7 +254,7 @@ class UserRepository:
bool: True if user exists, False otherwise bool: True if user exists, False otherwise
""" """
try: try:
count = await self.collection.count_documents({"username": username}) count = self.collection.count_documents({"username": username})
return count > 0 return count > 0
except PyMongoError: except PyMongoError:
return False return False

View File

@@ -63,15 +63,17 @@ class DocumentFileEventHandler(FileSystemEventHandler):
logger.info(f"Processing new file: {filepath}") logger.info(f"Processing new file: {filepath}")
try: # try:
from tasks.main import process_document from tasks.document_processing import process_document
celery_result = process_document.delay(filepath) task_result = process_document.delay(filepath)
celery_task_id = celery_result.id print(task_result)
logger.info(f"Dispatched Celery task with ID: {celery_task_id}") print("hello world")
# task_id = task_result.task_id
# logger.info(f"Dispatched Celery task with ID: {task_id}")
except Exception as e: # except Exception as e:
logger.error(f"Failed to process file {filepath}: {str(e)}") # logger.error(f"Failed to process file {filepath}: {str(e)}")
# Note: We don't re-raise the exception to keep the watcher running # # Note: We don't re-raise the exception to keep the watcher running
class FileWatcher: class FileWatcher:

View File

@@ -64,7 +64,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
# Create default admin user # Create default admin user
init_service = InitializationService(user_service) init_service = InitializationService(user_service)
await init_service.initialize_application() init_service.initialize_application()
logger.info("Default admin user initialization completed") logger.info("Default admin user initialization completed")
# Create and start file watcher # Create and start file watcher

View File

@@ -44,8 +44,8 @@ class DocumentService:
self.document_repository = FileDocumentRepository(self.db) self.document_repository = FileDocumentRepository(self.db)
self.objects_folder = objects_folder or get_objects_folder() self.objects_folder = objects_folder or get_objects_folder()
async def initialize(self): def initialize(self):
await self.document_repository.initialize() self.document_repository.initialize()
return self return self
@staticmethod @staticmethod
@@ -136,7 +136,7 @@ class DocumentService:
with open(target_path, "wb") as f: with open(target_path, "wb") as f:
f.write(content) f.write(content)
async def create_document( def create_document(
self, self,
file_path: str, file_path: str,
file_bytes: bytes | None = None, file_bytes: bytes | None = None,
@@ -187,7 +187,7 @@ class DocumentService:
mime_type=mime_type mime_type=mime_type
) )
created_file = await self.document_repository.create_document(file_data) created_file = self.document_repository.create_document(file_data)
return created_file return created_file
@@ -195,7 +195,7 @@ class DocumentService:
# Transaction will automatically rollback if supported # Transaction will automatically rollback if supported
raise PyMongoError(f"Failed to create document: {str(e)}") raise PyMongoError(f"Failed to create document: {str(e)}")
async def get_document_by_id(self, document_id: PyObjectId) -> Optional[FileDocument]: def get_document_by_id(self, document_id: PyObjectId) -> Optional[FileDocument]:
""" """
Retrieve a document by its ID. Retrieve a document by its ID.
@@ -205,9 +205,9 @@ class DocumentService:
Returns: Returns:
FileDocument if found, None otherwise FileDocument if found, None otherwise
""" """
return await self.document_repository.find_document_by_id(str(document_id)) return self.document_repository.find_document_by_id(str(document_id))
async def get_document_by_hash(self, file_hash: str) -> Optional[FileDocument]: def get_document_by_hash(self, file_hash: str) -> Optional[FileDocument]:
""" """
Retrieve a document by its file hash. Retrieve a document by its file hash.
@@ -217,9 +217,9 @@ class DocumentService:
Returns: Returns:
FileDocument if found, None otherwise FileDocument if found, None otherwise
""" """
return await self.document_repository.find_document_by_hash(file_hash) return self.document_repository.find_document_by_hash(file_hash)
async def get_document_by_filepath(self, filepath: str) -> Optional[FileDocument]: def get_document_by_filepath(self, filepath: str) -> Optional[FileDocument]:
""" """
Retrieve a document by its file path. Retrieve a document by its file path.
@@ -229,9 +229,9 @@ class DocumentService:
Returns: Returns:
FileDocument if found, None otherwise FileDocument if found, None otherwise
""" """
return await self.document_repository.find_document_by_filepath(filepath) return self.document_repository.find_document_by_filepath(filepath)
async def get_document_content_by_hash(self, file_hash): def get_document_content_by_hash(self, file_hash):
target_path = self._get_document_path(file_hash) target_path = self._get_document_path(file_hash)
if not os.path.exists(target_path): if not os.path.exists(target_path):
return None return None
@@ -239,7 +239,7 @@ class DocumentService:
with open(target_path, "rb") as f: with open(target_path, "rb") as f:
return f.read() return f.read()
async def list_documents( def list_documents(
self, self,
skip: int = 0, skip: int = 0,
limit: int = 100 limit: int = 100
@@ -254,18 +254,18 @@ class DocumentService:
Returns: Returns:
List of FileDocument instances List of FileDocument instances
""" """
return await self.document_repository.list_documents(skip=skip, limit=limit) return self.document_repository.list_documents(skip=skip, limit=limit)
async def count_documents(self) -> int: def count_documents(self) -> int:
""" """
Get total number of documents. Get total number of documents.
Returns: Returns:
Total document count Total document count
""" """
return await self.document_repository.count_documents() return self.document_repository.count_documents()
async def update_document( def update_document(
self, self,
document_id: PyObjectId, document_id: PyObjectId,
update_data: Dict[str, Any] update_data: Dict[str, Any]
@@ -285,9 +285,9 @@ class DocumentService:
update_data["file_hash"] = file_hash update_data["file_hash"] = file_hash
self.save_content_if_needed(file_hash, update_data["file_bytes"]) self.save_content_if_needed(file_hash, update_data["file_bytes"])
return await self.document_repository.update_document(document_id, update_data) return self.document_repository.update_document(document_id, update_data)
async def delete_document(self, document_id: PyObjectId) -> bool: def delete_document(self, document_id: PyObjectId) -> bool:
""" """
Delete a document and its orphaned content. Delete a document and its orphaned content.
@@ -308,17 +308,17 @@ class DocumentService:
try: try:
# Get document to find its hash # Get document to find its hash
document = await self.document_repository.find_document_by_id(document_id) document = self.document_repository.find_document_by_id(document_id)
if not document: if not document:
return False return False
# Delete the document # Delete the document
deleted = await self.document_repository.delete_document(document_id) deleted = self.document_repository.delete_document(document_id)
if not deleted: if not deleted:
return False return False
# Check if content is orphaned # Check if content is orphaned
remaining_files = await self.document_repository.find_document_by_hash(document.file_hash) remaining_files = self.document_repository.find_document_by_hash(document.file_hash)
# If no other files reference this content, delete it # If no other files reference this content, delete it
if not remaining_files: if not remaining_files:

View File

@@ -32,7 +32,7 @@ class InitializationService:
""" """
self.user_service = user_service self.user_service = user_service
async def ensure_admin_user_exists(self) -> Optional[UserInDB]: def ensure_admin_user_exists(self) -> Optional[UserInDB]:
""" """
Ensure default admin user exists in the system. Ensure default admin user exists in the system.
@@ -48,7 +48,7 @@ class InitializationService:
logger.info("Checking if admin user exists...") logger.info("Checking if admin user exists...")
# Check if any admin user already exists # Check if any admin user already exists
if await self._admin_user_exists(): if self._admin_user_exists():
logger.info("Admin user already exists, skipping creation") logger.info("Admin user already exists, skipping creation")
return None return None
@@ -63,7 +63,7 @@ class InitializationService:
role=UserRole.ADMIN role=UserRole.ADMIN
) )
created_user = await self.user_service.create_user(admin_data) created_user = self.user_service.create_user(admin_data)
logger.info(f"Default admin user created successfully with ID: {created_user.id}") logger.info(f"Default admin user created successfully with ID: {created_user.id}")
logger.warning( logger.warning(
"Default admin user created with username 'admin' and password 'admin'. " "Default admin user created with username 'admin' and password 'admin'. "
@@ -76,7 +76,7 @@ class InitializationService:
logger.error(f"Failed to create default admin user: {str(e)}") logger.error(f"Failed to create default admin user: {str(e)}")
raise Exception(f"Admin user creation failed: {str(e)}") raise Exception(f"Admin user creation failed: {str(e)}")
async def _admin_user_exists(self) -> bool: def _admin_user_exists(self) -> bool:
""" """
Check if any admin user exists in the system. Check if any admin user exists in the system.
@@ -85,7 +85,7 @@ class InitializationService:
""" """
try: try:
# Get all users and check if any have admin role # Get all users and check if any have admin role
users = await self.user_service.list_users(limit=1000) # Reasonable limit for admin check users = self.user_service.list_users(limit=1000) # Reasonable limit for admin check
for user in users: for user in users:
if user.role == UserRole.ADMIN and user.is_active: if user.role == UserRole.ADMIN and user.is_active:
@@ -98,7 +98,7 @@ class InitializationService:
# In case of error, assume admin exists to avoid creating duplicates # In case of error, assume admin exists to avoid creating duplicates
return True return True
async def initialize_application(self) -> dict: def initialize_application(self) -> dict:
""" """
Perform all application initialization tasks. Perform all application initialization tasks.
@@ -118,7 +118,7 @@ class InitializationService:
try: try:
# Ensure admin user exists # Ensure admin user exists
created_admin = await self.ensure_admin_user_exists() created_admin = self.ensure_admin_user_exists()
if created_admin: if created_admin:
initialization_summary["admin_user_created"] = True initialization_summary["admin_user_created"] = True

View File

@@ -31,11 +31,11 @@ class JobService:
self.db = database self.db = database
self.repository = JobRepository(database) self.repository = JobRepository(database)
async def initialize(self): def initialize(self):
await self.repository.initialize() self.repository.initialize()
return self return self
async def create_job(self, document_id: PyObjectId, task_id: Optional[str] = None) -> ProcessingJob: def create_job(self, document_id: PyObjectId, task_id: Optional[str] = None) -> ProcessingJob:
""" """
Create a new processing job. Create a new processing job.
@@ -49,9 +49,9 @@ class JobService:
Raises: Raises:
JobRepositoryError: If database operation fails JobRepositoryError: If database operation fails
""" """
return await self.repository.create_job(document_id, task_id) return self.repository.create_job(document_id, task_id)
async def get_job_by_id(self, job_id: PyObjectId) -> ProcessingJob: def get_job_by_id(self, job_id: PyObjectId) -> ProcessingJob:
""" """
Retrieve a job by its ID. Retrieve a job by its ID.
@@ -65,9 +65,9 @@ class JobService:
JobNotFoundError: If job doesn't exist JobNotFoundError: If job doesn't exist
JobRepositoryError: If database operation fails JobRepositoryError: If database operation fails
""" """
return await self.repository.find_job_by_id(job_id) return self.repository.find_job_by_id(job_id)
async def mark_job_as_started(self, job_id: PyObjectId) -> ProcessingJob: def mark_job_as_started(self, job_id: PyObjectId) -> ProcessingJob:
""" """
Mark a job as started (PENDING → PROCESSING). Mark a job as started (PENDING → PROCESSING).
@@ -83,16 +83,16 @@ class JobService:
JobRepositoryError: If database operation fails JobRepositoryError: If database operation fails
""" """
# Get current job to validate transition # Get current job to validate transition
current_job = await self.repository.find_job_by_id(job_id) current_job = self.repository.find_job_by_id(job_id)
# Validate status transition # Validate status transition
if current_job.status != ProcessingStatus.PENDING: if current_job.status != ProcessingStatus.PENDING:
raise InvalidStatusTransitionError(current_job.status, ProcessingStatus.PROCESSING) raise InvalidStatusTransitionError(current_job.status, ProcessingStatus.PROCESSING)
# Update status # Update status
return await self.repository.update_job_status(job_id, ProcessingStatus.PROCESSING) return self.repository.update_job_status(job_id, ProcessingStatus.PROCESSING)
async def mark_job_as_completed(self, job_id: PyObjectId) -> ProcessingJob: def mark_job_as_completed(self, job_id: PyObjectId) -> ProcessingJob:
""" """
Mark a job as completed (PROCESSING → COMPLETED). Mark a job as completed (PROCESSING → COMPLETED).
@@ -108,16 +108,16 @@ class JobService:
JobRepositoryError: If database operation fails JobRepositoryError: If database operation fails
""" """
# Get current job to validate transition # Get current job to validate transition
current_job = await self.repository.find_job_by_id(job_id) current_job = self.repository.find_job_by_id(job_id)
# Validate status transition # Validate status transition
if current_job.status != ProcessingStatus.PROCESSING: if current_job.status != ProcessingStatus.PROCESSING:
raise InvalidStatusTransitionError(current_job.status, ProcessingStatus.COMPLETED) raise InvalidStatusTransitionError(current_job.status, ProcessingStatus.COMPLETED)
# Update status # Update status
return await self.repository.update_job_status(job_id, ProcessingStatus.COMPLETED) return self.repository.update_job_status(job_id, ProcessingStatus.COMPLETED)
async def mark_job_as_failed( def mark_job_as_failed(
self, self,
job_id: PyObjectId, job_id: PyObjectId,
error_message: Optional[str] = None error_message: Optional[str] = None
@@ -138,20 +138,20 @@ class JobService:
JobRepositoryError: If database operation fails JobRepositoryError: If database operation fails
""" """
# Get current job to validate transition # Get current job to validate transition
current_job = await self.repository.find_job_by_id(job_id) current_job = self.repository.find_job_by_id(job_id)
# Validate status transition # Validate status transition
if current_job.status != ProcessingStatus.PROCESSING: if current_job.status != ProcessingStatus.PROCESSING:
raise InvalidStatusTransitionError(current_job.status, ProcessingStatus.FAILED) raise InvalidStatusTransitionError(current_job.status, ProcessingStatus.FAILED)
# Update status with error message # Update status with error message
return await self.repository.update_job_status( return self.repository.update_job_status(
job_id, job_id,
ProcessingStatus.FAILED, ProcessingStatus.FAILED,
error_message error_message
) )
async def delete_job(self, job_id: PyObjectId) -> bool: def delete_job(self, job_id: PyObjectId) -> bool:
""" """
Delete a job from the database. Delete a job from the database.
@@ -164,9 +164,9 @@ class JobService:
Raises: Raises:
JobRepositoryError: If database operation fails JobRepositoryError: If database operation fails
""" """
return await self.repository.delete_job(job_id) return self.repository.delete_job(job_id)
async def get_jobs_by_status(self, status: ProcessingStatus) -> list[ProcessingJob]: def get_jobs_by_status(self, status: ProcessingStatus) -> list[ProcessingJob]:
""" """
Retrieve all jobs with a specific status. Retrieve all jobs with a specific status.
@@ -179,4 +179,4 @@ class JobService:
Raises: Raises:
JobRepositoryError: If database operation fails JobRepositoryError: If database operation fails
""" """
return await self.repository.get_jobs_by_status(status) return self.repository.get_jobs_by_status(status)

View File

@@ -33,11 +33,11 @@ class UserService:
self.user_repository = UserRepository(self.db) self.user_repository = UserRepository(self.db)
self.auth_service = AuthService() self.auth_service = AuthService()
async def initialize(self): def initialize(self):
await self.user_repository.initialize() self.user_repository.initialize()
return self return self
async def create_user(self, user_data: UserCreate | UserCreateNoValidation) -> UserInDB: def create_user(self, user_data: UserCreate | UserCreateNoValidation) -> UserInDB:
""" """
Create a new user with business logic validation. Create a new user with business logic validation.
@@ -60,11 +60,11 @@ class UserService:
raise ValueError(f"User with email '{user_data.email}' already exists") raise ValueError(f"User with email '{user_data.email}' already exists")
try: try:
return await self.user_repository.create_user(user_data) return self.user_repository.create_user(user_data)
except DuplicateKeyError: except DuplicateKeyError:
raise ValueError(f"User with username '{user_data.username}' already exists") raise ValueError(f"User with username '{user_data.username}' already exists")
async def get_user_by_username(self, username: str) -> Optional[UserInDB]: def get_user_by_username(self, username: str) -> Optional[UserInDB]:
""" """
Retrieve user by username. Retrieve user by username.
@@ -74,9 +74,9 @@ class UserService:
Returns: Returns:
UserInDB or None: User if found, None otherwise UserInDB or None: User if found, None otherwise
""" """
return await self.user_repository.find_user_by_username(username) return self.user_repository.find_user_by_username(username)
async def get_user_by_id(self, user_id: str) -> Optional[UserInDB]: def get_user_by_id(self, user_id: str) -> Optional[UserInDB]:
""" """
Retrieve user by ID. Retrieve user by ID.
@@ -86,9 +86,9 @@ class UserService:
Returns: Returns:
UserInDB or None: User if found, None otherwise UserInDB or None: User if found, None otherwise
""" """
return await self.user_repository.find_user_by_id(user_id) return self.user_repository.find_user_by_id(user_id)
async def authenticate_user(self, username: str, password: str) -> Optional[UserInDB]: def authenticate_user(self, username: str, password: str) -> Optional[UserInDB]:
""" """
Authenticate user with username and password. Authenticate user with username and password.
@@ -111,7 +111,7 @@ class UserService:
return user return user
async def update_user(self, user_id: str, user_update: UserUpdate) -> Optional[UserInDB]: def update_user(self, user_id: str, user_update: UserUpdate) -> Optional[UserInDB]:
""" """
Update user information. Update user information.
@@ -137,9 +137,9 @@ class UserService:
if existing_user and str(existing_user.id) != user_id: if existing_user and str(existing_user.id) != user_id:
raise ValueError(f"Email '{user_update.email}' is already taken") raise ValueError(f"Email '{user_update.email}' is already taken")
return await self.user_repository.update_user(user_id, user_update) return self.user_repository.update_user(user_id, user_update)
async def delete_user(self, user_id: str) -> bool: def delete_user(self, user_id: str) -> bool:
""" """
Delete user from system. Delete user from system.
@@ -151,7 +151,7 @@ class UserService:
""" """
return self.user_repository.delete_user(user_id) return self.user_repository.delete_user(user_id)
async def list_users(self, skip: int = 0, limit: int = 100) -> List[UserInDB]: def list_users(self, skip: int = 0, limit: int = 100) -> List[UserInDB]:
""" """
List users with pagination. List users with pagination.
@@ -162,18 +162,18 @@ class UserService:
Returns: Returns:
List[UserInDB]: List of users List[UserInDB]: List of users
""" """
return await self.user_repository.list_users(skip=skip, limit=limit) return self.user_repository.list_users(skip=skip, limit=limit)
async def count_users(self) -> int: def count_users(self) -> int:
""" """
Count total number of users. Count total number of users.
Returns: Returns:
int: Total number of users in system int: Total number of users in system
""" """
return await self.user_repository.count_users() return self.user_repository.count_users()
async def user_exists(self, username: str) -> bool: def user_exists(self, username: str) -> bool:
""" """
Check if user exists by username. Check if user exists by username.
@@ -183,4 +183,4 @@ class UserService:
Returns: Returns:
bool: True if user exists, False otherwise bool: True if user exists, False otherwise
""" """
return await self.user_repository.user_exists(username) return self.user_repository.user_exists(username)

View File

@@ -1,3 +1,4 @@
asgiref==3.9.1
bcrypt==4.3.0 bcrypt==4.3.0
celery==5.5.3 celery==5.5.3
email-validator==2.3.0 email-validator==2.3.0

View File

@@ -1,3 +1,4 @@
asgiref==3.9.1
bcrypt==4.3.0 bcrypt==4.3.0
celery==5.5.3 celery==5.5.3
email-validator==2.3.0 email-validator==2.3.0

View File

@@ -11,11 +11,12 @@ from typing import Any, Dict
from app.config import settings from app.config import settings
from app.database.connection import get_database from app.database.connection import get_database
from app.services.document_service import DocumentService from app.services.document_service import DocumentService
from tasks.main import celery_app
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@celery_app.task(bind=True, autoretry_for=(Exception,), retry_kwargs={'max_retries': 3, 'countdown': 60})
async def process_document_async(self, filepath: str) -> Dict[str, Any]: def process_document(self, filepath: str) -> Dict[str, Any]:
""" """
Process a document file and extract its content. Process a document file and extract its content.
@@ -45,18 +46,18 @@ async def process_document_async(self, filepath: str) -> Dict[str, Any]:
job = None job = None
try: try:
# Step 1: Insert the document in DB # Step 1: Insert the document in DB
document = await document_service.create_document(filepath) document = document_service.create_document(filepath)
logger.info(f"Job {task_id} created for document {document.id} with file path: {filepath}") logger.info(f"Job {task_id} created for document {document.id} with file path: {filepath}")
# Step 2: Create a new job record for the document # Step 2: Create a new job record for the document
job = await job_service.create_job(task_id=task_id, document_id=document.id) job = job_service.create_job(task_id=task_id, document_id=document.id)
# Step 3: Mark job as started # Step 3: Mark job as started
await job_service.mark_job_as_started(job_id=job.id) job_service.mark_job_as_started(job_id=job.id)
logger.info(f"Job {task_id} marked as PROCESSING") logger.info(f"Job {task_id} marked as PROCESSING")
# Step 4: Mark job as completed # Step 4: Mark job as completed
await job_service.mark_job_as_completed(job_id=job.id) job_service.mark_job_as_completed(job_id=job.id)
logger.info(f"Job {task_id} marked as COMPLETED") logger.info(f"Job {task_id} marked as COMPLETED")
return { return {
@@ -72,7 +73,7 @@ async def process_document_async(self, filepath: str) -> Dict[str, Any]:
try: try:
# Mark job as failed # Mark job as failed
if job is not None: if job is not None:
await job_service.mark_job_as_failed(job_id=job.id, error_message=error_message) job_service.mark_job_as_failed(job_id=job.id, error_message=error_message)
logger.info(f"Job {task_id} marked as FAILED") logger.info(f"Job {task_id} marked as FAILED")
else: else:
logger.error(f"Failed to process {filepath}. error = {str(e)}") logger.error(f"Failed to process {filepath}. error = {str(e)}")
@@ -81,3 +82,4 @@ async def process_document_async(self, filepath: str) -> Dict[str, Any]:
# Re-raise the exception to trigger Celery retry mechanism # Re-raise the exception to trigger Celery retry mechanism
raise raise

View File

@@ -3,13 +3,10 @@ Celery worker for MyDocManager document processing tasks.
This module contains all Celery tasks for processing documents. This module contains all Celery tasks for processing documents.
""" """
import asyncio
import os import os
from celery import Celery from celery import Celery
from tasks.document_processing import process_document_async
# Environment variables # Environment variables
REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0") REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0")
MONGODB_URL = os.getenv("MONGODB_URL", "mongodb://localhost:27017") MONGODB_URL = os.getenv("MONGODB_URL", "mongodb://localhost:27017")
@@ -21,6 +18,8 @@ celery_app = Celery(
backend=REDIS_URL, backend=REDIS_URL,
) )
celery_app.autodiscover_tasks(["tasks.document_processing"])
# Celery configuration # Celery configuration
celery_app.conf.update( celery_app.conf.update(
task_serializer="json", task_serializer="json",
@@ -33,11 +32,5 @@ celery_app.conf.update(
task_soft_time_limit=240, # 4 minutes task_soft_time_limit=240, # 4 minutes
) )
@celery_app.task(bind=True, autoretry_for=(Exception,), retry_kwargs={'max_retries': 3, 'countdown': 60})
def process_document(self, filepath: str):
return asyncio.run(process_document_async(self, filepath))
if __name__ == "__main__": if __name__ == "__main__":
celery_app.start() celery_app.start()

View File

@@ -1,18 +1,16 @@
""" """
Test suite for FileDocumentRepository with async/await support. Test suite for FileDocumentRepository with async/support.
This module contains comprehensive tests for all FileDocumentRepository methods This module contains comprehensive tests for all FileDocumentRepository methods
using mongomock-motor for in-memory MongoDB testing. using mongomock-motor for in-memory MongoDB testing.
""" """
import pytest
from datetime import datetime from datetime import datetime
from typing import Dict, Any
import pytest_asyncio import pytest
from bson import ObjectId from bson import ObjectId
from pymongo.errors import DuplicateKeyError, PyMongoError from mongomock.mongo_client import MongoClient
from mongomock_motor import AsyncMongoMockClient from pymongo.errors import PyMongoError
from app.database.repositories.document_repository import ( from app.database.repositories.document_repository import (
FileDocumentRepository, FileDocumentRepository,
@@ -23,13 +21,13 @@ from app.database.repositories.document_repository import (
from app.models.document import FileDocument, FileType, ExtractionMethod from app.models.document import FileDocument, FileType, ExtractionMethod
@pytest_asyncio.fixture @pytest.fixture
async def in_memory_repository(): def in_memory_repository():
"""Create an in-memory FileDocumentRepository for testing.""" """Create an in-memory FileDocumentRepository for testing."""
client = AsyncMongoMockClient() client = MongoClient()
db = client.test_database db = client.test_database
repo = FileDocumentRepository(db) repo = FileDocumentRepository(db)
await repo.initialize() repo.initialize()
return repo return repo
@@ -107,14 +105,13 @@ def multiple_sample_files():
class TestFileDocumentRepositoryInitialization: class TestFileDocumentRepositoryInitialization:
"""Tests for repository initialization.""" """Tests for repository initialization."""
@pytest.mark.asyncio def test_i_can_initialize_repository(self):
async def test_i_can_initialize_repository(self):
"""Test repository initialization.""" """Test repository initialization."""
# Arrange # Arrange
client = AsyncMongoMockClient() client = MongoClient()
db = client.test_database db = client.test_database
repo = FileDocumentRepository(db) repo = FileDocumentRepository(db)
await repo.initialize() repo.initialize()
# Act & Assert (should not raise any exception) # Act & Assert (should not raise any exception)
assert repo.db is not None assert repo.db is not None
@@ -125,11 +122,10 @@ class TestFileDocumentRepositoryInitialization:
class TestFileDocumentRepositoryCreation: class TestFileDocumentRepositoryCreation:
"""Tests for file document creation functionality.""" """Tests for file document creation functionality."""
@pytest.mark.asyncio def test_i_can_create_file_document(self, in_memory_repository, sample_file_document):
async def test_i_can_create_file_document(self, in_memory_repository, sample_file_document):
"""Test successful file document creation.""" """Test successful file document creation."""
# Act # Act
created_file = await in_memory_repository.create_document(sample_file_document) created_file = in_memory_repository.create_document(sample_file_document)
# Assert # Assert
assert created_file is not None assert created_file is not None
@@ -144,46 +140,20 @@ class TestFileDocumentRepositoryCreation:
assert created_file.id is not None assert created_file.id is not None
assert isinstance(created_file.id, ObjectId) assert isinstance(created_file.id, ObjectId)
@pytest.mark.asyncio def test_i_can_create_file_document_without_id(self, in_memory_repository, sample_file_document):
async def test_i_can_create_file_document_without_id(self, in_memory_repository, sample_file_document):
"""Test creating file document with _id set to None (should be removed).""" """Test creating file document with _id set to None (should be removed)."""
# Arrange # Arrange
sample_file_document.id = None sample_file_document.id = None
# Act # Act
created_file = await in_memory_repository.create_document(sample_file_document) created_file = in_memory_repository.create_document(sample_file_document)
# Assert # Assert
assert created_file is not None assert created_file is not None
assert created_file.id is not None assert created_file.id is not None
assert isinstance(created_file.id, ObjectId) assert isinstance(created_file.id, ObjectId)
@pytest.mark.asyncio def test_i_cannot_create_file_document_with_pymongo_error(self, in_memory_repository,
async def test_i_cannot_create_duplicate_file_document(self, in_memory_repository, sample_file_document):
"""Test that creating file document with duplicate filepath raises DuplicateKeyError."""
# Arrange
await in_memory_repository.create_document(sample_file_document)
duplicate_file = FileDocument(
filename="different_name.pdf",
filepath=sample_file_document.filepath, # Same filepath
file_type=FileType.PDF,
extraction_method=ExtractionMethod.OCR,
metadata={"different": "metadata"},
detected_at=datetime.now(),
file_hash="different_hash_123456789012345678901234567890123456789012345678",
encoding="utf-8",
file_size=2000,
mime_type="application/pdf"
)
# Act & Assert
with pytest.raises(DuplicateKeyError) as exc_info:
await in_memory_repository.create_document(duplicate_file)
assert "already exists" in str(exc_info.value)
@pytest.mark.asyncio
async def test_i_cannot_create_file_document_with_pymongo_error(self, in_memory_repository,
sample_file_document, mocker): sample_file_document, mocker):
"""Test handling of PyMongo errors during file document creation.""" """Test handling of PyMongo errors during file document creation."""
# Arrange # Arrange
@@ -191,7 +161,7 @@ class TestFileDocumentRepositoryCreation:
# Act & Assert # Act & Assert
with pytest.raises(ValueError) as exc_info: with pytest.raises(ValueError) as exc_info:
await in_memory_repository.create_document(sample_file_document) in_memory_repository.create_document(sample_file_document)
assert "Failed to create file document" in str(exc_info.value) assert "Failed to create file document" in str(exc_info.value)
@@ -199,14 +169,13 @@ class TestFileDocumentRepositoryCreation:
class TestFileDocumentRepositoryFinding: class TestFileDocumentRepositoryFinding:
"""Tests for file document finding functionality.""" """Tests for file document finding functionality."""
@pytest.mark.asyncio def test_i_can_find_document_by_valid_id(self, in_memory_repository, sample_file_document):
async def test_i_can_find_document_by_valid_id(self, in_memory_repository, sample_file_document):
"""Test finding file document by valid ObjectId.""" """Test finding file document by valid ObjectId."""
# Arrange # Arrange
created_file = await in_memory_repository.create_document(sample_file_document) created_file = in_memory_repository.create_document(sample_file_document)
# Act # Act
found_file = await in_memory_repository.find_document_by_id(str(created_file.id)) found_file = in_memory_repository.find_document_by_id(str(created_file.id))
# Assert # Assert
assert found_file is not None assert found_file is not None
@@ -214,81 +183,74 @@ class TestFileDocumentRepositoryFinding:
assert found_file.filename == created_file.filename assert found_file.filename == created_file.filename
assert found_file.filepath == created_file.filepath assert found_file.filepath == created_file.filepath
@pytest.mark.asyncio def test_i_cannot_find_document_with_invalid_id(self, in_memory_repository):
async def test_i_cannot_find_document_with_invalid_id(self, in_memory_repository):
"""Test that invalid ObjectId returns None.""" """Test that invalid ObjectId returns None."""
# Act # Act
found_file = await in_memory_repository.find_document_by_id("invalid_id") found_file = in_memory_repository.find_document_by_id("invalid_id")
# Assert # Assert
assert found_file is None assert found_file is None
@pytest.mark.asyncio def test_i_cannot_find_document_by_nonexistent_id(self, in_memory_repository):
async def test_i_cannot_find_document_by_nonexistent_id(self, in_memory_repository):
"""Test that nonexistent but valid ObjectId returns None.""" """Test that nonexistent but valid ObjectId returns None."""
# Arrange # Arrange
nonexistent_id = str(ObjectId()) nonexistent_id = str(ObjectId())
# Act # Act
found_file = await in_memory_repository.find_document_by_id(nonexistent_id) found_file = in_memory_repository.find_document_by_id(nonexistent_id)
# Assert # Assert
assert found_file is None assert found_file is None
@pytest.mark.asyncio def test_i_can_find_document_by_file_hash(self, in_memory_repository, sample_file_document):
async def test_i_can_find_document_by_file_hash(self, in_memory_repository, sample_file_document):
"""Test finding file document by file hash.""" """Test finding file document by file hash."""
# Arrange # Arrange
created_file = await in_memory_repository.create_document(sample_file_document) created_file = in_memory_repository.create_document(sample_file_document)
# Act # Act
found_file = await in_memory_repository.find_document_by_hash(sample_file_document.file_hash) found_file = in_memory_repository.find_document_by_hash(sample_file_document.file_hash)
# Assert # Assert
assert found_file is not None assert found_file is not None
assert found_file.file_hash == created_file.file_hash assert found_file.file_hash == created_file.file_hash
assert found_file.id == created_file.id assert found_file.id == created_file.id
@pytest.mark.asyncio def test_i_cannot_find_document_with_nonexistent_file_hash(self, in_memory_repository):
async def test_i_cannot_find_document_with_nonexistent_file_hash(self, in_memory_repository):
"""Test that nonexistent file hash returns None.""" """Test that nonexistent file hash returns None."""
# Act # Act
found_file = await in_memory_repository.find_document_by_hash("nonexistent_hash") found_file = in_memory_repository.find_document_by_hash("nonexistent_hash")
# Assert # Assert
assert found_file is None assert found_file is None
@pytest.mark.asyncio def test_i_can_find_document_by_filepath(self, in_memory_repository, sample_file_document):
async def test_i_can_find_document_by_filepath(self, in_memory_repository, sample_file_document):
"""Test finding file document by filepath.""" """Test finding file document by filepath."""
# Arrange # Arrange
created_file = await in_memory_repository.create_document(sample_file_document) created_file = in_memory_repository.create_document(sample_file_document)
# Act # Act
found_file = await in_memory_repository.find_document_by_filepath(sample_file_document.filepath) found_file = in_memory_repository.find_document_by_filepath(sample_file_document.filepath)
# Assert # Assert
assert found_file is not None assert found_file is not None
assert found_file.filepath == created_file.filepath assert found_file.filepath == created_file.filepath
assert found_file.id == created_file.id assert found_file.id == created_file.id
@pytest.mark.asyncio def test_i_cannot_find_document_with_nonexistent_filepath(self, in_memory_repository):
async def test_i_cannot_find_document_with_nonexistent_filepath(self, in_memory_repository):
"""Test that nonexistent filepath returns None.""" """Test that nonexistent filepath returns None."""
# Act # Act
found_file = await in_memory_repository.find_document_by_filepath("/nonexistent/path/file.pdf") found_file = in_memory_repository.find_document_by_filepath("/nonexistent/path/file.pdf")
# Assert # Assert
assert found_file is None assert found_file is None
@pytest.mark.asyncio def test_i_cannot_find_document_with_pymongo_error(self, in_memory_repository, mocker):
async def test_i_cannot_find_document_with_pymongo_error(self, in_memory_repository, mocker):
"""Test handling of PyMongo errors during file document finding.""" """Test handling of PyMongo errors during file document finding."""
# Arrange # Arrange
mocker.patch.object(in_memory_repository.collection, 'find_one', side_effect=PyMongoError("Database error")) mocker.patch.object(in_memory_repository.collection, 'find_one', side_effect=PyMongoError("Database error"))
# Act # Act
found_file = await in_memory_repository.find_document_by_hash("test_hash") found_file = in_memory_repository.find_document_by_hash("test_hash")
# Assert # Assert
assert found_file is None assert found_file is None
@@ -297,16 +259,15 @@ class TestFileDocumentRepositoryFinding:
class TestFileDocumentRepositoryNameMatching: class TestFileDocumentRepositoryNameMatching:
"""Tests for file document name matching functionality.""" """Tests for file document name matching functionality."""
@pytest.mark.asyncio def test_i_can_find_documents_by_name_with_fuzzy_matching(self, in_memory_repository, multiple_sample_files):
async def test_i_can_find_documents_by_name_with_fuzzy_matching(self, in_memory_repository, multiple_sample_files):
"""Test finding file documents by filename using fuzzy matching.""" """Test finding file documents by filename using fuzzy matching."""
# Arrange # Arrange
for file_doc in multiple_sample_files: for file_doc in multiple_sample_files:
await in_memory_repository.create_document(file_doc) in_memory_repository.create_document(file_doc)
# Act # Act
fuzzy_method = FuzzyMatching(threshold=0.5) fuzzy_method = FuzzyMatching(threshold=0.5)
found_files = await in_memory_repository.find_document_by_name("document", fuzzy_method) found_files = in_memory_repository.find_document_by_name("document", fuzzy_method)
# Assert # Assert
assert len(found_files) >= 1 assert len(found_files) >= 1
@@ -315,44 +276,41 @@ class TestFileDocumentRepositoryNameMatching:
found_filenames = [f.filename for f in found_files] found_filenames = [f.filename for f in found_files]
assert any("document" in fname.lower() for fname in found_filenames) assert any("document" in fname.lower() for fname in found_filenames)
@pytest.mark.asyncio def test_i_can_find_documents_by_name_with_subsequence_matching(self, in_memory_repository,
async def test_i_can_find_documents_by_name_with_subsequence_matching(self, in_memory_repository,
multiple_sample_files): multiple_sample_files):
"""Test finding file documents by filename using subsequence matching.""" """Test finding file documents by filename using subsequence matching."""
# Arrange # Arrange
for file_doc in multiple_sample_files: for file_doc in multiple_sample_files:
await in_memory_repository.create_document(file_doc) in_memory_repository.create_document(file_doc)
# Act # Act
subsequence_method = SubsequenceMatching() subsequence_method = SubsequenceMatching()
found_files = await in_memory_repository.find_document_by_name("doc", subsequence_method) found_files = in_memory_repository.find_document_by_name("doc", subsequence_method)
# Assert # Assert
assert len(found_files) >= 1 assert len(found_files) >= 1
assert all(isinstance(file_doc, FileDocument) for file_doc in found_files) assert all(isinstance(file_doc, FileDocument) for file_doc in found_files)
@pytest.mark.asyncio def test_i_can_find_documents_by_name_with_default_method(self, in_memory_repository, multiple_sample_files):
async def test_i_can_find_documents_by_name_with_default_method(self, in_memory_repository, multiple_sample_files):
"""Test finding file documents by filename with default matching method.""" """Test finding file documents by filename with default matching method."""
# Arrange # Arrange
for file_doc in multiple_sample_files: for file_doc in multiple_sample_files:
await in_memory_repository.create_document(file_doc) in_memory_repository.create_document(file_doc)
# Act # Act
found_files = await in_memory_repository.find_document_by_name("first") found_files = in_memory_repository.find_document_by_name("first")
# Assert # Assert
assert len(found_files) >= 0 assert len(found_files) >= 0
assert all(isinstance(file_doc, FileDocument) for file_doc in found_files) assert all(isinstance(file_doc, FileDocument) for file_doc in found_files)
@pytest.mark.asyncio def test_i_cannot_find_documents_by_name_with_pymongo_error(self, in_memory_repository, mocker):
async def test_i_cannot_find_documents_by_name_with_pymongo_error(self, in_memory_repository, mocker):
"""Test handling of PyMongo errors during document name matching.""" """Test handling of PyMongo errors during document name matching."""
# Arrange # Arrange
mocker.patch.object(in_memory_repository.collection, 'find', side_effect=PyMongoError("Database error")) mocker.patch.object(in_memory_repository.collection, 'find', side_effect=PyMongoError("Database error"))
# Act # Act
found_files = await in_memory_repository.find_document_by_name("test") found_files = in_memory_repository.find_document_by_name("test")
# Assert # Assert
assert found_files == [] assert found_files == []
@@ -361,30 +319,28 @@ class TestFileDocumentRepositoryNameMatching:
class TestFileDocumentRepositoryListing: class TestFileDocumentRepositoryListing:
"""Tests for file document listing functionality.""" """Tests for file document listing functionality."""
@pytest.mark.asyncio def test_i_can_list_documents_with_default_pagination(self, in_memory_repository, multiple_sample_files):
async def test_i_can_list_documents_with_default_pagination(self, in_memory_repository, multiple_sample_files):
"""Test listing file documents with default pagination.""" """Test listing file documents with default pagination."""
# Arrange # Arrange
for file_doc in multiple_sample_files: for file_doc in multiple_sample_files:
await in_memory_repository.create_document(file_doc) in_memory_repository.create_document(file_doc)
# Act # Act
files = await in_memory_repository.list_documents() files = in_memory_repository.list_documents()
# Assert # Assert
assert len(files) == len(multiple_sample_files) assert len(files) == len(multiple_sample_files)
assert all(isinstance(file_doc, FileDocument) for file_doc in files) assert all(isinstance(file_doc, FileDocument) for file_doc in files)
@pytest.mark.asyncio def test_i_can_list_documents_with_custom_pagination(self, in_memory_repository, multiple_sample_files):
async def test_i_can_list_documents_with_custom_pagination(self, in_memory_repository, multiple_sample_files):
"""Test listing file documents with custom pagination.""" """Test listing file documents with custom pagination."""
# Arrange # Arrange
for file_doc in multiple_sample_files: for file_doc in multiple_sample_files:
await in_memory_repository.create_document(file_doc) in_memory_repository.create_document(file_doc)
# Act # Act
files_page1 = await in_memory_repository.list_documents(skip=0, limit=2) files_page1 = in_memory_repository.list_documents(skip=0, limit=2)
files_page2 = await in_memory_repository.list_documents(skip=2, limit=2) files_page2 = in_memory_repository.list_documents(skip=2, limit=2)
# Assert # Assert
assert len(files_page1) == 2 assert len(files_page1) == 2
@@ -395,8 +351,7 @@ class TestFileDocumentRepositoryListing:
page2_ids = [file_doc.id for file_doc in files_page2] page2_ids = [file_doc.id for file_doc in files_page2]
assert len(set(page1_ids).intersection(set(page2_ids))) == 0 assert len(set(page1_ids).intersection(set(page2_ids))) == 0
@pytest.mark.asyncio def test_i_can_list_documents_sorted_by_detected_at(self, in_memory_repository, sample_file_document):
async def test_i_can_list_documents_sorted_by_detected_at(self, in_memory_repository, sample_file_document):
"""Test that file documents are sorted by detected_at in descending order.""" """Test that file documents are sorted by detected_at in descending order."""
# Arrange # Arrange
file1 = sample_file_document.model_copy() file1 = sample_file_document.model_copy()
@@ -411,11 +366,11 @@ class TestFileDocumentRepositoryListing:
file2.file_hash = "hash2" + "0" * 58 file2.file_hash = "hash2" + "0" * 58
file2.detected_at = datetime(2024, 1, 2, 10, 0, 0) # Later date file2.detected_at = datetime(2024, 1, 2, 10, 0, 0) # Later date
created_file1 = await in_memory_repository.create_document(file1) created_file1 = in_memory_repository.create_document(file1)
created_file2 = await in_memory_repository.create_document(file2) created_file2 = in_memory_repository.create_document(file2)
# Act # Act
files = await in_memory_repository.list_documents() files = in_memory_repository.list_documents()
# Assert # Assert
assert len(files) == 2 assert len(files) == 2
@@ -423,23 +378,21 @@ class TestFileDocumentRepositoryListing:
assert files[0].id == created_file2.id assert files[0].id == created_file2.id
assert files[1].id == created_file1.id assert files[1].id == created_file1.id
@pytest.mark.asyncio def test_i_can_list_empty_documents(self, in_memory_repository):
async def test_i_can_list_empty_documents(self, in_memory_repository):
"""Test listing file documents from empty collection.""" """Test listing file documents from empty collection."""
# Act # Act
files = await in_memory_repository.list_documents() files = in_memory_repository.list_documents()
# Assert # Assert
assert files == [] assert files == []
@pytest.mark.asyncio def test_i_cannot_list_documents_with_pymongo_error(self, in_memory_repository, mocker):
async def test_i_cannot_list_documents_with_pymongo_error(self, in_memory_repository, mocker):
"""Test handling of PyMongo errors during file document listing.""" """Test handling of PyMongo errors during file document listing."""
# Arrange # Arrange
mocker.patch.object(in_memory_repository.collection, 'find', side_effect=PyMongoError("Database error")) mocker.patch.object(in_memory_repository.collection, 'find', side_effect=PyMongoError("Database error"))
# Act # Act
files = await in_memory_repository.list_documents() files = in_memory_repository.list_documents()
# Assert # Assert
assert files == [] assert files == []
@@ -448,15 +401,14 @@ class TestFileDocumentRepositoryListing:
class TestFileDocumentRepositoryUpdate: class TestFileDocumentRepositoryUpdate:
"""Tests for file document update functionality.""" """Tests for file document update functionality."""
@pytest.mark.asyncio def test_i_can_update_document_successfully(self, in_memory_repository, sample_file_document,
async def test_i_can_update_document_successfully(self, in_memory_repository, sample_file_document,
sample_update_data): sample_update_data):
"""Test successful file document update.""" """Test successful file document update."""
# Arrange # Arrange
created_file = await in_memory_repository.create_document(sample_file_document) created_file = in_memory_repository.create_document(sample_file_document)
# Act # Act
updated_file = await in_memory_repository.update_document(str(created_file.id), sample_update_data) updated_file = in_memory_repository.update_document(str(created_file.id), sample_update_data)
# Assert # Assert
assert updated_file is not None assert updated_file is not None
@@ -467,15 +419,14 @@ class TestFileDocumentRepositoryUpdate:
assert updated_file.filename == created_file.filename # Unchanged fields remain assert updated_file.filename == created_file.filename # Unchanged fields remain
assert updated_file.filepath == created_file.filepath assert updated_file.filepath == created_file.filepath
@pytest.mark.asyncio def test_i_can_update_document_with_partial_data(self, in_memory_repository, sample_file_document):
async def test_i_can_update_document_with_partial_data(self, in_memory_repository, sample_file_document):
"""Test updating file document with partial data.""" """Test updating file document with partial data."""
# Arrange # Arrange
created_file = await in_memory_repository.create_document(sample_file_document) created_file = in_memory_repository.create_document(sample_file_document)
partial_update = {"file_size": 999999} partial_update = {"file_size": 999999}
# Act # Act
updated_file = await in_memory_repository.update_document(str(created_file.id), partial_update) updated_file = in_memory_repository.update_document(str(created_file.id), partial_update)
# Assert # Assert
assert updated_file is not None assert updated_file is not None
@@ -483,30 +434,28 @@ class TestFileDocumentRepositoryUpdate:
assert updated_file.filename == created_file.filename # Should remain unchanged assert updated_file.filename == created_file.filename # Should remain unchanged
assert updated_file.metadata == created_file.metadata # Should remain unchanged assert updated_file.metadata == created_file.metadata # Should remain unchanged
@pytest.mark.asyncio def test_i_can_update_document_filtering_none_values(self, in_memory_repository, sample_file_document):
async def test_i_can_update_document_filtering_none_values(self, in_memory_repository, sample_file_document):
"""Test that None values are filtered out from update data.""" """Test that None values are filtered out from update data."""
# Arrange # Arrange
created_file = await in_memory_repository.create_document(sample_file_document) created_file = in_memory_repository.create_document(sample_file_document)
update_with_none = {"file_size": 777777, "metadata": None} update_with_none = {"file_size": 777777, "metadata": None}
# Act # Act
updated_file = await in_memory_repository.update_document(str(created_file.id), update_with_none) updated_file = in_memory_repository.update_document(str(created_file.id), update_with_none)
# Assert # Assert
assert updated_file is not None assert updated_file is not None
assert updated_file.file_size == 777777 assert updated_file.file_size == 777777
assert updated_file.metadata == created_file.metadata # Should remain unchanged (None filtered out) assert updated_file.metadata == created_file.metadata # Should remain unchanged (None filtered out)
@pytest.mark.asyncio def test_i_can_update_document_with_empty_data(self, in_memory_repository, sample_file_document):
async def test_i_can_update_document_with_empty_data(self, in_memory_repository, sample_file_document):
"""Test updating file document with empty data returns current document.""" """Test updating file document with empty data returns current document."""
# Arrange # Arrange
created_file = await in_memory_repository.create_document(sample_file_document) created_file = in_memory_repository.create_document(sample_file_document)
empty_update = {} empty_update = {}
# Act # Act
result = await in_memory_repository.update_document(str(created_file.id), empty_update) result = in_memory_repository.update_document(str(created_file.id), empty_update)
# Assert # Assert
assert result is not None assert result is not None
@@ -514,38 +463,35 @@ class TestFileDocumentRepositoryUpdate:
assert result.filepath == created_file.filepath assert result.filepath == created_file.filepath
assert result.metadata == created_file.metadata assert result.metadata == created_file.metadata
@pytest.mark.asyncio def test_i_cannot_update_document_with_invalid_id(self, in_memory_repository, sample_update_data):
async def test_i_cannot_update_document_with_invalid_id(self, in_memory_repository, sample_update_data):
"""Test that updating with invalid ID returns None.""" """Test that updating with invalid ID returns None."""
# Act # Act
result = await in_memory_repository.update_document("invalid_id", sample_update_data) result = in_memory_repository.update_document("invalid_id", sample_update_data)
# Assert # Assert
assert result is None assert result is None
@pytest.mark.asyncio def test_i_cannot_update_nonexistent_document(self, in_memory_repository, sample_update_data):
async def test_i_cannot_update_nonexistent_document(self, in_memory_repository, sample_update_data):
"""Test that updating nonexistent file document returns None.""" """Test that updating nonexistent file document returns None."""
# Arrange # Arrange
nonexistent_id = str(ObjectId()) nonexistent_id = str(ObjectId())
# Act # Act
result = await in_memory_repository.update_document(nonexistent_id, sample_update_data) result = in_memory_repository.update_document(nonexistent_id, sample_update_data)
# Assert # Assert
assert result is None assert result is None
@pytest.mark.asyncio def test_i_cannot_update_document_with_pymongo_error(self, in_memory_repository, sample_file_document,
async def test_i_cannot_update_document_with_pymongo_error(self, in_memory_repository, sample_file_document,
sample_update_data, mocker): sample_update_data, mocker):
"""Test handling of PyMongo errors during file document update.""" """Test handling of PyMongo errors during file document update."""
# Arrange # Arrange
created_file = await in_memory_repository.create_document(sample_file_document) created_file = in_memory_repository.create_document(sample_file_document)
mocker.patch.object(in_memory_repository.collection, 'find_one_and_update', mocker.patch.object(in_memory_repository.collection, 'find_one_and_update',
side_effect=PyMongoError("Database error")) side_effect=PyMongoError("Database error"))
# Act # Act
result = await in_memory_repository.update_document(str(created_file.id), sample_update_data) result = in_memory_repository.update_document(str(created_file.id), sample_update_data)
# Assert # Assert
assert result is None assert result is None
@@ -554,52 +500,48 @@ class TestFileDocumentRepositoryUpdate:
class TestFileDocumentRepositoryDeletion: class TestFileDocumentRepositoryDeletion:
"""Tests for file document deletion functionality.""" """Tests for file document deletion functionality."""
@pytest.mark.asyncio def test_i_can_delete_existing_document(self, in_memory_repository, sample_file_document):
async def test_i_can_delete_existing_document(self, in_memory_repository, sample_file_document):
"""Test successful file document deletion.""" """Test successful file document deletion."""
# Arrange # Arrange
created_file = await in_memory_repository.create_document(sample_file_document) created_file = in_memory_repository.create_document(sample_file_document)
# Act # Act
deletion_result = await in_memory_repository.delete_document(str(created_file.id)) deletion_result = in_memory_repository.delete_document(str(created_file.id))
# Assert # Assert
assert deletion_result is True assert deletion_result is True
# Verify document is actually deleted # Verify document is actually deleted
found_file = await in_memory_repository.find_document_by_id(str(created_file.id)) found_file = in_memory_repository.find_document_by_id(str(created_file.id))
assert found_file is None assert found_file is None
@pytest.mark.asyncio def test_i_cannot_delete_document_with_invalid_id(self, in_memory_repository):
async def test_i_cannot_delete_document_with_invalid_id(self, in_memory_repository):
"""Test that deleting with invalid ID returns False.""" """Test that deleting with invalid ID returns False."""
# Act # Act
result = await in_memory_repository.delete_document("invalid_id") result = in_memory_repository.delete_document("invalid_id")
# Assert # Assert
assert result is False assert result is False
@pytest.mark.asyncio def test_i_cannot_delete_nonexistent_document(self, in_memory_repository):
async def test_i_cannot_delete_nonexistent_document(self, in_memory_repository):
"""Test that deleting nonexistent file document returns False.""" """Test that deleting nonexistent file document returns False."""
# Arrange # Arrange
nonexistent_id = str(ObjectId()) nonexistent_id = str(ObjectId())
# Act # Act
result = await in_memory_repository.delete_document(nonexistent_id) result = in_memory_repository.delete_document(nonexistent_id)
# Assert # Assert
assert result is False assert result is False
@pytest.mark.asyncio def test_i_cannot_delete_document_with_pymongo_error(self, in_memory_repository, sample_file_document, mocker):
async def test_i_cannot_delete_document_with_pymongo_error(self, in_memory_repository, sample_file_document, mocker):
"""Test handling of PyMongo errors during file document deletion.""" """Test handling of PyMongo errors during file document deletion."""
# Arrange # Arrange
created_file = await in_memory_repository.create_document(sample_file_document) created_file = in_memory_repository.create_document(sample_file_document)
mocker.patch.object(in_memory_repository.collection, 'delete_one', side_effect=PyMongoError("Database error")) mocker.patch.object(in_memory_repository.collection, 'delete_one', side_effect=PyMongoError("Database error"))
# Act # Act
result = await in_memory_repository.delete_document(str(created_file.id)) result = in_memory_repository.delete_document(str(created_file.id))
# Assert # Assert
assert result is False assert result is False
@@ -608,36 +550,33 @@ class TestFileDocumentRepositoryDeletion:
class TestFileDocumentRepositoryUtilities: class TestFileDocumentRepositoryUtilities:
"""Tests for utility methods.""" """Tests for utility methods."""
@pytest.mark.asyncio def test_i_can_count_documents(self, in_memory_repository, sample_file_document):
async def test_i_can_count_documents(self, in_memory_repository, sample_file_document):
"""Test counting file documents.""" """Test counting file documents."""
# Arrange # Arrange
initial_count = await in_memory_repository.count_documents() initial_count = in_memory_repository.count_documents()
await in_memory_repository.create_document(sample_file_document) in_memory_repository.create_document(sample_file_document)
# Act # Act
final_count = await in_memory_repository.count_documents() final_count = in_memory_repository.count_documents()
# Assert # Assert
assert final_count == initial_count + 1 assert final_count == initial_count + 1
@pytest.mark.asyncio def test_i_can_count_zero_documents(self, in_memory_repository):
async def test_i_can_count_zero_documents(self, in_memory_repository):
"""Test counting file documents in empty collection.""" """Test counting file documents in empty collection."""
# Act # Act
count = await in_memory_repository.count_documents() count = in_memory_repository.count_documents()
# Assert # Assert
assert count == 0 assert count == 0
@pytest.mark.asyncio def test_i_cannot_count_documents_with_pymongo_error(self, in_memory_repository, mocker):
async def test_i_cannot_count_documents_with_pymongo_error(self, in_memory_repository, mocker):
"""Test handling of PyMongo errors during file document counting.""" """Test handling of PyMongo errors during file document counting."""
# Arrange # Arrange
mocker.patch.object(in_memory_repository.collection, 'count_documents', side_effect=PyMongoError("Database error")) mocker.patch.object(in_memory_repository.collection, 'count_documents', side_effect=PyMongoError("Database error"))
# Act # Act
count = await in_memory_repository.count_documents() count = in_memory_repository.count_documents()
# Assert # Assert
assert count == 0 assert count == 0

View File

@@ -1,5 +1,5 @@
""" """
Test suite for JobRepository with async/await support. Test suite for JobRepository with async/support.
This module contains comprehensive tests for all JobRepository methods This module contains comprehensive tests for all JobRepository methods
using mongomock-motor for in-memory MongoDB testing. using mongomock-motor for in-memory MongoDB testing.
@@ -8,8 +8,8 @@ using mongomock-motor for in-memory MongoDB testing.
from datetime import datetime from datetime import datetime
import pytest import pytest
import pytest_asyncio
from bson import ObjectId from bson import ObjectId
from mongomock.mongo_client import MongoClient
from mongomock_motor import AsyncMongoMockClient from mongomock_motor import AsyncMongoMockClient
from pymongo.errors import PyMongoError from pymongo.errors import PyMongoError
@@ -19,13 +19,13 @@ from app.models.job import ProcessingJob, ProcessingStatus
from app.models.types import PyObjectId from app.models.types import PyObjectId
@pytest_asyncio.fixture @pytest.fixture
async def in_memory_repository(): def in_memory_repository():
"""Create an in-memory JobRepository for testing.""" """Create an in-memory JobRepository for testing."""
client = AsyncMongoMockClient() client = MongoClient()
db = client.test_database db = client.test_database
repo = JobRepository(db) repo = JobRepository(db)
await repo.initialize() repo.initialize()
return repo return repo
@@ -82,8 +82,7 @@ def multiple_sample_jobs():
class TestJobRepositoryInitialization: class TestJobRepositoryInitialization:
"""Tests for repository initialization.""" """Tests for repository initialization."""
@pytest.mark.asyncio def test_i_can_initialize_repository(self):
async def test_i_can_initialize_repository(self):
"""Test repository initialization.""" """Test repository initialization."""
# Arrange # Arrange
client = AsyncMongoMockClient() client = AsyncMongoMockClient()
@@ -91,7 +90,7 @@ class TestJobRepositoryInitialization:
repo = JobRepository(db) repo = JobRepository(db)
# Act # Act
initialized_repo = await repo.initialize() initialized_repo = repo.initialize()
# Assert # Assert
assert initialized_repo is repo assert initialized_repo is repo
@@ -102,11 +101,10 @@ class TestJobRepositoryInitialization:
class TestJobRepositoryCreation: class TestJobRepositoryCreation:
"""Tests for job creation functionality.""" """Tests for job creation functionality."""
@pytest.mark.asyncio def test_i_can_create_job_with_task_id(self, in_memory_repository, sample_document_id, sample_task_id):
async def test_i_can_create_job_with_task_id(self, in_memory_repository, sample_document_id, sample_task_id):
"""Test successful job creation with task ID.""" """Test successful job creation with task ID."""
# Act # Act
created_job = await in_memory_repository.create_job(sample_document_id, sample_task_id) created_job = in_memory_repository.create_job(sample_document_id, sample_task_id)
# Assert # Assert
assert created_job is not None assert created_job is not None
@@ -120,11 +118,10 @@ class TestJobRepositoryCreation:
assert created_job.id is not None assert created_job.id is not None
assert isinstance(created_job.id, ObjectId) assert isinstance(created_job.id, ObjectId)
@pytest.mark.asyncio def test_i_can_create_job_without_task_id(self, in_memory_repository, sample_document_id):
async def test_i_can_create_job_without_task_id(self, in_memory_repository, sample_document_id):
"""Test successful job creation without task ID.""" """Test successful job creation without task ID."""
# Act # Act
created_job = await in_memory_repository.create_job(sample_document_id) created_job = in_memory_repository.create_job(sample_document_id)
# Assert # Assert
assert created_job is not None assert created_job is not None
@@ -138,28 +135,26 @@ class TestJobRepositoryCreation:
assert created_job.id is not None assert created_job.id is not None
assert isinstance(created_job.id, ObjectId) assert isinstance(created_job.id, ObjectId)
@pytest.mark.asyncio def test_i_cannot_create_duplicate_job_for_document(self, in_memory_repository, sample_document_id,
async def test_i_cannot_create_duplicate_job_for_document(self, in_memory_repository, sample_document_id,
sample_task_id): sample_task_id):
"""Test that creating job with duplicate document_id raises DuplicateKeyError.""" """Test that creating job with duplicate document_id raises DuplicateKeyError."""
# Arrange # Arrange
await in_memory_repository.create_job(sample_document_id, sample_task_id) in_memory_repository.create_job(sample_document_id, sample_task_id)
# Act & Assert # Act & Assert
with pytest.raises(JobRepositoryError) as exc_info: with pytest.raises(JobRepositoryError) as exc_info:
await in_memory_repository.create_job(sample_document_id, "different-task-id") in_memory_repository.create_job(sample_document_id, "different-task-id")
assert "create_job" in str(exc_info.value) assert "create_job" in str(exc_info.value)
@pytest.mark.asyncio def test_i_cannot_create_job_with_pymongo_error(self, in_memory_repository, sample_document_id, mocker):
async def test_i_cannot_create_job_with_pymongo_error(self, in_memory_repository, sample_document_id, mocker):
"""Test handling of PyMongo errors during job creation.""" """Test handling of PyMongo errors during job creation."""
# Arrange # Arrange
mocker.patch.object(in_memory_repository.collection, 'insert_one', side_effect=PyMongoError("Database error")) mocker.patch.object(in_memory_repository.collection, 'insert_one', side_effect=PyMongoError("Database error"))
# Act & Assert # Act & Assert
with pytest.raises(JobRepositoryError) as exc_info: with pytest.raises(JobRepositoryError) as exc_info:
await in_memory_repository.create_job(sample_document_id) in_memory_repository.create_job(sample_document_id)
assert "create_job" in str(exc_info.value) assert "create_job" in str(exc_info.value)
@@ -167,14 +162,13 @@ class TestJobRepositoryCreation:
class TestJobRepositoryFinding: class TestJobRepositoryFinding:
"""Tests for job finding functionality.""" """Tests for job finding functionality."""
@pytest.mark.asyncio def test_i_can_find_job_by_valid_id(self, in_memory_repository, sample_document_id, sample_task_id):
async def test_i_can_find_job_by_valid_id(self, in_memory_repository, sample_document_id, sample_task_id):
"""Test finding job by valid ObjectId.""" """Test finding job by valid ObjectId."""
# Arrange # Arrange
created_job = await in_memory_repository.create_job(sample_document_id, sample_task_id) created_job = in_memory_repository.create_job(sample_document_id, sample_task_id)
# Act # Act
found_job = await in_memory_repository.find_job_by_id(created_job.id) found_job = in_memory_repository.find_job_by_id(created_job.id)
# Assert # Assert
assert found_job is not None assert found_job is not None
@@ -183,97 +177,90 @@ class TestJobRepositoryFinding:
assert found_job.task_id == created_job.task_id assert found_job.task_id == created_job.task_id
assert found_job.status == created_job.status assert found_job.status == created_job.status
@pytest.mark.asyncio def test_i_cannot_find_job_by_nonexistent_id(self, in_memory_repository):
async def test_i_cannot_find_job_by_nonexistent_id(self, in_memory_repository):
"""Test that nonexistent ObjectId returns None.""" """Test that nonexistent ObjectId returns None."""
# Arrange # Arrange
nonexistent_id = PyObjectId() nonexistent_id = PyObjectId()
# Act # Act
found_job = await in_memory_repository.find_job_by_id(nonexistent_id) found_job = in_memory_repository.find_job_by_id(nonexistent_id)
# Assert # Assert
assert found_job is None assert found_job is None
@pytest.mark.asyncio def test_i_cannot_find_job_with_pymongo_error(self, in_memory_repository, mocker):
async def test_i_cannot_find_job_with_pymongo_error(self, in_memory_repository, mocker):
"""Test handling of PyMongo errors during job finding.""" """Test handling of PyMongo errors during job finding."""
# Arrange # Arrange
mocker.patch.object(in_memory_repository.collection, 'find_one', side_effect=PyMongoError("Database error")) mocker.patch.object(in_memory_repository.collection, 'find_one', side_effect=PyMongoError("Database error"))
# Act & Assert # Act & Assert
with pytest.raises(JobRepositoryError) as exc_info: with pytest.raises(JobRepositoryError) as exc_info:
await in_memory_repository.find_job_by_id(PyObjectId()) in_memory_repository.find_job_by_id(PyObjectId())
assert "get_job_by_id" in str(exc_info.value) assert "get_job_by_id" in str(exc_info.value)
@pytest.mark.asyncio def test_i_can_find_jobs_by_document_id(self, in_memory_repository, sample_document_id, sample_task_id):
async def test_i_can_find_jobs_by_document_id(self, in_memory_repository, sample_document_id, sample_task_id):
"""Test finding jobs by document ID.""" """Test finding jobs by document ID."""
# Arrange # Arrange
created_job = await in_memory_repository.create_job(sample_document_id, sample_task_id) created_job = in_memory_repository.create_job(sample_document_id, sample_task_id)
# Act # Act
found_jobs = await in_memory_repository.find_jobs_by_document_id(sample_document_id) found_jobs = in_memory_repository.find_jobs_by_document_id(sample_document_id)
# Assert # Assert
assert len(found_jobs) == 1 assert len(found_jobs) == 1
assert found_jobs[0].id == created_job.id assert found_jobs[0].id == created_job.id
assert found_jobs[0].document_id == sample_document_id assert found_jobs[0].document_id == sample_document_id
@pytest.mark.asyncio def test_i_can_find_empty_jobs_list_for_nonexistent_document(self, in_memory_repository):
async def test_i_can_find_empty_jobs_list_for_nonexistent_document(self, in_memory_repository):
"""Test that nonexistent document ID returns empty list.""" """Test that nonexistent document ID returns empty list."""
# Arrange # Arrange
nonexistent_id = ObjectId() nonexistent_id = ObjectId()
# Act # Act
found_jobs = await in_memory_repository.find_jobs_by_document_id(nonexistent_id) found_jobs = in_memory_repository.find_jobs_by_document_id(nonexistent_id)
# Assert # Assert
assert found_jobs == [] assert found_jobs == []
@pytest.mark.asyncio def test_i_cannot_find_jobs_by_document_with_pymongo_error(self, in_memory_repository, mocker):
async def test_i_cannot_find_jobs_by_document_with_pymongo_error(self, in_memory_repository, mocker):
"""Test handling of PyMongo errors during finding jobs by document ID.""" """Test handling of PyMongo errors during finding jobs by document ID."""
# Arrange # Arrange
mocker.patch.object(in_memory_repository.collection, 'find', side_effect=PyMongoError("Database error")) mocker.patch.object(in_memory_repository.collection, 'find', side_effect=PyMongoError("Database error"))
# Act & Assert # Act & Assert
with pytest.raises(JobRepositoryError) as exc_info: with pytest.raises(JobRepositoryError) as exc_info:
await in_memory_repository.find_jobs_by_document_id(PyObjectId()) in_memory_repository.find_jobs_by_document_id(PyObjectId())
assert "get_jobs_by_file_id" in str(exc_info.value) assert "get_jobs_by_file_id" in str(exc_info.value)
@pytest.mark.asyncio
@pytest.mark.parametrize("status", [ @pytest.mark.parametrize("status", [
ProcessingStatus.PENDING, ProcessingStatus.PENDING,
ProcessingStatus.PROCESSING, ProcessingStatus.PROCESSING,
ProcessingStatus.COMPLETED ProcessingStatus.COMPLETED
]) ])
async def test_i_can_find_jobs_by_pending_status(self, in_memory_repository, sample_document_id, status): def test_i_can_find_jobs_by_pending_status(self, in_memory_repository, sample_document_id, status):
"""Test finding jobs by PENDING status.""" """Test finding jobs by PENDING status."""
# Arrange # Arrange
created_job = await in_memory_repository.create_job(sample_document_id) created_job = in_memory_repository.create_job(sample_document_id)
await in_memory_repository.update_job_status(created_job.id, status) in_memory_repository.update_job_status(created_job.id, status)
# Act # Act
found_jobs = await in_memory_repository.get_jobs_by_status(status) found_jobs = in_memory_repository.get_jobs_by_status(status)
# Assert # Assert
assert len(found_jobs) == 1 assert len(found_jobs) == 1
assert found_jobs[0].id == created_job.id assert found_jobs[0].id == created_job.id
assert found_jobs[0].status == status assert found_jobs[0].status == status
@pytest.mark.asyncio def test_i_can_find_jobs_by_failed_status(self, in_memory_repository, sample_document_id):
async def test_i_can_find_jobs_by_failed_status(self, in_memory_repository, sample_document_id):
"""Test finding jobs by FAILED status.""" """Test finding jobs by FAILED status."""
# Arrange # Arrange
created_job = await in_memory_repository.create_job(sample_document_id) created_job = in_memory_repository.create_job(sample_document_id)
await in_memory_repository.update_job_status(created_job.id, ProcessingStatus.FAILED, "Test error") in_memory_repository.update_job_status(created_job.id, ProcessingStatus.FAILED, "Test error")
# Act # Act
found_jobs = await in_memory_repository.get_jobs_by_status(ProcessingStatus.FAILED) found_jobs = in_memory_repository.get_jobs_by_status(ProcessingStatus.FAILED)
# Assert # Assert
assert len(found_jobs) == 1 assert len(found_jobs) == 1
@@ -281,24 +268,22 @@ class TestJobRepositoryFinding:
assert found_jobs[0].status == ProcessingStatus.FAILED assert found_jobs[0].status == ProcessingStatus.FAILED
assert found_jobs[0].error_message == "Test error" assert found_jobs[0].error_message == "Test error"
@pytest.mark.asyncio def test_i_can_find_empty_jobs_list_for_unused_status(self, in_memory_repository):
async def test_i_can_find_empty_jobs_list_for_unused_status(self, in_memory_repository):
"""Test that unused status returns empty list.""" """Test that unused status returns empty list."""
# Act # Act
found_jobs = await in_memory_repository.get_jobs_by_status(ProcessingStatus.COMPLETED) found_jobs = in_memory_repository.get_jobs_by_status(ProcessingStatus.COMPLETED)
# Assert # Assert
assert found_jobs == [] assert found_jobs == []
@pytest.mark.asyncio def test_i_cannot_find_jobs_by_status_with_pymongo_error(self, in_memory_repository, mocker):
async def test_i_cannot_find_jobs_by_status_with_pymongo_error(self, in_memory_repository, mocker):
"""Test handling of PyMongo errors during finding jobs by status.""" """Test handling of PyMongo errors during finding jobs by status."""
# Arrange # Arrange
mocker.patch.object(in_memory_repository.collection, 'find', side_effect=PyMongoError("Database error")) mocker.patch.object(in_memory_repository.collection, 'find', side_effect=PyMongoError("Database error"))
# Act & Assert # Act & Assert
with pytest.raises(JobRepositoryError) as exc_info: with pytest.raises(JobRepositoryError) as exc_info:
await in_memory_repository.get_jobs_by_status(ProcessingStatus.PENDING) in_memory_repository.get_jobs_by_status(ProcessingStatus.PENDING)
assert "get_jobs_by_status" in str(exc_info.value) assert "get_jobs_by_status" in str(exc_info.value)
@@ -306,14 +291,13 @@ class TestJobRepositoryFinding:
class TestJobRepositoryStatusUpdate: class TestJobRepositoryStatusUpdate:
"""Tests for job status update functionality.""" """Tests for job status update functionality."""
@pytest.mark.asyncio def test_i_can_update_job_status_to_processing(self, in_memory_repository, sample_document_id):
async def test_i_can_update_job_status_to_processing(self, in_memory_repository, sample_document_id):
"""Test updating job status to PROCESSING with started_at timestamp.""" """Test updating job status to PROCESSING with started_at timestamp."""
# Arrange # Arrange
created_job = await in_memory_repository.create_job(sample_document_id) created_job = in_memory_repository.create_job(sample_document_id)
# Act # Act
updated_job = await in_memory_repository.update_job_status(created_job.id, ProcessingStatus.PROCESSING) updated_job = in_memory_repository.update_job_status(created_job.id, ProcessingStatus.PROCESSING)
# Assert # Assert
assert updated_job is not None assert updated_job is not None
@@ -323,15 +307,14 @@ class TestJobRepositoryStatusUpdate:
assert updated_job.completed_at is None assert updated_job.completed_at is None
assert updated_job.error_message is None assert updated_job.error_message is None
@pytest.mark.asyncio def test_i_can_update_job_status_to_completed(self, in_memory_repository, sample_document_id):
async def test_i_can_update_job_status_to_completed(self, in_memory_repository, sample_document_id):
"""Test updating job status to COMPLETED with completed_at timestamp.""" """Test updating job status to COMPLETED with completed_at timestamp."""
# Arrange # Arrange
created_job = await in_memory_repository.create_job(sample_document_id) created_job = in_memory_repository.create_job(sample_document_id)
await in_memory_repository.update_job_status(created_job.id, ProcessingStatus.PROCESSING) in_memory_repository.update_job_status(created_job.id, ProcessingStatus.PROCESSING)
# Act # Act
updated_job = await in_memory_repository.update_job_status(created_job.id, ProcessingStatus.COMPLETED) updated_job = in_memory_repository.update_job_status(created_job.id, ProcessingStatus.COMPLETED)
# Assert # Assert
assert updated_job is not None assert updated_job is not None
@@ -341,15 +324,14 @@ class TestJobRepositoryStatusUpdate:
assert updated_job.completed_at is not None assert updated_job.completed_at is not None
assert updated_job.error_message is None assert updated_job.error_message is None
@pytest.mark.asyncio def test_i_can_update_job_status_to_failed_with_error(self, in_memory_repository, sample_document_id):
async def test_i_can_update_job_status_to_failed_with_error(self, in_memory_repository, sample_document_id):
"""Test updating job status to FAILED with error message and completed_at timestamp.""" """Test updating job status to FAILED with error message and completed_at timestamp."""
# Arrange # Arrange
created_job = await in_memory_repository.create_job(sample_document_id) created_job = in_memory_repository.create_job(sample_document_id)
error_message = "Processing failed due to invalid format" error_message = "Processing failed due to invalid format"
# Act # Act
updated_job = await in_memory_repository.update_job_status( updated_job = in_memory_repository.update_job_status(
created_job.id, ProcessingStatus.FAILED, error_message created_job.id, ProcessingStatus.FAILED, error_message
) )
@@ -360,14 +342,13 @@ class TestJobRepositoryStatusUpdate:
assert updated_job.completed_at is not None assert updated_job.completed_at is not None
assert updated_job.error_message == error_message assert updated_job.error_message == error_message
@pytest.mark.asyncio def test_i_can_update_job_status_to_failed_without_error(self, in_memory_repository, sample_document_id):
async def test_i_can_update_job_status_to_failed_without_error(self, in_memory_repository, sample_document_id):
"""Test updating job status to FAILED without error message.""" """Test updating job status to FAILED without error message."""
# Arrange # Arrange
created_job = await in_memory_repository.create_job(sample_document_id) created_job = in_memory_repository.create_job(sample_document_id)
# Act # Act
updated_job = await in_memory_repository.update_job_status(created_job.id, ProcessingStatus.FAILED) updated_job = in_memory_repository.update_job_status(created_job.id, ProcessingStatus.FAILED)
# Assert # Assert
assert updated_job is not None assert updated_job is not None
@@ -376,29 +357,27 @@ class TestJobRepositoryStatusUpdate:
assert updated_job.completed_at is not None assert updated_job.completed_at is not None
assert updated_job.error_message is None assert updated_job.error_message is None
@pytest.mark.asyncio def test_i_cannot_update_nonexistent_job_status(self, in_memory_repository):
async def test_i_cannot_update_nonexistent_job_status(self, in_memory_repository):
"""Test that updating nonexistent job returns None.""" """Test that updating nonexistent job returns None."""
# Arrange # Arrange
nonexistent_id = ObjectId() nonexistent_id = ObjectId()
# Act # Act
result = await in_memory_repository.update_job_status(nonexistent_id, ProcessingStatus.COMPLETED) result = in_memory_repository.update_job_status(nonexistent_id, ProcessingStatus.COMPLETED)
# Assert # Assert
assert result is None assert result is None
@pytest.mark.asyncio def test_i_cannot_update_job_status_with_pymongo_error(self, in_memory_repository, sample_document_id, mocker):
async def test_i_cannot_update_job_status_with_pymongo_error(self, in_memory_repository, sample_document_id, mocker):
"""Test handling of PyMongo errors during job status update.""" """Test handling of PyMongo errors during job status update."""
# Arrange # Arrange
created_job = await in_memory_repository.create_job(sample_document_id) created_job = in_memory_repository.create_job(sample_document_id)
mocker.patch.object(in_memory_repository.collection, 'find_one_and_update', mocker.patch.object(in_memory_repository.collection, 'find_one_and_update',
side_effect=PyMongoError("Database error")) side_effect=PyMongoError("Database error"))
# Act & Assert # Act & Assert
with pytest.raises(JobRepositoryError) as exc_info: with pytest.raises(JobRepositoryError) as exc_info:
await in_memory_repository.update_job_status(created_job.id, ProcessingStatus.COMPLETED) in_memory_repository.update_job_status(created_job.id, ProcessingStatus.COMPLETED)
assert "update_job_status" in str(exc_info.value) assert "update_job_status" in str(exc_info.value)
@@ -406,44 +385,41 @@ class TestJobRepositoryStatusUpdate:
class TestJobRepositoryDeletion: class TestJobRepositoryDeletion:
"""Tests for job deletion functionality.""" """Tests for job deletion functionality."""
@pytest.mark.asyncio def test_i_can_delete_existing_job(self, in_memory_repository, sample_document_id):
async def test_i_can_delete_existing_job(self, in_memory_repository, sample_document_id):
"""Test successful job deletion.""" """Test successful job deletion."""
# Arrange # Arrange
created_job = await in_memory_repository.create_job(sample_document_id) created_job = in_memory_repository.create_job(sample_document_id)
# Act # Act
deletion_result = await in_memory_repository.delete_job(created_job.id) deletion_result = in_memory_repository.delete_job(created_job.id)
# Assert # Assert
assert deletion_result is True assert deletion_result is True
# Verify job is actually deleted # Verify job is actually deleted
found_job = await in_memory_repository.find_job_by_id(created_job.id) found_job = in_memory_repository.find_job_by_id(created_job.id)
assert found_job is None assert found_job is None
@pytest.mark.asyncio def test_i_cannot_delete_nonexistent_job(self, in_memory_repository):
async def test_i_cannot_delete_nonexistent_job(self, in_memory_repository):
"""Test that deleting nonexistent job returns False.""" """Test that deleting nonexistent job returns False."""
# Arrange # Arrange
nonexistent_id = ObjectId() nonexistent_id = ObjectId()
# Act # Act
result = await in_memory_repository.delete_job(nonexistent_id) result = in_memory_repository.delete_job(nonexistent_id)
# Assert # Assert
assert result is False assert result is False
@pytest.mark.asyncio def test_i_cannot_delete_job_with_pymongo_error(self, in_memory_repository, sample_document_id, mocker):
async def test_i_cannot_delete_job_with_pymongo_error(self, in_memory_repository, sample_document_id, mocker):
"""Test handling of PyMongo errors during job deletion.""" """Test handling of PyMongo errors during job deletion."""
# Arrange # Arrange
created_job = await in_memory_repository.create_job(sample_document_id) created_job = in_memory_repository.create_job(sample_document_id)
mocker.patch.object(in_memory_repository.collection, 'delete_one', side_effect=PyMongoError("Database error")) mocker.patch.object(in_memory_repository.collection, 'delete_one', side_effect=PyMongoError("Database error"))
# Act & Assert # Act & Assert
with pytest.raises(JobRepositoryError) as exc_info: with pytest.raises(JobRepositoryError) as exc_info:
await in_memory_repository.delete_job(created_job.id) in_memory_repository.delete_job(created_job.id)
assert "delete_job" in str(exc_info.value) assert "delete_job" in str(exc_info.value)
@@ -451,38 +427,36 @@ class TestJobRepositoryDeletion:
class TestJobRepositoryComplexScenarios: class TestJobRepositoryComplexScenarios:
"""Tests for complex job repository scenarios.""" """Tests for complex job repository scenarios."""
@pytest.mark.asyncio def test_i_can_handle_complete_job_lifecycle(self, in_memory_repository, sample_document_id, sample_task_id):
async def test_i_can_handle_complete_job_lifecycle(self, in_memory_repository, sample_document_id, sample_task_id):
"""Test complete job lifecycle from creation to completion.""" """Test complete job lifecycle from creation to completion."""
# Create job # Create job
job = await in_memory_repository.create_job(sample_document_id, sample_task_id) job = in_memory_repository.create_job(sample_document_id, sample_task_id)
assert job.status == ProcessingStatus.PENDING assert job.status == ProcessingStatus.PENDING
assert job.started_at is None assert job.started_at is None
assert job.completed_at is None assert job.completed_at is None
# Start processing # Start processing
job = await in_memory_repository.update_job_status(job.id, ProcessingStatus.PROCESSING) job = in_memory_repository.update_job_status(job.id, ProcessingStatus.PROCESSING)
assert job.status == ProcessingStatus.PROCESSING assert job.status == ProcessingStatus.PROCESSING
assert job.started_at is not None assert job.started_at is not None
assert job.completed_at is None assert job.completed_at is None
# Complete job # Complete job
job = await in_memory_repository.update_job_status(job.id, ProcessingStatus.COMPLETED) job = in_memory_repository.update_job_status(job.id, ProcessingStatus.COMPLETED)
assert job.status == ProcessingStatus.COMPLETED assert job.status == ProcessingStatus.COMPLETED
assert job.started_at is not None assert job.started_at is not None
assert job.completed_at is not None assert job.completed_at is not None
assert job.error_message is None assert job.error_message is None
@pytest.mark.asyncio def test_i_can_handle_job_failure_scenario(self, in_memory_repository, sample_document_id, sample_task_id):
async def test_i_can_handle_job_failure_scenario(self, in_memory_repository, sample_document_id, sample_task_id):
"""Test job failure scenario with error message.""" """Test job failure scenario with error message."""
# Create and start job # Create and start job
job = await in_memory_repository.create_job(sample_document_id, sample_task_id) job = in_memory_repository.create_job(sample_document_id, sample_task_id)
job = await in_memory_repository.update_job_status(job.id, ProcessingStatus.PROCESSING) job = in_memory_repository.update_job_status(job.id, ProcessingStatus.PROCESSING)
# Fail job with error # Fail job with error
error_msg = "File format not supported" error_msg = "File format not supported"
job = await in_memory_repository.update_job_status(job.id, ProcessingStatus.FAILED, error_msg) job = in_memory_repository.update_job_status(job.id, ProcessingStatus.FAILED, error_msg)
# Assert failure state # Assert failure state
assert job.status == ProcessingStatus.FAILED assert job.status == ProcessingStatus.FAILED
@@ -490,28 +464,27 @@ class TestJobRepositoryComplexScenarios:
assert job.completed_at is not None assert job.completed_at is not None
assert job.error_message == error_msg assert job.error_message == error_msg
@pytest.mark.asyncio def test_i_can_handle_multiple_documents_with_different_statuses(self, in_memory_repository):
async def test_i_can_handle_multiple_documents_with_different_statuses(self, in_memory_repository):
"""Test managing multiple jobs for different documents with various statuses.""" """Test managing multiple jobs for different documents with various statuses."""
# Create jobs for different documents # Create jobs for different documents
doc1 = PyObjectId() doc1 = PyObjectId()
doc2 = PyObjectId() doc2 = PyObjectId()
doc3 = PyObjectId() doc3 = PyObjectId()
job1 = await in_memory_repository.create_job(doc1, "task-1") job1 = in_memory_repository.create_job(doc1, "task-1")
job2 = await in_memory_repository.create_job(doc2, "task-2") job2 = in_memory_repository.create_job(doc2, "task-2")
job3 = await in_memory_repository.create_job(doc3, "task-3") job3 = in_memory_repository.create_job(doc3, "task-3")
# Update to different statuses # Update to different statuses
await in_memory_repository.update_job_status(job1.id, ProcessingStatus.PROCESSING) in_memory_repository.update_job_status(job1.id, ProcessingStatus.PROCESSING)
await in_memory_repository.update_job_status(job2.id, ProcessingStatus.COMPLETED) in_memory_repository.update_job_status(job2.id, ProcessingStatus.COMPLETED)
await in_memory_repository.update_job_status(job3.id, ProcessingStatus.FAILED, "Error occurred") in_memory_repository.update_job_status(job3.id, ProcessingStatus.FAILED, "Error occurred")
# Verify status queries # Verify status queries
pending_jobs = await in_memory_repository.get_jobs_by_status(ProcessingStatus.PENDING) pending_jobs = in_memory_repository.get_jobs_by_status(ProcessingStatus.PENDING)
processing_jobs = await in_memory_repository.get_jobs_by_status(ProcessingStatus.PROCESSING) processing_jobs = in_memory_repository.get_jobs_by_status(ProcessingStatus.PROCESSING)
completed_jobs = await in_memory_repository.get_jobs_by_status(ProcessingStatus.COMPLETED) completed_jobs = in_memory_repository.get_jobs_by_status(ProcessingStatus.COMPLETED)
failed_jobs = await in_memory_repository.get_jobs_by_status(ProcessingStatus.FAILED) failed_jobs = in_memory_repository.get_jobs_by_status(ProcessingStatus.FAILED)
assert len(pending_jobs) == 0 assert len(pending_jobs) == 0
assert len(processing_jobs) == 1 assert len(processing_jobs) == 1

View File

@@ -1,29 +1,26 @@
""" """
Test suite for UserRepository with async/await support. Test suite for UserRepository with async/support.
This module contains comprehensive tests for all UserRepository methods This module contains comprehensive tests for all UserRepository methods
using mongomock-motor for in-memory MongoDB testing. using mongomock-motor for in-memory MongoDB testing.
""" """
import pytest import pytest
from datetime import datetime
import pytest_asyncio
from bson import ObjectId from bson import ObjectId
from mongomock.mongo_client import MongoClient
from pymongo.errors import DuplicateKeyError from pymongo.errors import DuplicateKeyError
from mongomock_motor import AsyncMongoMockClient
from app.database.repositories.user_repository import UserRepository from app.database.repositories.user_repository import UserRepository
from app.models.user import UserCreate, UserUpdate, UserInDB from app.models.user import UserCreate, UserUpdate
@pytest_asyncio.fixture @pytest.fixture
async def in_memory_repository(): def in_memory_repository():
"""Create an in-memory UserRepository for testing.""" """Create an in-memory UserRepository for testing."""
client = AsyncMongoMockClient() client = MongoClient()
db = client.test_database db = client.test_database
repo = UserRepository(db) repo = UserRepository(db)
await repo.initialize() repo.initialize()
return repo return repo
@@ -51,11 +48,10 @@ def sample_user_update():
class TestUserRepositoryCreation: class TestUserRepositoryCreation:
"""Tests for user creation functionality.""" """Tests for user creation functionality."""
@pytest.mark.asyncio def test_i_can_create_user(self, in_memory_repository, sample_user_create):
async def test_i_can_create_user(self, in_memory_repository, sample_user_create):
"""Test successful user creation.""" """Test successful user creation."""
# Act # Act
created_user = await in_memory_repository.create_user(sample_user_create) created_user = in_memory_repository.create_user(sample_user_create)
# Assert # Assert
assert created_user is not None assert created_user is not None
@@ -68,15 +64,14 @@ class TestUserRepositoryCreation:
assert created_user.updated_at is not None assert created_user.updated_at is not None
assert created_user.hashed_password != sample_user_create.password # Should be hashed assert created_user.hashed_password != sample_user_create.password # Should be hashed
@pytest.mark.asyncio def test_i_cannot_create_user_with_duplicate_username(self, in_memory_repository, sample_user_create):
async def test_i_cannot_create_user_with_duplicate_username(self, in_memory_repository, sample_user_create):
"""Test that creating user with duplicate username raises DuplicateKeyError.""" """Test that creating user with duplicate username raises DuplicateKeyError."""
# Arrange # Arrange
await in_memory_repository.create_user(sample_user_create) in_memory_repository.create_user(sample_user_create)
# Act & Assert # Act & Assert
with pytest.raises(DuplicateKeyError) as exc_info: with pytest.raises(DuplicateKeyError) as exc_info:
await in_memory_repository.create_user(sample_user_create) in_memory_repository.create_user(sample_user_create)
assert "already exists" in str(exc_info.value) assert "already exists" in str(exc_info.value)
@@ -84,14 +79,13 @@ class TestUserRepositoryCreation:
class TestUserRepositoryFinding: class TestUserRepositoryFinding:
"""Tests for user finding functionality.""" """Tests for user finding functionality."""
@pytest.mark.asyncio def test_i_can_find_user_by_id(self, in_memory_repository, sample_user_create):
async def test_i_can_find_user_by_id(self, in_memory_repository, sample_user_create):
"""Test finding user by valid ID.""" """Test finding user by valid ID."""
# Arrange # Arrange
created_user = await in_memory_repository.create_user(sample_user_create) created_user = in_memory_repository.create_user(sample_user_create)
# Act # Act
found_user = await in_memory_repository.find_user_by_id(str(created_user.id)) found_user = in_memory_repository.find_user_by_id(str(created_user.id))
# Assert # Assert
assert found_user is not None assert found_user is not None
@@ -99,69 +93,63 @@ class TestUserRepositoryFinding:
assert found_user.username == created_user.username assert found_user.username == created_user.username
assert found_user.email == created_user.email assert found_user.email == created_user.email
@pytest.mark.asyncio def test_i_cannot_find_user_by_invalid_id(self, in_memory_repository):
async def test_i_cannot_find_user_by_invalid_id(self, in_memory_repository):
"""Test that invalid ObjectId returns None.""" """Test that invalid ObjectId returns None."""
# Act # Act
found_user = await in_memory_repository.find_user_by_id("invalid_id") found_user = in_memory_repository.find_user_by_id("invalid_id")
# Assert # Assert
assert found_user is None assert found_user is None
@pytest.mark.asyncio def test_i_cannot_find_user_by_nonexistent_id(self, in_memory_repository):
async def test_i_cannot_find_user_by_nonexistent_id(self, in_memory_repository):
"""Test that nonexistent but valid ObjectId returns None.""" """Test that nonexistent but valid ObjectId returns None."""
# Arrange # Arrange
nonexistent_id = str(ObjectId()) nonexistent_id = str(ObjectId())
# Act # Act
found_user = await in_memory_repository.find_user_by_id(nonexistent_id) found_user = in_memory_repository.find_user_by_id(nonexistent_id)
# Assert # Assert
assert found_user is None assert found_user is None
@pytest.mark.asyncio def test_i_can_find_user_by_username(self, in_memory_repository, sample_user_create):
async def test_i_can_find_user_by_username(self, in_memory_repository, sample_user_create):
"""Test finding user by username.""" """Test finding user by username."""
# Arrange # Arrange
created_user = await in_memory_repository.create_user(sample_user_create) created_user = in_memory_repository.create_user(sample_user_create)
# Act # Act
found_user = await in_memory_repository.find_user_by_username(sample_user_create.username) found_user = in_memory_repository.find_user_by_username(sample_user_create.username)
# Assert # Assert
assert found_user is not None assert found_user is not None
assert found_user.username == created_user.username assert found_user.username == created_user.username
assert found_user.id == created_user.id assert found_user.id == created_user.id
@pytest.mark.asyncio def test_i_cannot_find_user_by_nonexistent_username(self, in_memory_repository):
async def test_i_cannot_find_user_by_nonexistent_username(self, in_memory_repository):
"""Test that nonexistent username returns None.""" """Test that nonexistent username returns None."""
# Act # Act
found_user = await in_memory_repository.find_user_by_username("nonexistent") found_user = in_memory_repository.find_user_by_username("nonexistent")
# Assert # Assert
assert found_user is None assert found_user is None
@pytest.mark.asyncio def test_i_can_find_user_by_email(self, in_memory_repository, sample_user_create):
async def test_i_can_find_user_by_email(self, in_memory_repository, sample_user_create):
"""Test finding user by email.""" """Test finding user by email."""
# Arrange # Arrange
created_user = await in_memory_repository.create_user(sample_user_create) created_user = in_memory_repository.create_user(sample_user_create)
# Act # Act
found_user = await in_memory_repository.find_user_by_email(str(sample_user_create.email)) found_user = in_memory_repository.find_user_by_email(str(sample_user_create.email))
# Assert # Assert
assert found_user is not None assert found_user is not None
assert found_user.email == created_user.email assert found_user.email == created_user.email
assert found_user.id == created_user.id assert found_user.id == created_user.id
@pytest.mark.asyncio def test_i_cannot_find_user_by_nonexistent_email(self, in_memory_repository):
async def test_i_cannot_find_user_by_nonexistent_email(self, in_memory_repository):
"""Test that nonexistent email returns None.""" """Test that nonexistent email returns None."""
# Act # Act
found_user = await in_memory_repository.find_user_by_email("nonexistent@example.com") found_user = in_memory_repository.find_user_by_email("nonexistent@example.com")
# Assert # Assert
assert found_user is None assert found_user is None
@@ -170,15 +158,14 @@ class TestUserRepositoryFinding:
class TestUserRepositoryUpdate: class TestUserRepositoryUpdate:
"""Tests for user update functionality.""" """Tests for user update functionality."""
@pytest.mark.asyncio def test_i_can_update_user(self, in_memory_repository, sample_user_create, sample_user_update):
async def test_i_can_update_user(self, in_memory_repository, sample_user_create, sample_user_update):
"""Test successful user update.""" """Test successful user update."""
# Arrange # Arrange
created_user = await in_memory_repository.create_user(sample_user_create) created_user = in_memory_repository.create_user(sample_user_create)
original_updated_at = created_user.updated_at original_updated_at = created_user.updated_at
# Act # Act
updated_user = await in_memory_repository.update_user(str(created_user.id), sample_user_update) updated_user = in_memory_repository.update_user(str(created_user.id), sample_user_update)
# Assert # Assert
assert updated_user is not None assert updated_user is not None
@@ -187,24 +174,22 @@ class TestUserRepositoryUpdate:
assert updated_user.role == sample_user_update.role assert updated_user.role == sample_user_update.role
assert updated_user.id == created_user.id assert updated_user.id == created_user.id
@pytest.mark.asyncio def test_i_cannot_update_user_with_invalid_id(self, in_memory_repository, sample_user_update):
async def test_i_cannot_update_user_with_invalid_id(self, in_memory_repository, sample_user_update):
"""Test that updating with invalid ID returns None.""" """Test that updating with invalid ID returns None."""
# Act # Act
result = await in_memory_repository.update_user("invalid_id", sample_user_update) result = in_memory_repository.update_user("invalid_id", sample_user_update)
# Assert # Assert
assert result is None assert result is None
@pytest.mark.asyncio def test_i_can_update_user_with_partial_data(self, in_memory_repository, sample_user_create):
async def test_i_can_update_user_with_partial_data(self, in_memory_repository, sample_user_create):
"""Test updating user with partial data.""" """Test updating user with partial data."""
# Arrange # Arrange
created_user = await in_memory_repository.create_user(sample_user_create) created_user = in_memory_repository.create_user(sample_user_create)
partial_update = UserUpdate(username="newusername") partial_update = UserUpdate(username="newusername")
# Act # Act
updated_user = await in_memory_repository.update_user(str(created_user.id), partial_update) updated_user = in_memory_repository.update_user(str(created_user.id), partial_update)
# Assert # Assert
assert updated_user is not None assert updated_user is not None
@@ -212,15 +197,14 @@ class TestUserRepositoryUpdate:
assert updated_user.email == created_user.email # Should remain unchanged assert updated_user.email == created_user.email # Should remain unchanged
assert updated_user.role == created_user.role # Should remain unchanged assert updated_user.role == created_user.role # Should remain unchanged
@pytest.mark.asyncio def test_i_can_update_user_with_empty_data(self, in_memory_repository, sample_user_create):
async def test_i_can_update_user_with_empty_data(self, in_memory_repository, sample_user_create):
"""Test updating user with empty data returns current user.""" """Test updating user with empty data returns current user."""
# Arrange # Arrange
created_user = await in_memory_repository.create_user(sample_user_create) created_user = in_memory_repository.create_user(sample_user_create)
empty_update = UserUpdate() empty_update = UserUpdate()
# Act # Act
result = await in_memory_repository.update_user(str(created_user.id), empty_update) result = in_memory_repository.update_user(str(created_user.id), empty_update)
# Assert # Assert
assert result is not None assert result is not None
@@ -231,39 +215,36 @@ class TestUserRepositoryUpdate:
class TestUserRepositoryDeletion: class TestUserRepositoryDeletion:
"""Tests for user deletion functionality.""" """Tests for user deletion functionality."""
@pytest.mark.asyncio def test_i_can_delete_user(self, in_memory_repository, sample_user_create):
async def test_i_can_delete_user(self, in_memory_repository, sample_user_create):
"""Test successful user deletion.""" """Test successful user deletion."""
# Arrange # Arrange
created_user = await in_memory_repository.create_user(sample_user_create) created_user = in_memory_repository.create_user(sample_user_create)
# Act # Act
deletion_result = await in_memory_repository.delete_user(str(created_user.id)) deletion_result = in_memory_repository.delete_user(str(created_user.id))
# Assert # Assert
assert deletion_result is True assert deletion_result is True
# Verify user is actually deleted # Verify user is actually deleted
found_user = await in_memory_repository.find_user_by_id(str(created_user.id)) found_user = in_memory_repository.find_user_by_id(str(created_user.id))
assert found_user is None assert found_user is None
@pytest.mark.asyncio def test_i_cannot_delete_user_with_invalid_id(self, in_memory_repository):
async def test_i_cannot_delete_user_with_invalid_id(self, in_memory_repository):
"""Test that deleting with invalid ID returns False.""" """Test that deleting with invalid ID returns False."""
# Act # Act
result = await in_memory_repository.delete_user("invalid_id") result = in_memory_repository.delete_user("invalid_id")
# Assert # Assert
assert result is False assert result is False
@pytest.mark.asyncio def test_i_cannot_delete_nonexistent_user(self, in_memory_repository):
async def test_i_cannot_delete_nonexistent_user(self, in_memory_repository):
"""Test that deleting nonexistent user returns False.""" """Test that deleting nonexistent user returns False."""
# Arrange # Arrange
nonexistent_id = str(ObjectId()) nonexistent_id = str(ObjectId())
# Act # Act
result = await in_memory_repository.delete_user(nonexistent_id) result = in_memory_repository.delete_user(nonexistent_id)
# Assert # Assert
assert result is False assert result is False
@@ -272,30 +253,27 @@ class TestUserRepositoryDeletion:
class TestUserRepositoryUtilities: class TestUserRepositoryUtilities:
"""Tests for utility methods.""" """Tests for utility methods."""
@pytest.mark.asyncio def test_i_can_count_users(self, in_memory_repository, sample_user_create):
async def test_i_can_count_users(self, in_memory_repository, sample_user_create):
"""Test counting users.""" """Test counting users."""
# Arrange # Arrange
initial_count = await in_memory_repository.count_users() initial_count = in_memory_repository.count_users()
await in_memory_repository.create_user(sample_user_create) in_memory_repository.create_user(sample_user_create)
# Act # Act
final_count = await in_memory_repository.count_users() final_count = in_memory_repository.count_users()
# Assert # Assert
assert final_count == initial_count + 1 assert final_count == initial_count + 1
@pytest.mark.asyncio def test_i_can_check_user_exists(self, in_memory_repository, sample_user_create):
async def test_i_can_check_user_exists(self, in_memory_repository, sample_user_create):
"""Test checking if user exists.""" """Test checking if user exists."""
# Arrange # Arrange
await in_memory_repository.create_user(sample_user_create) in_memory_repository.create_user(sample_user_create)
# Act # Act
exists = await in_memory_repository.user_exists(sample_user_create.username) exists = in_memory_repository.user_exists(sample_user_create.username)
not_exists = await in_memory_repository.user_exists("nonexistent") not_exists = in_memory_repository.user_exists("nonexistent")
# Assert # Assert
assert exists is True assert exists is True
assert not_exists is False assert not_exists is False

View File

@@ -11,7 +11,7 @@ from unittest.mock import patch
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from bson import ObjectId from bson import ObjectId
from mongomock_motor import AsyncMongoMockClient from mongomock.mongo_client import MongoClient
from app.models.document import FileType from app.models.document import FileType
from app.services.document_service import DocumentService from app.services.document_service import DocumentService
@@ -24,15 +24,15 @@ def cleanup_test_folder():
shutil.rmtree("test_folder", ignore_errors=True) shutil.rmtree("test_folder", ignore_errors=True)
@pytest_asyncio.fixture @pytest.fixture
async def in_memory_database(): def in_memory_database():
"""Create an in-memory database for testing.""" """Create an in-memory database for testing."""
client = AsyncMongoMockClient() client = MongoClient()
return client.test_database return client.test_database
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def document_service(in_memory_database): def document_service(in_memory_database):
"""Create DocumentService with in-memory repositories.""" """Create DocumentService with in-memory repositories."""
service = DocumentService(in_memory_database, objects_folder="test_folder") service = DocumentService(in_memory_database, objects_folder="test_folder")
return service return service
@@ -72,8 +72,7 @@ class TestCreateDocument:
@patch('app.services.document_service.magic.from_buffer') @patch('app.services.document_service.magic.from_buffer')
@patch('app.services.document_service.datetime') @patch('app.services.document_service.datetime')
@pytest.mark.asyncio def test_i_can_create_document_with_new_content(
async def test_i_can_create_document_with_new_content(
self, self,
mock_datetime, mock_datetime,
mock_magic, mock_magic,
@@ -87,7 +86,7 @@ class TestCreateDocument:
mock_magic.return_value = "application/pdf" mock_magic.return_value = "application/pdf"
# Execute # Execute
result = await document_service.create_document( result = document_service.create_document(
"/test/test.pdf", "/test/test.pdf",
sample_file_bytes, sample_file_bytes,
"utf-8" "utf-8"
@@ -102,7 +101,7 @@ class TestCreateDocument:
assert result.file_hash == document_service._calculate_file_hash(sample_file_bytes) assert result.file_hash == document_service._calculate_file_hash(sample_file_bytes)
# Verify document created in database # Verify document created in database
doc_in_db = await document_service.document_repository.find_document_by_id(result.id) doc_in_db = document_service.document_repository.find_document_by_id(result.id)
assert doc_in_db is not None assert doc_in_db is not None
assert doc_in_db.id == result.id assert doc_in_db.id == result.id
assert doc_in_db.filename == result.filename assert doc_in_db.filename == result.filename
@@ -116,8 +115,7 @@ class TestCreateDocument:
@patch('app.services.document_service.magic.from_buffer') @patch('app.services.document_service.magic.from_buffer')
@patch('app.services.document_service.datetime') @patch('app.services.document_service.datetime')
@pytest.mark.asyncio def test_i_can_create_document_with_existing_content(
async def test_i_can_create_document_with_existing_content(
self, self,
mock_datetime, mock_datetime,
mock_magic, mock_magic,
@@ -131,14 +129,14 @@ class TestCreateDocument:
mock_magic.return_value = "application/pdf" mock_magic.return_value = "application/pdf"
# Create first document # Create first document
first_doc = await document_service.create_document( first_doc = document_service.create_document(
"/test/first.pdf", "/test/first.pdf",
sample_file_bytes, sample_file_bytes,
"utf-8" "utf-8"
) )
# Create second document with same content # Create second document with same content
second_doc = await document_service.create_document( second_doc = document_service.create_document(
"/test/second.pdf", "/test/second.pdf",
sample_file_bytes, sample_file_bytes,
"utf-8" "utf-8"
@@ -149,37 +147,34 @@ class TestCreateDocument:
assert first_doc.filename != second_doc.filename assert first_doc.filename != second_doc.filename
assert first_doc.filepath != second_doc.filepath assert first_doc.filepath != second_doc.filepath
@pytest.mark.asyncio def test_i_cannot_create_document_with_unsupported_file_type(
async def test_i_cannot_create_document_with_unsupported_file_type(
self, self,
document_service, document_service,
sample_file_bytes sample_file_bytes
): ):
"""Test that unsupported file types raise ValueError.""" """Test that unsupported file types raise ValueError."""
with pytest.raises(ValueError, match="Unsupported file type"): with pytest.raises(ValueError, match="Unsupported file type"):
await document_service.create_document( document_service.create_document(
"/test/test.xyz", # Unsupported extension "/test/test.xyz", # Unsupported extension
sample_file_bytes, sample_file_bytes,
"utf-8" "utf-8"
) )
@pytest.mark.asyncio def test_i_cannot_create_document_with_empty_file_path(
async def test_i_cannot_create_document_with_empty_file_path(
self, self,
document_service, document_service,
sample_file_bytes sample_file_bytes
): ):
"""Test that empty file path raises ValueError.""" """Test that empty file path raises ValueError."""
with pytest.raises(ValueError): with pytest.raises(ValueError):
await document_service.create_document( document_service.create_document(
"", # Empty path "", # Empty path
sample_file_bytes, sample_file_bytes,
"utf-8" "utf-8"
) )
@patch('app.services.document_service.magic.from_buffer') @patch('app.services.document_service.magic.from_buffer')
@pytest.mark.asyncio def test_i_can_create_document_with_empty_bytes(
async def test_i_can_create_document_with_empty_bytes(
self, self,
mock_magic, mock_magic,
document_service document_service
@@ -189,7 +184,7 @@ class TestCreateDocument:
mock_magic.return_value = "text/plain" mock_magic.return_value = "text/plain"
# Execute with empty bytes # Execute with empty bytes
result = await document_service.create_document( result = document_service.create_document(
"/test/empty.txt", "/test/empty.txt",
b"", # Empty bytes b"", # Empty bytes
"utf-8" "utf-8"
@@ -203,8 +198,7 @@ class TestGetMethods:
"""Tests for document retrieval methods.""" """Tests for document retrieval methods."""
@patch('app.services.document_service.magic.from_buffer') @patch('app.services.document_service.magic.from_buffer')
@pytest.mark.asyncio def test_i_can_get_document_by_id(
async def test_i_can_get_document_by_id(
self, self,
mock_magic, mock_magic,
document_service, document_service,
@@ -215,14 +209,14 @@ class TestGetMethods:
mock_magic.return_value = "application/pdf" mock_magic.return_value = "application/pdf"
# Create a document first # Create a document first
created_doc = await document_service.create_document( created_doc = document_service.create_document(
"/test/test.pdf", "/test/test.pdf",
sample_file_bytes, sample_file_bytes,
"utf-8" "utf-8"
) )
# Execute # Execute
result = await document_service.get_document_by_id(created_doc.id) result = document_service.get_document_by_id(created_doc.id)
# Verify # Verify
assert result is not None assert result is not None
@@ -230,8 +224,7 @@ class TestGetMethods:
assert result.filename == created_doc.filename assert result.filename == created_doc.filename
@patch('app.services.document_service.magic.from_buffer') @patch('app.services.document_service.magic.from_buffer')
@pytest.mark.asyncio def test_i_can_get_document_by_hash(
async def test_i_can_get_document_by_hash(
self, self,
mock_magic, mock_magic,
document_service, document_service,
@@ -242,14 +235,14 @@ class TestGetMethods:
mock_magic.return_value = "application/pdf" mock_magic.return_value = "application/pdf"
# Create a document first # Create a document first
created_doc = await document_service.create_document( created_doc = document_service.create_document(
"/test/test.pdf", "/test/test.pdf",
sample_file_bytes, sample_file_bytes,
"utf-8" "utf-8"
) )
# Execute # Execute
result = await document_service.get_document_by_hash(created_doc.file_hash) result = document_service.get_document_by_hash(created_doc.file_hash)
# Verify # Verify
assert result is not None assert result is not None
@@ -257,8 +250,7 @@ class TestGetMethods:
assert result.filename == created_doc.filename assert result.filename == created_doc.filename
@patch('app.services.document_service.magic.from_buffer') @patch('app.services.document_service.magic.from_buffer')
@pytest.mark.asyncio def test_i_can_get_document_by_filepath(
async def test_i_can_get_document_by_filepath(
self, self,
mock_magic, mock_magic,
document_service, document_service,
@@ -270,14 +262,14 @@ class TestGetMethods:
test_path = "/test/unique_test.pdf" test_path = "/test/unique_test.pdf"
# Create a document first # Create a document first
created_doc = await document_service.create_document( created_doc = document_service.create_document(
test_path, test_path,
sample_file_bytes, sample_file_bytes,
"utf-8" "utf-8"
) )
# Execute # Execute
result = await document_service.get_document_by_filepath(test_path) result = document_service.get_document_by_filepath(test_path)
# Verify # Verify
assert result is not None assert result is not None
@@ -285,8 +277,7 @@ class TestGetMethods:
assert result.id == created_doc.id assert result.id == created_doc.id
@patch('app.services.document_service.magic.from_buffer') @patch('app.services.document_service.magic.from_buffer')
@pytest.mark.asyncio def test_i_can_get_document_content(
async def test_i_can_get_document_content(
self, self,
mock_magic, mock_magic,
document_service, document_service,
@@ -297,38 +288,36 @@ class TestGetMethods:
mock_magic.return_value = "application/pdf" mock_magic.return_value = "application/pdf"
# Create a document first # Create a document first
created_doc = await document_service.create_document( created_doc = document_service.create_document(
"/test/test.pdf", "/test/test.pdf",
sample_file_bytes, sample_file_bytes,
"utf-8" "utf-8"
) )
# Execute # Execute
result = await document_service.get_document_content_by_hash(created_doc.file_hash) result = document_service.get_document_content_by_hash(created_doc.file_hash)
# Verify # Verify
assert result == sample_file_bytes assert result == sample_file_bytes
@pytest.mark.asyncio def test_i_cannot_get_nonexistent_document_by_id(
async def test_i_cannot_get_nonexistent_document_by_id(
self, self,
document_service document_service
): ):
"""Test that nonexistent document returns None.""" """Test that nonexistent document returns None."""
# Execute with random ObjectId # Execute with random ObjectId
result = await document_service.get_document_by_id(ObjectId()) result = document_service.get_document_by_id(ObjectId())
# Verify # Verify
assert result is None assert result is None
@pytest.mark.asyncio def test_i_cannot_get_nonexistent_document_by_hash(
async def test_i_cannot_get_nonexistent_document_by_hash(
self, self,
document_service document_service
): ):
"""Test that nonexistent document hash returns None.""" """Test that nonexistent document hash returns None."""
# Execute # Execute
result = await document_service.get_document_by_hash("nonexistent_hash") result = document_service.get_document_by_hash("nonexistent_hash")
# Verify # Verify
assert result is None assert result is None
@@ -338,8 +327,7 @@ class TestPaginationAndCounting:
"""Tests for document listing and counting.""" """Tests for document listing and counting."""
@patch('app.services.document_service.magic.from_buffer') @patch('app.services.document_service.magic.from_buffer')
@pytest.mark.asyncio def test_i_can_list_documents_with_pagination(
async def test_i_can_list_documents_with_pagination(
self, self,
mock_magic, mock_magic,
document_service, document_service,
@@ -351,25 +339,24 @@ class TestPaginationAndCounting:
# Create multiple documents # Create multiple documents
for i in range(5): for i in range(5):
await document_service.create_document( document_service.create_document(
f"/test/test{i}.pdf", f"/test/test{i}.pdf",
sample_file_bytes + bytes(str(i), 'utf-8'), # Make each file unique sample_file_bytes + bytes(str(i), 'utf-8'), # Make each file unique
"utf-8" "utf-8"
) )
# Execute with pagination # Execute with pagination
result = await document_service.list_documents(skip=1, limit=2) result = document_service.list_documents(skip=1, limit=2)
# Verify # Verify
assert len(result) == 2 assert len(result) == 2
# Test counting # Test counting
total_count = await document_service.count_documents() total_count = document_service.count_documents()
assert total_count == 5 assert total_count == 5
@patch('app.services.document_service.magic.from_buffer') @patch('app.services.document_service.magic.from_buffer')
@pytest.mark.asyncio def test_i_can_count_documents(
async def test_i_can_count_documents(
self, self,
mock_magic, mock_magic,
document_service, document_service,
@@ -380,19 +367,19 @@ class TestPaginationAndCounting:
mock_magic.return_value = "text/plain" mock_magic.return_value = "text/plain"
# Initially should be 0 # Initially should be 0
initial_count = await document_service.count_documents() initial_count = document_service.count_documents()
assert initial_count == 0 assert initial_count == 0
# Create some documents # Create some documents
for i in range(3): for i in range(3):
await document_service.create_document( document_service.create_document(
f"/test/test{i}.txt", f"/test/test{i}.txt",
sample_file_bytes + bytes(str(i), 'utf-8'), sample_file_bytes + bytes(str(i), 'utf-8'),
"utf-8" "utf-8"
) )
# Execute # Execute
final_count = await document_service.count_documents() final_count = document_service.count_documents()
# Verify # Verify
assert final_count == 3 assert final_count == 3
@@ -402,8 +389,7 @@ class TestUpdateAndDelete:
"""Tests for document update and deletion operations.""" """Tests for document update and deletion operations."""
@patch('app.services.document_service.magic.from_buffer') @patch('app.services.document_service.magic.from_buffer')
@pytest.mark.asyncio def test_i_can_update_document_metadata(
async def test_i_can_update_document_metadata(
self, self,
mock_magic, mock_magic,
document_service, document_service,
@@ -414,7 +400,7 @@ class TestUpdateAndDelete:
mock_magic.return_value = "application/pdf" mock_magic.return_value = "application/pdf"
# Create a document first # Create a document first
created_doc = await document_service.create_document( created_doc = document_service.create_document(
"/test/test.pdf", "/test/test.pdf",
sample_file_bytes, sample_file_bytes,
"utf-8" "utf-8"
@@ -422,7 +408,7 @@ class TestUpdateAndDelete:
# Execute update # Execute update
update_data = {"metadata": {"page_count": 5}} update_data = {"metadata": {"page_count": 5}}
result = await document_service.update_document(created_doc.id, update_data) result = document_service.update_document(created_doc.id, update_data)
# Verify # Verify
assert result is not None assert result is not None
@@ -433,14 +419,13 @@ class TestUpdateAndDelete:
assert result.file_type == created_doc.file_type assert result.file_type == created_doc.file_type
assert result.metadata == update_data['metadata'] assert result.metadata == update_data['metadata']
@pytest.mark.asyncio def test_i_can_update_document_content(
async def test_i_can_update_document_content(
self, self,
document_service, document_service,
sample_file_bytes sample_file_bytes
): ):
# Create a document first # Create a document first
created_doc = await document_service.create_document( created_doc = document_service.create_document(
"/test/test.pdf", "/test/test.pdf",
sample_file_bytes, sample_file_bytes,
"utf-8" "utf-8"
@@ -448,7 +433,7 @@ class TestUpdateAndDelete:
# Execute update # Execute update
update_data = {"file_bytes": b"this is an updated file content"} update_data = {"file_bytes": b"this is an updated file content"}
result = await document_service.update_document(created_doc.id, update_data) result = document_service.update_document(created_doc.id, update_data)
assert result.filename == created_doc.filename assert result.filename == created_doc.filename
assert result.filepath == created_doc.filepath assert result.filepath == created_doc.filepath
@@ -460,8 +445,7 @@ class TestUpdateAndDelete:
validate_file_saved(document_service, result.file_hash, b"this is an updated file content") validate_file_saved(document_service, result.file_hash, b"this is an updated file content")
@patch('app.services.document_service.magic.from_buffer') @patch('app.services.document_service.magic.from_buffer')
@pytest.mark.asyncio def test_i_can_delete_document_and_orphaned_content(
async def test_i_can_delete_document_and_orphaned_content(
self, self,
mock_magic, mock_magic,
document_service, document_service,
@@ -472,7 +456,7 @@ class TestUpdateAndDelete:
mock_magic.return_value = "application/pdf" mock_magic.return_value = "application/pdf"
# Create a document # Create a document
created_doc = await document_service.create_document( created_doc = document_service.create_document(
"/test/test.pdf", "/test/test.pdf",
sample_file_bytes, sample_file_bytes,
"utf-8" "utf-8"
@@ -482,12 +466,12 @@ class TestUpdateAndDelete:
validate_file_saved(document_service, created_doc.file_hash, sample_file_bytes) validate_file_saved(document_service, created_doc.file_hash, sample_file_bytes)
# Execute deletion # Execute deletion
result = await document_service.delete_document(created_doc.id) result = document_service.delete_document(created_doc.id)
# Verify document and content are deleted # Verify document and content are deleted
assert result is True assert result is True
deleted_doc = await document_service.get_document_by_id(created_doc.id) deleted_doc = document_service.get_document_by_id(created_doc.id)
assert deleted_doc is None assert deleted_doc is None
# validate content is deleted # validate content is deleted
@@ -496,8 +480,7 @@ class TestUpdateAndDelete:
assert not os.path.exists(target_file_path) assert not os.path.exists(target_file_path)
@patch('app.services.document_service.magic.from_buffer') @patch('app.services.document_service.magic.from_buffer')
@pytest.mark.asyncio def test_i_can_delete_document_without_affecting_shared_content(
async def test_i_can_delete_document_without_affecting_shared_content(
self, self,
mock_magic, mock_magic,
document_service, document_service,
@@ -508,13 +491,13 @@ class TestUpdateAndDelete:
mock_magic.return_value = "application/pdf" mock_magic.return_value = "application/pdf"
# Create two documents with same content # Create two documents with same content
doc1 = await document_service.create_document( doc1 = document_service.create_document(
"/test/test1.pdf", "/test/test1.pdf",
sample_file_bytes, sample_file_bytes,
"utf-8" "utf-8"
) )
doc2 = await document_service.create_document( doc2 = document_service.create_document(
"/test/test2.pdf", "/test/test2.pdf",
sample_file_bytes, sample_file_bytes,
"utf-8" "utf-8"
@@ -524,14 +507,14 @@ class TestUpdateAndDelete:
assert doc1.file_hash == doc2.file_hash assert doc1.file_hash == doc2.file_hash
# Delete first document # Delete first document
result = await document_service.delete_document(doc1.id) result = document_service.delete_document(doc1.id)
assert result is True assert result is True
# Verify first document is deleted but content still exists # Verify first document is deleted but content still exists
deleted_doc = await document_service.get_document_by_id(doc1.id) deleted_doc = document_service.get_document_by_id(doc1.id)
assert deleted_doc is None assert deleted_doc is None
remaining_doc = await document_service.get_document_by_id(doc2.id) remaining_doc = document_service.get_document_by_id(doc2.id)
assert remaining_doc is not None assert remaining_doc is not None
validate_file_saved(document_service, doc2.file_hash, sample_file_bytes) validate_file_saved(document_service, doc2.file_hash, sample_file_bytes)

View File

@@ -6,9 +6,8 @@ using mongomock for better integration testing.
""" """
import pytest import pytest
import pytest_asyncio
from bson import ObjectId from bson import ObjectId
from mongomock_motor import AsyncMongoMockClient from mongomock.mongo_client import MongoClient
from app.exceptions.job_exceptions import InvalidStatusTransitionError from app.exceptions.job_exceptions import InvalidStatusTransitionError
from app.models.job import ProcessingStatus from app.models.job import ProcessingStatus
@@ -16,17 +15,17 @@ from app.models.types import PyObjectId
from app.services.job_service import JobService from app.services.job_service import JobService
@pytest_asyncio.fixture @pytest.fixture
async def in_memory_database(): def in_memory_database():
"""Create an in-memory database for testing.""" """Create an in-memory database for testing."""
client = AsyncMongoMockClient() client = MongoClient()
return client.test_database return client.test_database
@pytest_asyncio.fixture @pytest.fixture
async def job_service(in_memory_database): def job_service(in_memory_database):
"""Create JobService with in-memory repositories.""" """Create JobService with in-memory repositories."""
service = await JobService(in_memory_database).initialize() service = JobService(in_memory_database).initialize()
return service return service
@@ -45,8 +44,7 @@ def sample_task_id():
class TestCreateJob: class TestCreateJob:
"""Tests for create_job method.""" """Tests for create_job method."""
@pytest.mark.asyncio def test_i_can_create_job_with_task_id(
async def test_i_can_create_job_with_task_id(
self, self,
job_service, job_service,
sample_document_id, sample_document_id,
@@ -54,7 +52,7 @@ class TestCreateJob:
): ):
"""Test creating job with task ID.""" """Test creating job with task ID."""
# Execute # Execute
result = await job_service.create_job(sample_document_id, sample_task_id) result = job_service.create_job(sample_document_id, sample_task_id)
# Verify job creation # Verify job creation
assert result is not None assert result is not None
@@ -66,22 +64,21 @@ class TestCreateJob:
assert result.error_message is None assert result.error_message is None
# Verify job exists in database # Verify job exists in database
job_in_db = await job_service.get_job_by_id(result.id) job_in_db = job_service.get_job_by_id(result.id)
assert job_in_db is not None assert job_in_db is not None
assert job_in_db.id == result.id assert job_in_db.id == result.id
assert job_in_db.document_id == sample_document_id assert job_in_db.document_id == sample_document_id
assert job_in_db.task_id == sample_task_id assert job_in_db.task_id == sample_task_id
assert job_in_db.status == ProcessingStatus.PENDING assert job_in_db.status == ProcessingStatus.PENDING
@pytest.mark.asyncio def test_i_can_create_job_without_task_id(
async def test_i_can_create_job_without_task_id(
self, self,
job_service, job_service,
sample_document_id sample_document_id
): ):
"""Test creating job without task ID.""" """Test creating job without task ID."""
# Execute # Execute
result = await job_service.create_job(sample_document_id) result = job_service.create_job(sample_document_id)
# Verify job creation # Verify job creation
assert result is not None assert result is not None
@@ -96,8 +93,7 @@ class TestCreateJob:
class TestGetJobMethods: class TestGetJobMethods:
"""Tests for job retrieval methods.""" """Tests for job retrieval methods."""
@pytest.mark.asyncio def test_i_can_get_job_by_id(
async def test_i_can_get_job_by_id(
self, self,
job_service, job_service,
sample_document_id, sample_document_id,
@@ -105,10 +101,10 @@ class TestGetJobMethods:
): ):
"""Test retrieving job by ID.""" """Test retrieving job by ID."""
# Create a job first # Create a job first
created_job = await job_service.create_job(sample_document_id, sample_task_id) created_job = job_service.create_job(sample_document_id, sample_task_id)
# Execute # Execute
result = await job_service.get_job_by_id(created_job.id) result = job_service.get_job_by_id(created_job.id)
# Verify # Verify
assert result is not None assert result is not None
@@ -117,25 +113,24 @@ class TestGetJobMethods:
assert result.task_id == created_job.task_id assert result.task_id == created_job.task_id
assert result.status == created_job.status assert result.status == created_job.status
@pytest.mark.asyncio def test_i_can_get_jobs_by_status(
async def test_i_can_get_jobs_by_status(
self, self,
job_service, job_service,
sample_document_id sample_document_id
): ):
"""Test retrieving jobs by status.""" """Test retrieving jobs by status."""
# Create jobs with different statuses # Create jobs with different statuses
pending_job = await job_service.create_job(sample_document_id, "pending-task") pending_job = job_service.create_job(sample_document_id, "pending-task")
processing_job = await job_service.create_job(ObjectId(), "processing-task") processing_job = job_service.create_job(ObjectId(), "processing-task")
await job_service.mark_job_as_started(processing_job.id) job_service.mark_job_as_started(processing_job.id)
completed_job = await job_service.create_job(ObjectId(), "completed-task") completed_job = job_service.create_job(ObjectId(), "completed-task")
await job_service.mark_job_as_started(completed_job.id) job_service.mark_job_as_started(completed_job.id)
await job_service.mark_job_as_completed(completed_job.id) job_service.mark_job_as_completed(completed_job.id)
# Execute - get pending jobs # Execute - get pending jobs
pending_results = await job_service.get_jobs_by_status(ProcessingStatus.PENDING) pending_results = job_service.get_jobs_by_status(ProcessingStatus.PENDING)
# Verify # Verify
assert len(pending_results) == 1 assert len(pending_results) == 1
@@ -143,12 +138,12 @@ class TestGetJobMethods:
assert pending_results[0].status == ProcessingStatus.PENDING assert pending_results[0].status == ProcessingStatus.PENDING
# Execute - get processing jobs # Execute - get processing jobs
processing_results = await job_service.get_jobs_by_status(ProcessingStatus.PROCESSING) processing_results = job_service.get_jobs_by_status(ProcessingStatus.PROCESSING)
assert len(processing_results) == 1 assert len(processing_results) == 1
assert processing_results[0].status == ProcessingStatus.PROCESSING assert processing_results[0].status == ProcessingStatus.PROCESSING
# Execute - get completed jobs # Execute - get completed jobs
completed_results = await job_service.get_jobs_by_status(ProcessingStatus.COMPLETED) completed_results = job_service.get_jobs_by_status(ProcessingStatus.COMPLETED)
assert len(completed_results) == 1 assert len(completed_results) == 1
assert completed_results[0].status == ProcessingStatus.COMPLETED assert completed_results[0].status == ProcessingStatus.COMPLETED
@@ -156,8 +151,7 @@ class TestGetJobMethods:
class TestUpdateStatus: class TestUpdateStatus:
"""Tests for mark_job_as_started method.""" """Tests for mark_job_as_started method."""
@pytest.mark.asyncio def test_i_can_mark_pending_job_as_started(
async def test_i_can_mark_pending_job_as_started(
self, self,
job_service, job_service,
sample_document_id, sample_document_id,
@@ -165,11 +159,11 @@ class TestUpdateStatus:
): ):
"""Test marking pending job as started (PENDING → PROCESSING).""" """Test marking pending job as started (PENDING → PROCESSING)."""
# Create a pending job # Create a pending job
created_job = await job_service.create_job(sample_document_id, sample_task_id) created_job = job_service.create_job(sample_document_id, sample_task_id)
assert created_job.status == ProcessingStatus.PENDING assert created_job.status == ProcessingStatus.PENDING
# Execute # Execute
result = await job_service.mark_job_as_started(created_job.id) result = job_service.mark_job_as_started(created_job.id)
# Verify status transition # Verify status transition
assert result is not None assert result is not None
@@ -177,11 +171,10 @@ class TestUpdateStatus:
assert result.status == ProcessingStatus.PROCESSING assert result.status == ProcessingStatus.PROCESSING
# Verify in database # Verify in database
updated_job = await job_service.get_job_by_id(created_job.id) updated_job = job_service.get_job_by_id(created_job.id)
assert updated_job.status == ProcessingStatus.PROCESSING assert updated_job.status == ProcessingStatus.PROCESSING
@pytest.mark.asyncio def test_i_cannot_mark_processing_job_as_started(
async def test_i_cannot_mark_processing_job_as_started(
self, self,
job_service, job_service,
sample_document_id, sample_document_id,
@@ -189,19 +182,18 @@ class TestUpdateStatus:
): ):
"""Test that processing job cannot be marked as started.""" """Test that processing job cannot be marked as started."""
# Create and start a job # Create and start a job
created_job = await job_service.create_job(sample_document_id, sample_task_id) created_job = job_service.create_job(sample_document_id, sample_task_id)
await job_service.mark_job_as_started(created_job.id) job_service.mark_job_as_started(created_job.id)
# Try to start it again # Try to start it again
with pytest.raises(InvalidStatusTransitionError) as exc_info: with pytest.raises(InvalidStatusTransitionError) as exc_info:
await job_service.mark_job_as_started(created_job.id) job_service.mark_job_as_started(created_job.id)
# Verify exception details # Verify exception details
assert exc_info.value.current_status == ProcessingStatus.PROCESSING assert exc_info.value.current_status == ProcessingStatus.PROCESSING
assert exc_info.value.target_status == ProcessingStatus.PROCESSING assert exc_info.value.target_status == ProcessingStatus.PROCESSING
@pytest.mark.asyncio def test_i_cannot_mark_completed_job_as_started(
async def test_i_cannot_mark_completed_job_as_started(
self, self,
job_service, job_service,
sample_document_id, sample_document_id,
@@ -209,20 +201,19 @@ class TestUpdateStatus:
): ):
"""Test that completed job cannot be marked as started.""" """Test that completed job cannot be marked as started."""
# Create, start, and complete a job # Create, start, and complete a job
created_job = await job_service.create_job(sample_document_id, sample_task_id) created_job = job_service.create_job(sample_document_id, sample_task_id)
await job_service.mark_job_as_started(created_job.id) job_service.mark_job_as_started(created_job.id)
await job_service.mark_job_as_completed(created_job.id) job_service.mark_job_as_completed(created_job.id)
# Try to start it again # Try to start it again
with pytest.raises(InvalidStatusTransitionError) as exc_info: with pytest.raises(InvalidStatusTransitionError) as exc_info:
await job_service.mark_job_as_started(created_job.id) job_service.mark_job_as_started(created_job.id)
# Verify exception details # Verify exception details
assert exc_info.value.current_status == ProcessingStatus.COMPLETED assert exc_info.value.current_status == ProcessingStatus.COMPLETED
assert exc_info.value.target_status == ProcessingStatus.PROCESSING assert exc_info.value.target_status == ProcessingStatus.PROCESSING
@pytest.mark.asyncio def test_i_cannot_mark_failed_job_as_started(
async def test_i_cannot_mark_failed_job_as_started(
self, self,
job_service, job_service,
sample_document_id, sample_document_id,
@@ -230,20 +221,19 @@ class TestUpdateStatus:
): ):
"""Test that failed job cannot be marked as started.""" """Test that failed job cannot be marked as started."""
# Create, start, and fail a job # Create, start, and fail a job
created_job = await job_service.create_job(sample_document_id, sample_task_id) created_job = job_service.create_job(sample_document_id, sample_task_id)
await job_service.mark_job_as_started(created_job.id) job_service.mark_job_as_started(created_job.id)
await job_service.mark_job_as_failed(created_job.id, "Test error") job_service.mark_job_as_failed(created_job.id, "Test error")
# Try to start it again # Try to start it again
with pytest.raises(InvalidStatusTransitionError) as exc_info: with pytest.raises(InvalidStatusTransitionError) as exc_info:
await job_service.mark_job_as_started(created_job.id) job_service.mark_job_as_started(created_job.id)
# Verify exception details # Verify exception details
assert exc_info.value.current_status == ProcessingStatus.FAILED assert exc_info.value.current_status == ProcessingStatus.FAILED
assert exc_info.value.target_status == ProcessingStatus.PROCESSING assert exc_info.value.target_status == ProcessingStatus.PROCESSING
@pytest.mark.asyncio def test_i_can_mark_processing_job_as_completed(
async def test_i_can_mark_processing_job_as_completed(
self, self,
job_service, job_service,
sample_document_id, sample_document_id,
@@ -251,11 +241,11 @@ class TestUpdateStatus:
): ):
"""Test marking processing job as completed (PROCESSING → COMPLETED).""" """Test marking processing job as completed (PROCESSING → COMPLETED)."""
# Create and start a job # Create and start a job
created_job = await job_service.create_job(sample_document_id, sample_task_id) created_job = job_service.create_job(sample_document_id, sample_task_id)
started_job = await job_service.mark_job_as_started(created_job.id) started_job = job_service.mark_job_as_started(created_job.id)
# Execute # Execute
result = await job_service.mark_job_as_completed(created_job.id) result = job_service.mark_job_as_completed(created_job.id)
# Verify status transition # Verify status transition
assert result is not None assert result is not None
@@ -263,11 +253,10 @@ class TestUpdateStatus:
assert result.status == ProcessingStatus.COMPLETED assert result.status == ProcessingStatus.COMPLETED
# Verify in database # Verify in database
updated_job = await job_service.get_job_by_id(created_job.id) updated_job = job_service.get_job_by_id(created_job.id)
assert updated_job.status == ProcessingStatus.COMPLETED assert updated_job.status == ProcessingStatus.COMPLETED
@pytest.mark.asyncio def test_i_cannot_mark_pending_job_as_completed(
async def test_i_cannot_mark_pending_job_as_completed(
self, self,
job_service, job_service,
sample_document_id, sample_document_id,
@@ -275,18 +264,17 @@ class TestUpdateStatus:
): ):
"""Test that pending job cannot be marked as completed.""" """Test that pending job cannot be marked as completed."""
# Create a pending job # Create a pending job
created_job = await job_service.create_job(sample_document_id, sample_task_id) created_job = job_service.create_job(sample_document_id, sample_task_id)
# Try to complete it directly # Try to complete it directly
with pytest.raises(InvalidStatusTransitionError) as exc_info: with pytest.raises(InvalidStatusTransitionError) as exc_info:
await job_service.mark_job_as_completed(created_job.id) job_service.mark_job_as_completed(created_job.id)
# Verify exception details # Verify exception details
assert exc_info.value.current_status == ProcessingStatus.PENDING assert exc_info.value.current_status == ProcessingStatus.PENDING
assert exc_info.value.target_status == ProcessingStatus.COMPLETED assert exc_info.value.target_status == ProcessingStatus.COMPLETED
@pytest.mark.asyncio def test_i_cannot_mark_completed_job_as_completed(
async def test_i_cannot_mark_completed_job_as_completed(
self, self,
job_service, job_service,
sample_document_id, sample_document_id,
@@ -294,20 +282,19 @@ class TestUpdateStatus:
): ):
"""Test that completed job cannot be marked as completed again.""" """Test that completed job cannot be marked as completed again."""
# Create, start, and complete a job # Create, start, and complete a job
created_job = await job_service.create_job(sample_document_id, sample_task_id) created_job = job_service.create_job(sample_document_id, sample_task_id)
await job_service.mark_job_as_started(created_job.id) job_service.mark_job_as_started(created_job.id)
await job_service.mark_job_as_completed(created_job.id) job_service.mark_job_as_completed(created_job.id)
# Try to complete it again # Try to complete it again
with pytest.raises(InvalidStatusTransitionError) as exc_info: with pytest.raises(InvalidStatusTransitionError) as exc_info:
await job_service.mark_job_as_completed(created_job.id) job_service.mark_job_as_completed(created_job.id)
# Verify exception details # Verify exception details
assert exc_info.value.current_status == ProcessingStatus.COMPLETED assert exc_info.value.current_status == ProcessingStatus.COMPLETED
assert exc_info.value.target_status == ProcessingStatus.COMPLETED assert exc_info.value.target_status == ProcessingStatus.COMPLETED
@pytest.mark.asyncio def test_i_cannot_mark_failed_job_as_completed(
async def test_i_cannot_mark_failed_job_as_completed(
self, self,
job_service, job_service,
sample_document_id, sample_document_id,
@@ -315,20 +302,19 @@ class TestUpdateStatus:
): ):
"""Test that failed job cannot be marked as completed.""" """Test that failed job cannot be marked as completed."""
# Create, start, and fail a job # Create, start, and fail a job
created_job = await job_service.create_job(sample_document_id, sample_task_id) created_job = job_service.create_job(sample_document_id, sample_task_id)
await job_service.mark_job_as_started(created_job.id) job_service.mark_job_as_started(created_job.id)
await job_service.mark_job_as_failed(created_job.id, "Test error") job_service.mark_job_as_failed(created_job.id, "Test error")
# Try to complete it # Try to complete it
with pytest.raises(InvalidStatusTransitionError) as exc_info: with pytest.raises(InvalidStatusTransitionError) as exc_info:
await job_service.mark_job_as_completed(created_job.id) job_service.mark_job_as_completed(created_job.id)
# Verify exception details # Verify exception details
assert exc_info.value.current_status == ProcessingStatus.FAILED assert exc_info.value.current_status == ProcessingStatus.FAILED
assert exc_info.value.target_status == ProcessingStatus.COMPLETED assert exc_info.value.target_status == ProcessingStatus.COMPLETED
@pytest.mark.asyncio def test_i_can_mark_processing_job_as_failed_with_error_message(
async def test_i_can_mark_processing_job_as_failed_with_error_message(
self, self,
job_service, job_service,
sample_document_id, sample_document_id,
@@ -336,13 +322,13 @@ class TestUpdateStatus:
): ):
"""Test marking processing job as failed with error message.""" """Test marking processing job as failed with error message."""
# Create and start a job # Create and start a job
created_job = await job_service.create_job(sample_document_id, sample_task_id) created_job = job_service.create_job(sample_document_id, sample_task_id)
started_job = await job_service.mark_job_as_started(created_job.id) started_job = job_service.mark_job_as_started(created_job.id)
error_message = "Processing failed due to invalid file format" error_message = "Processing failed due to invalid file format"
# Execute # Execute
result = await job_service.mark_job_as_failed(created_job.id, error_message) result = job_service.mark_job_as_failed(created_job.id, error_message)
# Verify status transition # Verify status transition
assert result is not None assert result is not None
@@ -351,12 +337,11 @@ class TestUpdateStatus:
assert result.error_message == error_message assert result.error_message == error_message
# Verify in database # Verify in database
updated_job = await job_service.get_job_by_id(created_job.id) updated_job = job_service.get_job_by_id(created_job.id)
assert updated_job.status == ProcessingStatus.FAILED assert updated_job.status == ProcessingStatus.FAILED
assert updated_job.error_message == error_message assert updated_job.error_message == error_message
@pytest.mark.asyncio def test_i_can_mark_processing_job_as_failed_without_error_message(
async def test_i_can_mark_processing_job_as_failed_without_error_message(
self, self,
job_service, job_service,
sample_document_id, sample_document_id,
@@ -364,19 +349,18 @@ class TestUpdateStatus:
): ):
"""Test marking processing job as failed without error message.""" """Test marking processing job as failed without error message."""
# Create and start a job # Create and start a job
created_job = await job_service.create_job(sample_document_id, sample_task_id) created_job = job_service.create_job(sample_document_id, sample_task_id)
await job_service.mark_job_as_started(created_job.id) job_service.mark_job_as_started(created_job.id)
# Execute without error message # Execute without error message
result = await job_service.mark_job_as_failed(created_job.id) result = job_service.mark_job_as_failed(created_job.id)
# Verify status transition # Verify status transition
assert result is not None assert result is not None
assert result.status == ProcessingStatus.FAILED assert result.status == ProcessingStatus.FAILED
assert result.error_message is None assert result.error_message is None
@pytest.mark.asyncio def test_i_cannot_mark_pending_job_as_failed(
async def test_i_cannot_mark_pending_job_as_failed(
self, self,
job_service, job_service,
sample_document_id, sample_document_id,
@@ -384,18 +368,17 @@ class TestUpdateStatus:
): ):
"""Test that pending job cannot be marked as failed.""" """Test that pending job cannot be marked as failed."""
# Create a pending job # Create a pending job
created_job = await job_service.create_job(sample_document_id, sample_task_id) created_job = job_service.create_job(sample_document_id, sample_task_id)
# Try to fail it directly # Try to fail it directly
with pytest.raises(InvalidStatusTransitionError) as exc_info: with pytest.raises(InvalidStatusTransitionError) as exc_info:
await job_service.mark_job_as_failed(created_job.id, "Test error") job_service.mark_job_as_failed(created_job.id, "Test error")
# Verify exception details # Verify exception details
assert exc_info.value.current_status == ProcessingStatus.PENDING assert exc_info.value.current_status == ProcessingStatus.PENDING
assert exc_info.value.target_status == ProcessingStatus.FAILED assert exc_info.value.target_status == ProcessingStatus.FAILED
@pytest.mark.asyncio def test_i_cannot_mark_completed_job_as_failed(
async def test_i_cannot_mark_completed_job_as_failed(
self, self,
job_service, job_service,
sample_document_id, sample_document_id,
@@ -403,20 +386,19 @@ class TestUpdateStatus:
): ):
"""Test that completed job cannot be marked as failed.""" """Test that completed job cannot be marked as failed."""
# Create, start, and complete a job # Create, start, and complete a job
created_job = await job_service.create_job(sample_document_id, sample_task_id) created_job = job_service.create_job(sample_document_id, sample_task_id)
await job_service.mark_job_as_started(created_job.id) job_service.mark_job_as_started(created_job.id)
await job_service.mark_job_as_completed(created_job.id) job_service.mark_job_as_completed(created_job.id)
# Try to fail it # Try to fail it
with pytest.raises(InvalidStatusTransitionError) as exc_info: with pytest.raises(InvalidStatusTransitionError) as exc_info:
await job_service.mark_job_as_failed(created_job.id, "Test error") job_service.mark_job_as_failed(created_job.id, "Test error")
# Verify exception details # Verify exception details
assert exc_info.value.current_status == ProcessingStatus.COMPLETED assert exc_info.value.current_status == ProcessingStatus.COMPLETED
assert exc_info.value.target_status == ProcessingStatus.FAILED assert exc_info.value.target_status == ProcessingStatus.FAILED
@pytest.mark.asyncio def test_i_cannot_mark_failed_job_as_failed(
async def test_i_cannot_mark_failed_job_as_failed(
self, self,
job_service, job_service,
sample_document_id, sample_document_id,
@@ -424,13 +406,13 @@ class TestUpdateStatus:
): ):
"""Test that failed job cannot be marked as failed again.""" """Test that failed job cannot be marked as failed again."""
# Create, start, and fail a job # Create, start, and fail a job
created_job = await job_service.create_job(sample_document_id, sample_task_id) created_job = job_service.create_job(sample_document_id, sample_task_id)
await job_service.mark_job_as_started(created_job.id) job_service.mark_job_as_started(created_job.id)
await job_service.mark_job_as_failed(created_job.id, "First error") job_service.mark_job_as_failed(created_job.id, "First error")
# Try to fail it again # Try to fail it again
with pytest.raises(InvalidStatusTransitionError) as exc_info: with pytest.raises(InvalidStatusTransitionError) as exc_info:
await job_service.mark_job_as_failed(created_job.id, "Second error") job_service.mark_job_as_failed(created_job.id, "Second error")
# Verify exception details # Verify exception details
assert exc_info.value.current_status == ProcessingStatus.FAILED assert exc_info.value.current_status == ProcessingStatus.FAILED
@@ -440,8 +422,7 @@ class TestUpdateStatus:
class TestDeleteJob: class TestDeleteJob:
"""Tests for delete_job method.""" """Tests for delete_job method."""
@pytest.mark.asyncio def test_i_can_delete_existing_job(
async def test_i_can_delete_existing_job(
self, self,
job_service, job_service,
sample_document_id, sample_document_id,
@@ -449,30 +430,29 @@ class TestDeleteJob:
): ):
"""Test deleting an existing job.""" """Test deleting an existing job."""
# Create a job # Create a job
created_job = await job_service.create_job(sample_document_id, sample_task_id) created_job = job_service.create_job(sample_document_id, sample_task_id)
# Verify job exists # Verify job exists
job_before_delete = await job_service.get_job_by_id(created_job.id) job_before_delete = job_service.get_job_by_id(created_job.id)
assert job_before_delete is not None assert job_before_delete is not None
# Execute deletion # Execute deletion
result = await job_service.delete_job(created_job.id) result = job_service.delete_job(created_job.id)
# Verify deletion # Verify deletion
assert result is True assert result is True
# Verify job no longer exists # Verify job no longer exists
deleted_job = await job_service.get_job_by_id(created_job.id) deleted_job = job_service.get_job_by_id(created_job.id)
assert deleted_job is None assert deleted_job is None
@pytest.mark.asyncio def test_i_cannot_delete_nonexistent_job(
async def test_i_cannot_delete_nonexistent_job(
self, self,
job_service job_service
): ):
"""Test deleting a nonexistent job returns False.""" """Test deleting a nonexistent job returns False."""
# Execute deletion with random ObjectId # Execute deletion with random ObjectId
result = await job_service.delete_job(ObjectId()) result = job_service.delete_job(ObjectId())
# Verify # Verify
assert result is False assert result is False
@@ -481,8 +461,7 @@ class TestDeleteJob:
class TestStatusTransitionValidation: class TestStatusTransitionValidation:
"""Tests for status transition validation across different scenarios.""" """Tests for status transition validation across different scenarios."""
@pytest.mark.asyncio def test_valid_job_lifecycle_flow(
async def test_valid_job_lifecycle_flow(
self, self,
job_service, job_service,
sample_document_id, sample_document_id,
@@ -490,19 +469,18 @@ class TestStatusTransitionValidation:
): ):
"""Test complete valid job lifecycle: PENDING → PROCESSING → COMPLETED.""" """Test complete valid job lifecycle: PENDING → PROCESSING → COMPLETED."""
# Create job (PENDING) # Create job (PENDING)
job = await job_service.create_job(sample_document_id, sample_task_id) job = job_service.create_job(sample_document_id, sample_task_id)
assert job.status == ProcessingStatus.PENDING assert job.status == ProcessingStatus.PENDING
# Start job (PENDING → PROCESSING) # Start job (PENDING → PROCESSING)
started_job = await job_service.mark_job_as_started(job.id) started_job = job_service.mark_job_as_started(job.id)
assert started_job.status == ProcessingStatus.PROCESSING assert started_job.status == ProcessingStatus.PROCESSING
# Complete job (PROCESSING → COMPLETED) # Complete job (PROCESSING → COMPLETED)
completed_job = await job_service.mark_job_as_completed(job.id) completed_job = job_service.mark_job_as_completed(job.id)
assert completed_job.status == ProcessingStatus.COMPLETED assert completed_job.status == ProcessingStatus.COMPLETED
@pytest.mark.asyncio def test_valid_job_failure_flow(
async def test_valid_job_failure_flow(
self, self,
job_service, job_service,
sample_document_id, sample_document_id,
@@ -510,69 +488,31 @@ class TestStatusTransitionValidation:
): ):
"""Test valid job failure: PENDING → PROCESSING → FAILED.""" """Test valid job failure: PENDING → PROCESSING → FAILED."""
# Create job (PENDING) # Create job (PENDING)
job = await job_service.create_job(sample_document_id, sample_task_id) job = job_service.create_job(sample_document_id, sample_task_id)
assert job.status == ProcessingStatus.PENDING assert job.status == ProcessingStatus.PENDING
# Start job (PENDING → PROCESSING) # Start job (PENDING → PROCESSING)
started_job = await job_service.mark_job_as_started(job.id) started_job = job_service.mark_job_as_started(job.id)
assert started_job.status == ProcessingStatus.PROCESSING assert started_job.status == ProcessingStatus.PROCESSING
# Fail job (PROCESSING → FAILED) # Fail job (PROCESSING → FAILED)
failed_job = await job_service.mark_job_as_failed(job.id, "Test failure") failed_job = job_service.mark_job_as_failed(job.id, "Test failure")
assert failed_job.status == ProcessingStatus.FAILED assert failed_job.status == ProcessingStatus.FAILED
assert failed_job.error_message == "Test failure" assert failed_job.error_message == "Test failure"
def test_job_operations_with_empty_database(
class TestEdgeCases:
"""Tests for edge cases and error conditions."""
#
# @pytest.mark.asyncio
# async def test_multiple_jobs_for_same_file(
# self,
# job_service,
# sample_document_id
# ):
# """Test handling multiple jobs for the same file."""
# # Create multiple jobs for same file
# job1 = await job_service.create_job(sample_document_id, "task-1")
# job2 = await job_service.create_job(sample_document_id, "task-2")
# job3 = await job_service.create_job(sample_document_id, "task-3")
#
# # Verify all jobs exist and are independent
# jobs_for_file = await job_service.get_jobs_by_file_id(sample_document_id)
# assert len(jobs_for_file) == 3
#
# job_ids = [job.id for job in jobs_for_file]
# assert job1.id in job_ids
# assert job2.id in job_ids
# assert job3.id in job_ids
#
# # Verify status transitions work independently
# await job_service.mark_job_as_started(job1.id)
# await job_service.mark_job_as_completed(job1.id)
#
# # Other jobs should still be pending
# updated_job2 = await job_service.get_job_by_id(job2.id)
# updated_job3 = await job_service.get_job_by_id(job3.id)
#
# assert updated_job2.status == ProcessingStatus.PENDING
# assert updated_job3.status == ProcessingStatus.PENDING
@pytest.mark.asyncio
async def test_job_operations_with_empty_database(
self, self,
job_service job_service
): ):
"""Test job operations when database is empty.""" """Test job operations when database is empty."""
# Try to get nonexistent job # Try to get nonexistent job
result = await job_service.get_job_by_id(ObjectId()) result = job_service.get_job_by_id(ObjectId())
assert result is None assert result is None
# Try to get jobs by status when none exist # Try to get jobs by status when none exist
pending_jobs = await job_service.get_jobs_by_status(ProcessingStatus.PENDING) pending_jobs = job_service.get_jobs_by_status(ProcessingStatus.PENDING)
assert pending_jobs == [] assert pending_jobs == []
# Try to delete nonexistent job # Try to delete nonexistent job
delete_result = await job_service.delete_job(ObjectId()) delete_result = job_service.delete_job(ObjectId())
assert delete_result is False assert delete_result is False