commit 79a31ecf40c6042238e41735ff67a0cd7c9411da Author: Kodjo Sossouvi Date: Sat Oct 18 12:26:55 2025 +0200 Unit testing AuthService diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1f7286a --- /dev/null +++ b/.gitignore @@ -0,0 +1,216 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[codz] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py.cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +# Pipfile.lock + +# UV +# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# uv.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +# poetry.lock +# poetry.toml + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python. +# https://pdm-project.org/en/latest/usage/project/#working-with-version-control +# pdm.lock +# pdm.toml +.pdm-python +.pdm-build/ + +# pixi +# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control. +# pixi.lock +# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one +# in the .venv directory. It is recommended not to include this directory in version control. +.pixi + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# Redis +*.rdb +*.aof +*.pid + +# RabbitMQ +mnesia/ +rabbitmq/ +rabbitmq-data/ + +# ActiveMQ +activemq-data/ + +# SageMath parsed files +*.sage.py + +# Environments +.env +.envrc +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +# .idea/ + +# Abstra +# Abstra is an AI-powered process automation framework. +# Ignore directories containing user credentials, local state, and settings. +# Learn more at https://abstra.io/docs +.abstra/ + +# Visual Studio Code +# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore +# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore +# and can be added to the global gitignore or merged into this file. However, if you prefer, +# you could uncomment the following to ignore the entire vscode folder +# .vscode/ + +# Ruff stuff: +.ruff_cache/ + +# PyPI configuration file +.pypirc + +# Marimo +marimo/_static/ +marimo/_lsp/ +__marimo__/ + +# Streamlit +.streamlit/secrets.toml \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..7ece33f --- /dev/null +++ b/README.md @@ -0,0 +1,285 @@ +# MyAuth Module + +A reusable, modular authentication system for FastAPI applications with pluggable database backends. + +## Overview + +This module provides a complete authentication solution designed to be deployed on internal PyPI servers and reused across multiple projects. It handles user registration, login/logout, token management, email verification, and password reset functionality. + +## Features + +### Core Authentication +- ✅ User registration with email and username +- ✅ Login/logout with email-based authentication +- ✅ JWT-based access tokens (30 minutes validity) +- ✅ Opaque refresh tokens stored in database (7 days validity) +- ✅ Password hashing with configurable bcrypt rounds + +### User Management +- ✅ Email verification with JWT tokens +- ✅ Password reset with secure random tokens (15 minutes validity) +- ✅ User roles management (flexible, no predefined roles) +- ✅ User settings storage (dict field) +- ✅ Account activation/deactivation + +### Password Security +- ✅ Strict password validation (via Pydantic): + - Minimum 8 characters + - At least 1 uppercase letter + - At least 1 lowercase letter + - At least 1 digit + - At least 1 special character + +### Architecture +- ✅ Abstract base classes for database persistence +- ✅ Multiple database implementations: MongoDB, SQLite, PostgreSQL +- ✅ Abstract email service interface with optional SMTP implementation +- ✅ Custom exceptions with FastAPI integration +- ✅ Synchronous implementation +- ✅ Ready-to-use FastAPI router with `/auth` prefix + +## Project Structure + +``` +├──src +│ my_auth/ +│ ├── __init__.py +│ ├── models/ # Pydantic models +│ │ ├── user.py # User model with roles and settings +│ │ ├── token.py # Token models (access, refresh, reset) +│ │ └── email_verification.py # Email verification models +│ ├── core/ # Business logic +│ │ ├── auth.py # Authentication service +│ │ ├── password.py # Password hashing/verification +│ │ └── token.py # Token generation/validation +│ ├── persistence/ # Database abstraction +│ │ ├── base.py # Abstract base classes +│ │ ├── mongodb.py # MongoDB implementation +│ │ ├── sqlite.py # SQLite implementation +│ │ └── postgresql.py # PostgreSQL implementation +│ ├── api/ # FastAPI routes +│ │ └── routes.py # All authentication endpoints +│ ├── email/ # Email service +│ │ ├── base.py # Abstract interface +│ │ └── smtp.py # SMTP implementation (optional) +│ ├── exceptions.py # Custom exceptions +│ └── config.py # Configuration classes +├── tests +``` + +## User Model + +```python +class User: + id: str # Unique identifier + email: str # Unique, required + username: str # Required, non-unique + hashed_password: str # Bcrypt hashed + roles: list[str] # Free-form roles, no defaults + user_settings: dict # Custom user settings + is_verified: bool # Email verification status + is_active: bool # Account active status + created_at: datetime + updated_at: datetime +``` + +## Token Management + +### Token Types +The module uses a unified tokens collection with a discriminator field: + +1. **Access Token (JWT)**: 30 minutes validity, stateless +2. **Refresh Token (Opaque)**: 7 days validity, stored in DB +3. **Password Reset Token (Random)**: 15 minutes validity, stored in DB +4. **Email Verification Token (JWT)**: Stateless, no DB storage + +### Token Storage +All tokens requiring storage (refresh and password reset) are kept in a single `tokens` collection/table with a `token_type` discriminator field. + +## API Endpoints + +The module exposes a pre-configured FastAPI router with the following endpoints: + +``` +POST /auth/register # User registration +POST /auth/login # Login (email + password) +POST /auth/logout # Logout (revokes refresh token) +POST /auth/refresh # Refresh access token +POST /auth/password-reset-request # Request password reset +POST /auth/password-reset # Reset password with token +POST /auth/verify-email-request # Request email verification +POST /auth/verify-email # Verify email with token +GET /auth/me # Get current user info +``` + +## Installation + +### Dependencies + +**Core dependencies:** +```bash +pip install fastapi pydantic pydantic-settings python-jose[cryptography] passlib[bcrypt] python-multipart +``` + +**Database-specific dependencies:** +- MongoDB: `pip install pymongo` +- SQLite: Built-in (no additional dependency) +- PostgreSQL: `pip install psycopg2-binary` + +**Optional email dependency:** +- SMTP: `pip install secure-smtplib` + +## Usage + +### Basic Setup + +```python +from fastapi import FastAPI +from auth_module import AuthService +from auth_module.persistence.mongodb import MongoUserRepository, MongoTokenRepository +from auth_module.api import auth_router + +# Initialize repositories +user_repo = MongoUserRepository(connection_string="mongodb://localhost:27017/mydb") +token_repo = MongoTokenRepository(connection_string="mongodb://localhost:27017/mydb") + +# Initialize auth service +auth_service = AuthService( + user_repository=user_repo, + token_repository=token_repo, + jwt_secret="your-secret-key-here", + access_token_expire_minutes=30, + refresh_token_expire_days=7, + password_reset_token_expire_minutes=15, + password_hash_rounds=12 +) + +# Create FastAPI app and include auth router +app = FastAPI() +app.include_router(auth_router) # Mounts at /auth prefix +``` + +### Using Different Databases + +#### SQLite +```python +from auth_module.persistence.sqlite import SQLiteUserRepository, SQLiteTokenRepository + +user_repo = SQLiteUserRepository(db_path="./auth.db") +token_repo = SQLiteTokenRepository(db_path="./auth.db") +``` + +#### PostgreSQL +```python +from auth_module.persistence.postgresql import PostgreSQLUserRepository, PostgreSQLTokenRepository + +user_repo = PostgreSQLUserRepository( + host="localhost", + port=5432, + database="mydb", + user="postgres", + password="secret" +) +token_repo = PostgreSQLTokenRepository(...) +``` + +### Email Service Configuration + +```python +from auth_module.email.smtp import SMTPEmailService + +email_service = SMTPEmailService( + host="smtp.gmail.com", + port=587, + username="your-email@gmail.com", + password="your-app-password", + use_tls=True +) + +auth_service = AuthService( + user_repository=user_repo, + token_repository=token_repo, + email_service=email_service, # Optional + ... +) +``` + +### Custom Email Service + +Implement your own email service by extending the abstract base class: + +```python +from auth_module.email.base import EmailService + +class CustomEmailService(EmailService): + def send_verification_email(self, email: str, token: str) -> None: + # Your implementation (SendGrid, AWS SES, etc.) + pass + + def send_password_reset_email(self, email: str, token: str) -> None: + # Your implementation + pass +``` + +## Error Handling + +The module uses custom exceptions that are automatically converted to appropriate HTTP responses: + +- `InvalidCredentialsError` → 401 Unauthorized +- `UserAlreadyExistsError` → 409 Conflict +- `UserNotFoundError` → 404 Not Found +- `InvalidTokenError` → 401 Unauthorized +- `RevokedTokenError` → 401 Unauthorized +- `ExpiredTokenError` → 401 Unauthorized +- `EmailNotVerifiedError` → 403 Forbidden +- `AccountDisabledError` → 403 Forbidden + +## Configuration Options + +```python +AuthService( + user_repository: UserRepository, # Required + token_repository: TokenRepository, # Required + jwt_secret: str, # Required + jwt_algorithm: str = "HS256", # Optional + access_token_expire_minutes: int = 30, # Optional + refresh_token_expire_days: int = 7, # Optional + password_reset_token_expire_minutes: int = 15, # Optional + password_hash_rounds: int = 12, # Optional (bcrypt cost) + email_service: EmailService = None # Optional +) +``` + +## Testing + +The module is fully testable with pytest. Test fixtures are provided for each database implementation. + +```bash +pytest tests/ +``` + +## Security Considerations + +- Passwords are hashed using bcrypt with configurable rounds +- JWT tokens are signed with HS256 (configurable) +- Refresh tokens are opaque and stored securely +- Password reset tokens are single-use and expire after 15 minutes +- Email verification tokens are stateless JWT +- Rate limiting should be implemented at the application level +- HTTPS should be enforced by the application + +## Future Enhancements (Not Included) + +- Multi-factor authentication (2FA/MFA) +- Rate limiting on login attempts +- OAuth2 provider integration +- Session management (multiple device tracking) +- Account lockout after failed attempts + +## License + +[Your License Here] + +## Contributing + +[Your Contributing Guidelines Here] \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..b104d14 --- /dev/null +++ b/main.py @@ -0,0 +1,16 @@ +# This is a sample Python script. + +# Press Ctrl+F5 to execute it or replace it with your code. +# Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings. + + +def print_hi(name): + # Use a breakpoint in the code line below to debug your script. + print(f'Hi, {name}') # Press F9 to toggle the breakpoint. + + +# Press the green button in the gutter to run the script. +if __name__ == '__main__': + print_hi('PyCharm') + +# See PyCharm help at https://www.jetbrains.com/help/pycharm/ diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/my_auth/__init__.py b/src/my_auth/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/my_auth/core/__init__.py b/src/my_auth/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/my_auth/core/auth.py b/src/my_auth/core/auth.py new file mode 100644 index 0000000..86b2fd0 --- /dev/null +++ b/src/my_auth/core/auth.py @@ -0,0 +1,445 @@ +""" +Authentication service for the authentication module. + +This module provides the main authentication service that orchestrates +all authentication operations including registration, login, token management, +password reset, and email verification. +""" + +from datetime import datetime + +from .password import PasswordManager +from .token import TokenManager +from ..exceptions import ( + InvalidCredentialsError, + UserNotFoundError, + AccountDisabledError, + ExpiredTokenError, + InvalidTokenError, + RevokedTokenError +) +from ..models.token import AccessTokenResponse, TokenData +from ..models.user import UserCreate, UserInDB, UserUpdate +from ..persistence.base import UserRepository, TokenRepository + + +class AuthService: + """ + Main authentication service. + + This service orchestrates all authentication-related operations by + coordinating between password management, token management, and + persistence layers. + + Attributes: + user_repository: Repository for user persistence operations. + token_repository: Repository for token persistence operations. + password_manager: Manager for password hashing and verification. + token_manager: Manager for token creation and validation. + """ + + def __init__( + self, + user_repository: UserRepository, + token_repository: TokenRepository, + password_manager: PasswordManager, + token_manager: TokenManager + ): + """ + Initialize the authentication service. + + Args: + user_repository: Repository for user persistence. + token_repository: Repository for token persistence. + jwt_secret: Secret key for JWT signing. + jwt_algorithm: JWT algorithm (default: HS256). + access_token_expire_minutes: Access token validity (default: 30). + refresh_token_expire_days: Refresh token validity (default: 7). + password_reset_token_expire_minutes: Reset token validity (default: 15). + password_hash_rounds: Bcrypt rounds (default: 12). + """ + self.user_repository = user_repository + self.token_repository = token_repository + self.password_manager = password_manager + self.token_manager = token_manager + + def register(self, user_data: UserCreate) -> UserInDB: + """ + Register a new user. + + This method creates a new user account with hashed password. + The user's email is initially unverified. + + Args: + user_data: User registration data including password. + + Returns: + The created user (without password). + + Raises: + UserAlreadyExistsError: If email is already registered. + + Example: + >>> user_data = UserCreate( + ... email="user@example.com", + ... username="john_doe", + ... password="SecurePass123!" + ... ) + >>> user = auth_service.register(user_data) + """ + hashed_password = self.password_manager.hash_password(user_data.password) + user = self.user_repository.create_user(user_data, hashed_password) + return user + + def login(self, email: str, password: str) -> tuple[UserInDB, AccessTokenResponse]: + """ + Authenticate a user and create tokens. + + This method verifies credentials, checks account status, and generates + both access and refresh tokens. + + Args: + email: User's email address. + password: User's plain text password. + + Returns: + Tuple of (user, tokens) where tokens contains access_token and refresh_token. + + Raises: + InvalidCredentialsError: If email or password is incorrect. + AccountDisabledError: If user account is disabled. + + Example: + >>> user, tokens = auth_service.login("user@example.com", "password") + >>> print(tokens.access_token) + """ + # Get user by email + user = self.user_repository.get_user_by_email(email) + if not user: + raise InvalidCredentialsError() + + # Verify password + if not self.password_manager.verify_password(password, user.hashed_password): + raise InvalidCredentialsError() + + # Check if account is active + if not user.is_active: + raise AccountDisabledError() + + # Create tokens + access_token = self.token_manager.create_access_token(user) + refresh_token = self.token_manager.create_refresh_token() + + # Store refresh token in database + token_data = TokenData( + token=refresh_token, + token_type="refresh", + user_id=user.id, + expires_at=self.token_manager.get_refresh_token_expiration(), + created_at=datetime.now(), + is_revoked=False + ) + self.token_repository.save_token(token_data) + + tokens = AccessTokenResponse( + access_token=access_token, + refresh_token=refresh_token, + token_type="bearer" + ) + + return user, tokens + + def refresh_access_token(self, refresh_token: str) -> AccessTokenResponse: + """ + Create a new access token using a refresh token. + + This method validates the refresh token and generates a new access token + without requiring the user to re-enter their password. + + Args: + refresh_token: The refresh token to exchange. + + Returns: + New access and refresh tokens. + + Raises: + InvalidTokenError: If refresh token is invalid. + ExpiredTokenError: If refresh token has expired. + RevokedTokenError: If refresh token has been revoked. + UserNotFoundError: If user no longer exists. + AccountDisabledError: If user account is disabled. + + Example: + >>> tokens = auth_service.refresh_access_token(old_refresh_token) + >>> print(tokens.access_token) + """ + # Validate refresh token + token_data = self.token_repository.get_token(refresh_token, "refresh") + if not token_data: + raise InvalidTokenError("Invalid refresh token") + + if token_data.is_revoked: + raise RevokedTokenError() + + if token_data.expires_at < datetime.now(): + raise ExpiredTokenError() + + # Get user + user = self.user_repository.get_user_by_id(token_data.user_id) + if not user: + raise UserNotFoundError() + + if not user.is_active: + raise AccountDisabledError() + + # Create new tokens + access_token = self.token_manager.create_access_token(user) + new_refresh_token = self.token_manager.create_refresh_token() + + # Revoke old refresh token + self.token_repository.revoke_token(refresh_token) + + # Store new refresh token + new_token_data = TokenData( + token=new_refresh_token, + token_type="refresh", + user_id=user.id, + expires_at=self.token_manager.get_refresh_token_expiration(), + created_at=datetime.now(), + is_revoked=False + ) + self.token_repository.save_token(new_token_data) + + return AccessTokenResponse( + access_token=access_token, + refresh_token=new_refresh_token, + token_type="bearer" + ) + + def logout(self, refresh_token: str) -> bool: + """ + Logout a user by revoking their refresh token. + + This prevents the refresh token from being used to obtain new + access tokens. The current access token will remain valid until + it expires naturally. + + Args: + refresh_token: The refresh token to revoke. + + Returns: + True if logout was successful, False if token not found. + + Example: + >>> auth_service.logout(refresh_token) + True + """ + return self.token_repository.revoke_token(refresh_token) + + def request_password_reset(self, email: str) -> str: + """ + Generate a password reset token for a user. + + This method creates a secure token that can be sent to the user + via email to reset their password. + + Args: + email: User's email address. + + Returns: + The password reset token to be sent via email. + + Raises: + UserNotFoundError: If email is not registered. + + Example: + >>> token = auth_service.request_password_reset("user@example.com") + >>> # Send token via email service + """ + user = self.user_repository.get_user_by_email(email) + if not user: + raise UserNotFoundError(f"No user found with email {email}") + + # Create reset token + reset_token = self.token_manager.create_password_reset_token() + + # Store token in database + token_data = TokenData( + token=reset_token, + token_type="password_reset", + user_id=user.id, + expires_at=self.token_manager.get_password_reset_token_expiration(), + created_at=datetime.now(), + is_revoked=False + ) + self.token_repository.save_token(token_data) + + return reset_token + + def reset_password(self, token: str, new_password: str) -> bool: + """ + Reset a user's password using a reset token. + + This method validates the reset token and updates the user's password. + All existing refresh tokens for the user are revoked for security. + + Args: + token: Password reset token. + new_password: New plain text password (will be hashed). + + Returns: + True if password was reset successfully. + + Raises: + InvalidTokenError: If reset token is invalid. + ExpiredTokenError: If reset token has expired. + RevokedTokenError: If reset token has been used. + UserNotFoundError: If user no longer exists. + + Example: + >>> auth_service.reset_password(token, "NewSecurePass123!") + True + """ + # Validate reset token + token_data = self.token_repository.get_token(token, "password_reset") + if not token_data: + raise InvalidTokenError("Invalid password reset token") + + if token_data.is_revoked: + raise RevokedTokenError("Password reset token has already been used") + + if token_data.expires_at < datetime.now(): + raise ExpiredTokenError("Password reset token has expired") + + # Get user + user = self.user_repository.get_user_by_id(token_data.user_id) + if not user: + raise UserNotFoundError() + + # Hash new password + hashed_password = self.password_manager.hash_password(new_password) + + # Update user password + updates = UserUpdate(password=hashed_password) + self.user_repository.update_user(user.id, updates) + + # Revoke the reset token + self.token_repository.revoke_token(token) + + # Revoke all user's refresh tokens for security + self.token_repository.revoke_all_user_tokens(user.id, "refresh") + + return True + + def request_email_verification(self, email: str) -> str: + """ + Generate an email verification token for a user. + + This method creates a JWT token that can be sent to the user + to verify their email address. + + Args: + email: User's email address. + + Returns: + The email verification token (JWT) to be sent via email. + + Raises: + UserNotFoundError: If email is not registered. + + Example: + >>> token = auth_service.request_email_verification("user@example.com") + >>> # Send token via email service + """ + user = self.user_repository.get_user_by_email(email) + if not user: + raise UserNotFoundError(f"No user found with email {email}") + + return self.token_manager.create_email_verification_token(email) + + def verify_email(self, token: str) -> bool: + """ + Verify a user's email address using a verification token. + + This method decodes the JWT token, extracts the email, and marks + the user's email as verified. + + Args: + token: Email verification token (JWT). + + Returns: + True if email was verified successfully. + + Raises: + InvalidTokenError: If verification token is invalid. + ExpiredTokenError: If verification token has expired. + UserNotFoundError: If user no longer exists. + + Example: + >>> auth_service.verify_email(token) + True + """ + # Decode and validate token + email = self.token_manager.decode_email_verification_token(token) + + # Get user + user = self.user_repository.get_user_by_email(email) + if not user: + raise UserNotFoundError(f"No user found with email {email}") + + # Update user's verified status + updates = UserUpdate(is_verified=True) + self.user_repository.update_user(user.id, updates) + + return True + + def get_current_user(self, access_token: str) -> UserInDB: + """ + Retrieve the current user from an access token. + + This method decodes and validates the JWT access token and + retrieves the corresponding user from the database. + + Args: + access_token: JWT access token. + + Returns: + The user associated with the token. + + Raises: + InvalidTokenError: If access token is invalid. + ExpiredTokenError: If access token has expired. + UserNotFoundError: If user no longer exists. + AccountDisabledError: If user account is disabled. + + Example: + >>> user = auth_service.get_current_user(access_token) + >>> print(user.email) + """ + # Decode and validate token + token_payload = self.token_manager.decode_access_token(access_token) + + # Get user + user = self.user_repository.get_user_by_id(token_payload.sub) + if not user: + raise UserNotFoundError() + + if not user.is_active: + raise AccountDisabledError() + + return user + + +def get_default_sqlite_auth_service(db_path: str, jwt_secret: str) -> AuthService: + from my_auth.persistence.sqlite import SQLiteUserRepository + from my_auth.persistence.sqlite import SQLiteTokenRepository + user_repository = SQLiteUserRepository(db_path=db_path) + token_repository = SQLiteTokenRepository(db_path=db_path) + password_manager = PasswordManager() + token_manager = TokenManager(jwt_secret=jwt_secret) + return AuthService( + user_repository=user_repository, + token_repository=token_repository, + password_manager=password_manager, + token_manager=token_manager + ) diff --git a/src/my_auth/core/password.py b/src/my_auth/core/password.py new file mode 100644 index 0000000..44749d7 --- /dev/null +++ b/src/my_auth/core/password.py @@ -0,0 +1,87 @@ +""" +Password management for authentication module. + +This module provides password hashing and verification functionality using +bcrypt. It handles secure password storage and validation. +""" + +from passlib.context import CryptContext + + +class PasswordManager: + """ + Manager for password hashing and verification operations. + + This class uses bcrypt for secure password hashing with configurable + cost factor (rounds). Higher rounds provide better security but slower + performance. + + Attributes: + rounds: Bcrypt cost factor (default: 12). Higher values increase + security but also increase computation time. + """ + + def __init__(self, rounds: int = 12): + """ + Initialize the password manager. + + Args: + rounds: Bcrypt cost factor (4-31). Default is 12 which provides + a good balance between security and performance. + """ + if rounds < 4 or rounds > 31: + raise ValueError("Bcrypt rounds must be between 4 and 31") + + self.rounds = rounds + self._context = CryptContext( + schemes=["bcrypt"], + deprecated="auto", + bcrypt__rounds=self.rounds + ) + + def hash_password(self, password: str) -> str: + """ + Hash a plain text password using bcrypt. + + This method creates a secure hash of the password that can be + safely stored in the database. Each hash includes a salt, making + rainbow table attacks ineffective. + + Args: + password: Plain text password to hash. + + Returns: + Bcrypt hashed password string. + + Example: + >>> pm = PasswordManager() + >>> hashed = pm.hash_password("SecurePassword123!") + >>> print(hashed) + $2b$12$abcdefghijklmnopqrstuvwxyz... + """ + return self._context.hash(password) + + def verify_password(self, plain_password: str, hashed_password: str) -> bool: + """ + Verify a plain text password against a hashed password. + + This method performs constant-time comparison to prevent timing + attacks. It returns True only if the plain password matches the + hash. + + Args: + plain_password: Plain text password to verify. + hashed_password: Bcrypt hashed password from database. + + Returns: + True if password matches, False otherwise. + + Example: + >>> pm = PasswordManager() + >>> hashed = pm.hash_password("SecurePassword123!") + >>> pm.verify_password("SecurePassword123!", hashed) + True + >>> pm.verify_password("WrongPassword", hashed) + False + """ + return self._context.verify(plain_password, hashed_password) diff --git a/src/my_auth/core/token.py b/src/my_auth/core/token.py new file mode 100644 index 0000000..64d2855 --- /dev/null +++ b/src/my_auth/core/token.py @@ -0,0 +1,265 @@ +""" +Token management for authentication module. + +This module provides functionality for creating and validating different types +of tokens: JWT access tokens, opaque refresh tokens, password reset tokens, +and email verification tokens. +""" + +import secrets +from datetime import datetime, timedelta + +from jose import JWTError, jwt + +from ..exceptions import InvalidTokenError, ExpiredTokenError +from ..models.token import TokenPayload +from ..models.user import UserInDB + + +class TokenManager: + """ + Manager for token creation and validation operations. + + This class handles multiple token types: + - JWT access tokens (stateless, short-lived) + - Opaque refresh tokens (stored in DB, long-lived) + - Opaque password reset tokens (stored in DB, short-lived) + - JWT email verification tokens (stateless) + + Attributes: + jwt_secret: Secret key for signing JWT tokens. + jwt_algorithm: Algorithm for JWT encoding (default: HS256). + access_token_expire_minutes: Validity duration for access tokens. + refresh_token_expire_days: Validity duration for refresh tokens. + password_reset_token_expire_minutes: Validity duration for reset tokens. + """ + + def __init__( + self, + jwt_secret: str, + jwt_algorithm: str = "HS256", + access_token_expire_minutes: int = 30, + refresh_token_expire_days: int = 7, + password_reset_token_expire_minutes: int = 15 + ): + """ + Initialize the token manager. + + Args: + jwt_secret: Secret key for signing JWT tokens. + jwt_algorithm: Algorithm for JWT encoding (default: HS256). + access_token_expire_minutes: Validity duration for access tokens. + refresh_token_expire_days: Validity duration for refresh tokens. + password_reset_token_expire_minutes: Validity duration for reset tokens. + """ + if not jwt_secret: + raise ValueError("JWT secret cannot be empty") + + self.jwt_secret = jwt_secret + self.jwt_algorithm = jwt_algorithm + self.access_token_expire_minutes = access_token_expire_minutes + self.refresh_token_expire_days = refresh_token_expire_days + self.password_reset_token_expire_minutes = password_reset_token_expire_minutes + + def create_access_token(self, user: UserInDB) -> str: + """ + Create a JWT access token for a user. + + The token contains the user's ID and email, and is signed with the + configured secret. It expires after the configured duration. + + Args: + user: The user for whom to create the token. + + Returns: + Encoded JWT access token string. + + Example: + >>> tm = TokenManager(jwt_secret="secret") + >>> token = tm.create_access_token(user) + >>> print(token) + eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9... + """ + expires_at = datetime.now() + timedelta(minutes=self.access_token_expire_minutes) + + payload = { + "sub": user.id, + "email": user.email, + "exp": int(expires_at.timestamp()), + "type": "access" + } + + return jwt.encode(payload, self.jwt_secret, algorithm=self.jwt_algorithm) + + def create_refresh_token(self) -> str: + """ + Create an opaque refresh token. + + This generates a cryptographically secure random token that will be + stored in the database along with its expiration time and user association. + + Returns: + Secure random token string (64 characters hex). + + Example: + >>> tm = TokenManager(jwt_secret="secret") + >>> token = tm.create_refresh_token() + >>> len(token) + 64 + """ + return secrets.token_hex(32) + + def create_password_reset_token(self) -> str: + """ + Create an opaque password reset token. + + This generates a cryptographically secure random token that will be + stored in the database with a short expiration time (15 minutes by default). + + Returns: + Secure random token string (64 characters hex). + + Example: + >>> tm = TokenManager(jwt_secret="secret") + >>> token = tm.create_password_reset_token() + >>> len(token) + 64 + """ + return secrets.token_hex(32) + + def create_email_verification_token(self, email: str) -> str: + """ + Create a JWT email verification token. + + This token is stateless (not stored in DB) and contains the email + to be verified. It has no explicit expiration in the payload but + should be validated for reasonable freshness. + + Args: + email: The email address to encode in the token. + + Returns: + Encoded JWT verification token string. + + Example: + >>> tm = TokenManager(jwt_secret="secret") + >>> token = tm.create_email_verification_token("user@example.com") + """ + # Set expiration to 7 days for email verification + expires_at = datetime.now() + timedelta(days=7) + + payload = { + "email": email, + "exp": int(expires_at.timestamp()), + "type": "email_verification" + } + + return jwt.encode(payload, self.jwt_secret, algorithm=self.jwt_algorithm) + + def decode_access_token(self, token: str) -> TokenPayload: + """ + Decode and validate a JWT access token. + + This method verifies the token signature, checks expiration, + and validates the token type. + + Args: + token: The JWT access token to decode. + + Returns: + TokenPayload containing the decoded claims. + + Raises: + InvalidTokenError: If token is malformed or has invalid signature. + ExpiredTokenError: If token has expired. + + Example: + >>> tm = TokenManager(jwt_secret="secret") + >>> payload = tm.decode_access_token(token) + >>> print(payload.sub) # user_id + """ + try: + payload = jwt.decode( + token, + self.jwt_secret, + algorithms=[self.jwt_algorithm] + ) + + # Validate token type + if payload.get("type") != "access": + raise InvalidTokenError("Invalid token type") + + return TokenPayload( + sub=payload.get("sub"), + email=payload.get("email"), + exp=payload.get("exp"), + type=payload.get("type") + ) + + except JWTError as e: + if "expired" in str(e).lower(): + raise ExpiredTokenError("Access token has expired") + raise InvalidTokenError(f"Invalid access token: {str(e)}") + + def decode_email_verification_token(self, token: str) -> str: + """ + Decode and validate an email verification JWT token. + + This method verifies the token signature, checks expiration, + and extracts the email address. + + Args: + token: The JWT email verification token to decode. + + Returns: + The email address contained in the token. + + Raises: + InvalidTokenError: If token is malformed or has invalid signature. + ExpiredTokenError: If token has expired. + + Example: + >>> tm = TokenManager(jwt_secret="secret") + >>> email = tm.decode_email_verification_token(token) + >>> print(email) + user@example.com + """ + try: + payload = jwt.decode( + token, + self.jwt_secret, + algorithms=[self.jwt_algorithm] + ) + + # Validate token type + if payload.get("type") != "email_verification": + raise InvalidTokenError("Invalid token type") + + email = payload.get("email") + if not email: + raise InvalidTokenError("Email not found in token") + + return email + + except JWTError as e: + if "expired" in str(e).lower(): + raise ExpiredTokenError("Email verification token has expired") + raise InvalidTokenError(f"Invalid email verification token: {str(e)}") + + def get_refresh_token_expiration(self) -> datetime: + """ + Calculate expiration datetime for a refresh token. + + Returns: + Datetime when a newly created refresh token should expire. + """ + return datetime.now() + timedelta(days=self.refresh_token_expire_days) + + def get_password_reset_token_expiration(self) -> datetime: + """ + Calculate expiration datetime for a password reset token. + + Returns: + Datetime when a newly created reset token should expire. + """ + return datetime.now() + timedelta(minutes=self.password_reset_token_expire_minutes) diff --git a/src/my_auth/exceptions.py b/src/my_auth/exceptions.py new file mode 100644 index 0000000..1e0bc7a --- /dev/null +++ b/src/my_auth/exceptions.py @@ -0,0 +1,198 @@ +""" +Custom exceptions for authentication module. + +This module defines all custom exceptions used throughout the authentication +system. Each exception is designed to be caught and converted to appropriate +HTTP responses by FastAPI exception handlers. +""" + + +class AuthError(Exception): + """ + Base exception for all authentication-related errors. + + This base class provides common attributes for all auth exceptions + and facilitates centralized error handling. + + Attributes: + message: Human-readable error message. + status_code: HTTP status code associated with this error. + error_code: Unique error code for client-side handling (optional). + """ + + def __init__(self, message: str, status_code: int = 500, error_code: str = None): + """ + Initialize the authentication error. + + Args: + message: Human-readable error message. + status_code: HTTP status code (default: 500). + error_code: Unique error code for client identification. + """ + self.message = message + self.status_code = status_code + self.error_code = error_code or self.__class__.__name__ + super().__init__(self.message) + + +class InvalidCredentialsError(AuthError): + """ + Exception raised when login credentials are invalid. + + This exception is raised during authentication when the provided + email or password does not match any user in the database. + + HTTP Status: 401 Unauthorized + """ + + def __init__(self, message: str = "Invalid email or password"): + """ + Initialize invalid credentials error. + + Args: + message: Custom error message (default: "Invalid email or password"). + """ + super().__init__(message=message, status_code=401) + + +class UserAlreadyExistsError(AuthError): + """ + Exception raised when attempting to create a user with an existing email. + + This exception is raised during registration when the provided email + address is already associated with an existing user account. + + HTTP Status: 409 Conflict + """ + + def __init__(self, message: str = "User with this email already exists"): + """ + Initialize user already exists error. + + Args: + message: Custom error message (default: "User with this email already exists"). + """ + super().__init__(message=message, status_code=409) + + +class UserNotFoundError(AuthError): + """ + Exception raised when a requested user cannot be found. + + This exception is raised when attempting to retrieve, update, or delete + a user that does not exist in the database. + + HTTP Status: 404 Not Found + """ + + def __init__(self, message: str = "User not found"): + """ + Initialize user not found error. + + Args: + message: Custom error message (default: "User not found"). + """ + super().__init__(message=message, status_code=404) + + +class InvalidTokenError(AuthError): + """ + Exception raised when a token is invalid or malformed. + + This exception is raised when a token cannot be decoded, has an invalid + signature, or contains invalid data. This applies to JWT tokens, refresh + tokens, and password reset tokens. + + HTTP Status: 401 Unauthorized + """ + + def __init__(self, message: str = "Invalid token"): + """ + Initialize invalid token error. + + Args: + message: Custom error message (default: "Invalid token"). + """ + super().__init__(message=message, status_code=401) + + +class ExpiredTokenError(AuthError): + """ + Exception raised when a token has expired. + + This exception is raised when attempting to use a token that has + passed its expiration time. This applies to all token types. + + HTTP Status: 401 Unauthorized + """ + + def __init__(self, message: str = "Token has expired"): + """ + Initialize expired token error. + + Args: + message: Custom error message (default: "Token has expired"). + """ + super().__init__(message=message, status_code=401) + + +class RevokedTokenError(AuthError): + """ + Exception raised when attempting to use a revoked token. + + This exception is raised when a token has been explicitly revoked, + typically after logout or when a user's tokens are invalidated for + security reasons. + + HTTP Status: 401 Unauthorized + """ + + def __init__(self, message: str = "Token has been revoked"): + """ + Initialize revoked token error. + + Args: + message: Custom error message (default: "Token has been revoked"). + """ + super().__init__(message=message, status_code=401) + + +class EmailNotVerifiedError(AuthError): + """ + Exception raised when an action requires email verification. + + This exception is raised when a user attempts to perform an action + that requires their email to be verified, but their account email + has not been verified yet. + + HTTP Status: 403 Forbidden + """ + + def __init__(self, message: str = "Email address is not verified"): + """ + Initialize email not verified error. + + Args: + message: Custom error message (default: "Email address is not verified"). + """ + super().__init__(message=message, status_code=403) + + +class AccountDisabledError(AuthError): + """ + Exception raised when attempting to access a disabled account. + + This exception is raised when a user attempts to login or perform + actions with an account that has been disabled by an administrator. + + HTTP Status: 403 Forbidden + """ + + def __init__(self, message: str = "Account has been disabled"): + """ + Initialize account disabled error. + + Args: + message: Custom error message (default: "Account has been disabled"). + """ + super().__init__(message=message, status_code=403) diff --git a/src/my_auth/models/__init__.py b/src/my_auth/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/my_auth/models/email_verification.py b/src/my_auth/models/email_verification.py new file mode 100644 index 0000000..07290f5 --- /dev/null +++ b/src/my_auth/models/email_verification.py @@ -0,0 +1,91 @@ +""" +Email verification and password reset models for authentication module. + +This module defines Pydantic models for email verification and password +reset operations including request and confirmation models. +""" + +from pydantic import BaseModel, EmailStr, field_validator + +from my_auth.models.validators import validate_password_strength + + +class EmailVerificationRequest(BaseModel): + """ + Request model for email verification. + + This model is used when a user requests an email verification link + to be sent to their email address. + + Attributes: + email: The email address to send the verification link to. + """ + + email: EmailStr + + +class EmailVerificationConfirm(BaseModel): + """ + Confirmation model for email verification. + + This model is used when a user clicks the verification link and + submits the token to confirm their email address. + + Attributes: + token: JWT token received in the verification email. + """ + + token: str + + +class PasswordResetRequest(BaseModel): + """ + Request model for password reset. + + This model is used when a user requests a password reset link + to be sent to their email address. + + Attributes: + email: The email address to send the password reset link to. + """ + + email: EmailStr + + +class PasswordResetConfirm(BaseModel): + """ + Confirmation model for password reset. + + This model is used when a user submits a new password along with + their reset token. The new password is validated with the same + strict rules as user registration. + + Attributes: + token: Random secure token received in the password reset email. + new_password: The new password that must meet security requirements: + - Minimum 8 characters + - At least 1 uppercase letter + - At least 1 lowercase letter + - At least 1 digit + - At least 1 special character + """ + + token: str + new_password: str + + @field_validator('new_password') + @classmethod + def validate_password(cls, value: str) -> str: + """ + Validate new password meets security requirements. + + Args: + value: The new password to validate. + + Returns: + The validated password. + + Raises: + ValueError: If password does not meet security requirements. + """ + return validate_password_strength(value) diff --git a/src/my_auth/models/token.py b/src/my_auth/models/token.py new file mode 100644 index 0000000..c8da851 --- /dev/null +++ b/src/my_auth/models/token.py @@ -0,0 +1,88 @@ +""" +Token models for authentication module. + +This module defines Pydantic models for token-related operations including +JWT payloads, API responses, requests, and database storage. +""" + +from datetime import datetime +from typing import Literal + +from pydantic import BaseModel, Field + + +class TokenPayload(BaseModel): + """ + JWT token payload structure. + + This model represents the data contained within a JWT access token. + It follows the standard JWT claims format. + + Attributes: + sub: Subject (user ID). + email: User's email address. + exp: Expiration timestamp (Unix timestamp). + type: Token type identifier (always "access" for access tokens). + """ + + sub: str + email: str + exp: int + type: Literal["access"] = "access" + + +class AccessTokenResponse(BaseModel): + """ + Response model after successful authentication. + + This model is returned after login or token refresh operations. + It contains both access and refresh tokens. + + Attributes: + access_token: JWT access token for API authentication. + refresh_token: Opaque refresh token for obtaining new access tokens. + token_type: OAuth2 token type (always "bearer"). + """ + + access_token: str + refresh_token: str + token_type: str = "bearer" + + +class RefreshTokenRequest(BaseModel): + """ + Request model for token refresh operation. + + This model is used when a client wants to obtain a new access token + using their refresh token. + + Attributes: + refresh_token: The refresh token to exchange for a new access token. + """ + + refresh_token: str + + +class TokenData(BaseModel): + """ + Token storage model for database. + + This model represents tokens that need to be persisted in the database. + It uses a discriminator field to distinguish between different token types + (refresh tokens and password reset tokens) in a single collection/table. + + Attributes: + token: The token string (random secure string for refresh and reset). + token_type: Discriminator field ("refresh" or "password_reset"). + user_id: ID of the user this token belongs to. + expires_at: When the token expires. + created_at: When the token was created. + is_revoked: Whether the token has been revoked (for logout/security). + """ + + token: str + token_type: Literal["refresh", "password_reset"] + user_id: str + expires_at: datetime + created_at: datetime + is_revoked: bool = False \ No newline at end of file diff --git a/src/my_auth/models/user.py b/src/my_auth/models/user.py new file mode 100644 index 0000000..c6a40a3 --- /dev/null +++ b/src/my_auth/models/user.py @@ -0,0 +1,189 @@ +""" +User models for authentication module. + +This module defines Pydantic models for user-related operations including +creation, updates, database representation, and API responses. +""" + +from datetime import datetime +from typing import Optional + +from pydantic import BaseModel, EmailStr, Field, field_validator + +from my_auth.models.validators import validate_password_strength, validate_username_not_empty + + +class UserBase(BaseModel): + """ + Base user model with common fields. + + This model contains the shared fields used across different user representations. + It serves as the foundation for other user models. + + Attributes: + email: User's email address (unique in database). + username: User's display name (required, non-unique). + roles: List of user roles (free-form strings, no defaults). + user_settings: Dictionary for storing custom user settings. + """ + + email: EmailStr + username: str + roles: list[str] = Field(default_factory=list) + user_settings: dict = Field(default_factory=dict) + + +class UserCreate(UserBase): + """ + Model for user creation (registration). + + This model extends UserBase with a password field and enforces strict + password validation rules. + + Attributes: + password: Plain text password that must meet security requirements: + - Minimum 8 characters + - At least 1 uppercase letter + - At least 1 lowercase letter + - At least 1 digit + - At least 1 special character + """ + + password: str + + @field_validator('password') + @classmethod + def validate_password(cls, value: str) -> Optional[str]: + """ + Validate password meets security requirements. + + Args: + value: The password to validate. + + Returns: + The validated password. + + Raises: + ValueError: If password does not meet security requirements. + """ + return validate_password_strength(value) + + @field_validator('username') + @classmethod + def validate_username(cls, value: str) -> str: + """ + Validate username is not empty and has reasonable length. + + Args: + value: The username to validate. + + Returns: + The validated username. + + Raises: + ValueError: If username is empty or too long. + """ + return validate_username_not_empty(value) + + +class UserUpdate(BaseModel): + """ + Model for user updates. + + All fields are optional to allow partial updates. Password updates + are validated with the same strict rules as UserCreate. + + Attributes: + email: Optional new email address. + username: Optional new username. + password: Optional new password (will be hashed). + roles: Optional new roles list. + user_settings: Optional new settings dict. + is_verified: Optional email verification status. + is_active: Optional account active status. + """ + + email: Optional[EmailStr] = None + username: Optional[str] = None + password: Optional[str] = None + roles: Optional[list[str]] = None + user_settings: Optional[dict] = None + is_verified: Optional[bool] = None + is_active: Optional[bool] = None + + @field_validator('password') + @classmethod + def validate_password_strength(cls, value: Optional[str]) -> Optional[str]: + """ + Validate password meets security requirements if provided. + + Args: + value: The password to validate (can be None). + + Returns: + The validated password or None. + + Raises: + ValueError: If password is provided but does not meet security requirements. + """ + return validate_password_strength(value) + + @field_validator('username') + @classmethod + def validate_username(cls, value: Optional[str]) -> Optional[str]: + """ + Validate username if provided. + + Args: + value: The username to validate (can be None). + + Returns: + The validated username or None. + + Raises: + ValueError: If username is provided but empty or too long. + """ + return validate_username_not_empty(value) + + +class UserInDB(UserBase): + """ + Complete user model as stored in database. + + This model represents the full user entity including security-sensitive + fields like hashed_password. It should not be directly exposed via API. + + Attributes: + id: Unique user identifier (string for database compatibility). + hashed_password: Bcrypt hashed password. + is_verified: Whether the user's email has been verified. + is_active: Whether the user account is active. + created_at: Timestamp of account creation. + updated_at: Timestamp of last update. + """ + + id: str + hashed_password: str + is_verified: bool = False + is_active: bool = True + created_at: datetime + updated_at: datetime + + +class UserResponse(UserBase): + """ + User model for API responses. + + This is a safe representation of the user without sensitive fields + like hashed_password. Used for API responses where user information + needs to be exposed. + + Attributes: + id: Unique user identifier. + created_at: Timestamp of account creation. + updated_at: Timestamp of last update. + """ + + id: str + created_at: datetime + updated_at: datetime diff --git a/src/my_auth/models/validators.py b/src/my_auth/models/validators.py new file mode 100644 index 0000000..b6d685d --- /dev/null +++ b/src/my_auth/models/validators.py @@ -0,0 +1,57 @@ +from typing import Optional + +def validate_password_strength(password: str) -> Optional[str]: + """ + Validate password meets security requirements. + + Args: + password: The password to validate. + + Returns: + The validated password. + + Raises: + ValueError: If password does not meet security requirements. + """ + if password is None: + return password + + if len(password) < 8: + raise ValueError('Password must be at least 8 characters long') + + if not any(char.isupper() for char in password): + raise ValueError('Password must contain at least one uppercase letter') + + if not any(char.islower() for char in password): + raise ValueError('Password must contain at least one lowercase letter') + + if not any(char.isdigit() for char in password): + raise ValueError('Password must contain at least one digit') + + special_characters = "!@#$%^&*()_+-=[]{}|;:,.<>?/~`" + if not any(char in special_characters for char in password): + raise ValueError('Password must contain at least one special character') + + return password + + +def validate_username_not_empty(user_name: str) -> str: + """ + Validate username is not empty and has reasonable length. + + Args: + user_name: The username to validate. + + Returns: + The validated username. + + Raises: + ValueError: If username is empty or too long. + """ + if not user_name or not user_name.strip(): + raise ValueError('Username cannot be empty') + + if len(user_name) > 100: + raise ValueError('Username cannot exceed 100 characters') + + return user_name.strip() diff --git a/src/my_auth/persistence/__init__.py b/src/my_auth/persistence/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/my_auth/persistence/base.py b/src/my_auth/persistence/base.py new file mode 100644 index 0000000..7334961 --- /dev/null +++ b/src/my_auth/persistence/base.py @@ -0,0 +1,226 @@ +""" +Abstract base classes for persistence layer. + +This module defines the interfaces that all database implementations must +follow. It provides abstract base classes for user and token repositories, +ensuring consistency across different database backends (MongoDB, SQLite, +PostgreSQL, custom engines, etc.). +""" + +from abc import ABC, abstractmethod + +from ..models.token import TokenData +from ..models.user import UserCreate, UserInDB, UserUpdate + + +class UserRepository(ABC): + """ + Abstract base class for user persistence operations. + + This interface defines all methods required for user management in the + database. Concrete implementations must provide all these methods for + their specific database backend. + + All methods should raise appropriate exceptions (UserNotFoundError, + UserAlreadyExistsError, etc.) when operations fail. + """ + + @abstractmethod + def create_user(self, user_data: UserCreate, hashed_password: str) -> UserInDB: + """ + Create a new user in the database. + + Args: + user_data: User information from registration. + hashed_password: Pre-hashed password (hashing is done by the service layer). + + Returns: + The created user with database-assigned ID and timestamps. + + Raises: + UserAlreadyExistsError: If email already exists in database. + """ + pass + + @abstractmethod + def get_user_by_email(self, email: str) -> UserInDB | None: + """ + Retrieve a user by their email address. + + Args: + email: The email address to search for. + + Returns: + The user if found, None otherwise. + """ + pass + + @abstractmethod + def get_user_by_id(self, user_id: str) -> UserInDB | None: + """ + Retrieve a user by their unique identifier. + + Args: + user_id: The unique user identifier. + + Returns: + The user if found, None otherwise. + """ + pass + + @abstractmethod + def update_user(self, user_id: str, updates: UserUpdate) -> UserInDB: + """ + Update an existing user's information. + + Only fields present in the updates object (non-None) should be updated. + The updated_at timestamp should be automatically set to the current time. + + Args: + user_id: The unique user identifier. + updates: Pydantic model containing fields to update. + + Returns: + The updated user. + + Raises: + UserNotFoundError: If user does not exist. + """ + pass + + @abstractmethod + def delete_user(self, user_id: str) -> bool: + """ + Delete a user from the database. + + Implementation can choose between soft delete (setting is_active=False) + or hard delete (removing from database). The choice depends on the + application's requirements for data retention. + + Args: + user_id: The unique user identifier. + + Returns: + True if user was deleted, False if user was not found. + """ + pass + + @abstractmethod + def email_exists(self, email: str) -> bool: + """ + Check if an email address is already registered. + + This is useful for validating registration requests without + retrieving the full user object. + + Args: + email: The email address to check. + + Returns: + True if email exists, False otherwise. + """ + pass + + +class TokenRepository(ABC): + """ + Abstract base class for token persistence operations. + + This interface defines all methods required for token management in the + database. It handles both refresh tokens and password reset tokens using + a discriminator field (token_type). + + All methods should raise appropriate exceptions (InvalidTokenError, + ExpiredTokenError, etc.) when operations fail. + """ + + @abstractmethod + def save_token(self, token_data: TokenData) -> None: + """ + Save a token to the database. + + This is used for both refresh tokens and password reset tokens. + The token_type field in TokenData distinguishes between them. + + Args: + token_data: Complete token information including type, user_id, expiration. + """ + pass + + @abstractmethod + def get_token(self, token: str, token_type: str) -> TokenData | None: + """ + Retrieve a token from the database. + + Args: + token: The token string to search for. + token_type: Type of token ("refresh" or "password_reset"). + + Returns: + The token data if found, None otherwise. + """ + pass + + @abstractmethod + def revoke_token(self, token: str) -> bool: + """ + Revoke a specific token. + + This sets the is_revoked flag to True, preventing the token from + being used again without physically deleting it (useful for audit trails). + + Args: + token: The token string to revoke. + + Returns: + True if token was revoked, False if token was not found. + """ + pass + + @abstractmethod + def revoke_all_user_tokens(self, user_id: str, token_type: str) -> int: + """ + Revoke all tokens of a specific type for a user. + + This is useful for logout-all-devices functionality or when a user's + password is changed (invalidating all refresh tokens). + + Args: + user_id: The user whose tokens should be revoked. + token_type: Type of tokens to revoke ("refresh" or "password_reset"). + + Returns: + Number of tokens revoked. + """ + pass + + @abstractmethod + def delete_expired_tokens(self) -> int: + """ + Delete all expired tokens from the database. + + This maintenance operation should be called periodically to keep + the tokens collection clean. It permanently removes tokens that + have passed their expiration time. + + Returns: + Number of tokens deleted. + """ + pass + + @abstractmethod + def is_token_valid(self, token: str, token_type: str) -> bool: + """ + Check if a token is valid (exists, not revoked, not expired). + + This is a convenience method that combines multiple checks into + a single boolean result. + + Args: + token: The token string to validate. + token_type: Type of token ("refresh" or "password_reset"). + + Returns: + True if token is valid and can be used, False otherwise. + """ + pass diff --git a/src/my_auth/persistence/sqlite.py b/src/my_auth/persistence/sqlite.py new file mode 100644 index 0000000..c4a4ec0 --- /dev/null +++ b/src/my_auth/persistence/sqlite.py @@ -0,0 +1,611 @@ +""" +SQLite implementation of persistence layer. + +This module provides SQLite database implementations for user and token +repositories. It uses the standard library sqlite3 module and handles +JSON serialization for complex fields. +""" + +import json +import sqlite3 +from datetime import datetime +from typing import Optional +from uuid import uuid4 + +from .base import UserRepository, TokenRepository +from ..models.user import UserCreate, UserInDB, UserUpdate +from ..models.token import TokenData +from ..exceptions import UserAlreadyExistsError, UserNotFoundError + + +class SQLiteUserRepository(UserRepository): + """ + SQLite implementation of UserRepository. + + This implementation uses sqlite3 to manage user data. JSON fields + (roles, user_settings) are serialized/deserialized automatically. + The database schema is created automatically on initialization. + + Attributes: + db_path: Path to the SQLite database file. + """ + + def __init__(self, db_path: str): + """ + Initialize SQLite user repository. + + Creates the users table and indexes if they don't exist. + + Args: + db_path: Path to the SQLite database file. + """ + self.db_path = db_path + self._create_tables() + + def _create_tables(self) -> None: + """ + Create users table and indexes if they don't exist. + + The table schema includes: + - id: Primary key (UUID string) + - email: Unique, indexed for fast lookups + - username: Non-unique display name + - hashed_password: Bcrypt hash + - roles: JSON array of role strings + - user_settings: JSON object for custom settings + - is_verified: Boolean (stored as INTEGER) + - is_active: Boolean (stored as INTEGER) + - created_at: ISO format timestamp + - updated_at: ISO format timestamp + """ + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + + # Create users table + cursor.execute(""" + CREATE TABLE IF NOT EXISTS users + ( + id + TEXT + PRIMARY + KEY, + email + TEXT + UNIQUE + NOT + NULL, + username + TEXT + NOT + NULL, + hashed_password + TEXT + NOT + NULL, + roles + TEXT + NOT + NULL, + user_settings + TEXT + NOT + NULL, + is_verified + INTEGER + NOT + NULL + DEFAULT + 0, + is_active + INTEGER + NOT + NULL + DEFAULT + 1, + created_at + TEXT + NOT + NULL, + updated_at + TEXT + NOT + NULL + ) + """) + + # Create index on email for fast lookups + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_users_email + ON users(email) + """) + + conn.commit() + + def _row_to_user(self, row: tuple) -> UserInDB: + """ + Convert a database row to a UserInDB model. + + Args: + row: Database row tuple. + + Returns: + UserInDB model instance. + """ + return UserInDB( + id=row[0], + email=row[1], + username=row[2], + hashed_password=row[3], + roles=json.loads(row[4]), + user_settings=json.loads(row[5]), + is_verified=bool(row[6]), + is_active=bool(row[7]), + created_at=datetime.fromisoformat(row[8]), + updated_at=datetime.fromisoformat(row[9]) + ) + + def create_user(self, user_data: UserCreate, hashed_password: str) -> UserInDB: + """ + Create a new user in the database. + + Args: + user_data: User information from registration. + hashed_password: Pre-hashed password. + + Returns: + The created user with generated ID and timestamps. + + Raises: + UserAlreadyExistsError: If email already exists. + """ + if self.email_exists(user_data.email): + raise UserAlreadyExistsError(f"User with email {user_data.email} already exists") + + user_id = str(uuid4()) + now = datetime.utcnow() + + user = UserInDB( + id=user_id, + email=user_data.email, + username=user_data.username, + hashed_password=hashed_password, + roles=user_data.roles, + user_settings=user_data.user_settings, + is_verified=False, + is_active=True, + created_at=now, + updated_at=now + ) + + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + cursor.execute(""" + INSERT INTO users (id, email, username, hashed_password, roles, + user_settings, is_verified, is_active, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, ( + user.id, + user.email, + user.username, + user.hashed_password, + json.dumps(user.roles), + json.dumps(user.user_settings), + int(user.is_verified), + int(user.is_active), + user.created_at.isoformat(), + user.updated_at.isoformat() + )) + conn.commit() + + return user + + def get_user_by_email(self, email: str) -> Optional[UserInDB]: + """ + Retrieve a user by their email address. + + Args: + email: The email address to search for. + + Returns: + The user if found, None otherwise. + """ + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + cursor.execute(""" + SELECT id, + email, + username, + hashed_password, + roles, + user_settings, + is_verified, + is_active, + created_at, + updated_at + FROM users + WHERE email = ? + """, (email,)) + + row = cursor.fetchone() + if row: + return self._row_to_user(row) + return None + + def get_user_by_id(self, user_id: str) -> Optional[UserInDB]: + """ + Retrieve a user by their unique identifier. + + Args: + user_id: The unique user identifier. + + Returns: + The user if found, None otherwise. + """ + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + cursor.execute(""" + SELECT id, + email, + username, + hashed_password, + roles, + user_settings, + is_verified, + is_active, + created_at, + updated_at + FROM users + WHERE id = ? + """, (user_id,)) + + row = cursor.fetchone() + if row: + return self._row_to_user(row) + return None + + def update_user(self, user_id: str, updates: UserUpdate) -> UserInDB: + """ + Update an existing user's information. + + Only non-None fields in the updates object are applied. + The updated_at timestamp is automatically set to current time. + + Args: + user_id: The unique user identifier. + updates: Pydantic model containing fields to update. + + Returns: + The updated user. + + Raises: + UserNotFoundError: If user does not exist. + """ + user = self.get_user_by_id(user_id) + if not user: + raise UserNotFoundError(f"User with id {user_id} not found") + + # Build update query dynamically based on non-None fields + update_fields = [] + update_values = [] + + update_data = updates.model_dump(exclude_unset=True) + + for field, value in update_data.items(): + if value is not None: + if field == "roles": + update_fields.append("roles = ?") + update_values.append(json.dumps(value)) + elif field == "user_settings": + update_fields.append("user_settings = ?") + update_values.append(json.dumps(value)) + elif field == "password": + # Password should be hashed before calling this method + # This field should contain hashed_password value + update_fields.append("hashed_password = ?") + update_values.append(value) + elif field in ["is_verified", "is_active"]: + update_fields.append(f"{field} = ?") + update_values.append(int(value)) + else: + update_fields.append(f"{field} = ?") + update_values.append(value) + + if not update_fields: + return user + + # Always update the updated_at timestamp + update_fields.append("updated_at = ?") + update_values.append(datetime.utcnow().isoformat()) + + # Add user_id for WHERE clause + update_values.append(user_id) + + query = f"UPDATE users SET {', '.join(update_fields)} WHERE id = ?" + + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + cursor.execute(query, update_values) + conn.commit() + + # Return updated user + updated_user = self.get_user_by_id(user_id) + if not updated_user: + raise UserNotFoundError(f"User with id {user_id} not found after update") + + return updated_user + + def delete_user(self, user_id: str) -> bool: + """ + Delete a user from the database (hard delete). + + Args: + user_id: The unique user identifier. + + Returns: + True if user was deleted, False if user was not found. + """ + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + cursor.execute("DELETE FROM users WHERE id = ?", (user_id,)) + conn.commit() + return cursor.rowcount > 0 + + def email_exists(self, email: str) -> bool: + """ + Check if an email address is already registered. + + Args: + email: The email address to check. + + Returns: + True if email exists, False otherwise. + """ + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + cursor.execute("SELECT 1 FROM users WHERE email = ? LIMIT 1", (email,)) + return cursor.fetchone() is not None + + +class SQLiteTokenRepository(TokenRepository): + """ + SQLite implementation of TokenRepository. + + This implementation manages both refresh tokens and password reset tokens + in a single table with a discriminator field (token_type). + + Attributes: + db_path: Path to the SQLite database file. + """ + + def __init__(self, db_path: str): + """ + Initialize SQLite token repository. + + Creates the tokens table and indexes if they don't exist. + + Args: + db_path: Path to the SQLite database file. + """ + self.db_path = db_path + self._create_tables() + + def _create_tables(self) -> None: + """ + Create tokens table and indexes if they don't exist. + + The table schema includes: + - token: Primary key (random secure string) + - token_type: Discriminator ("refresh" or "password_reset") + - user_id: Foreign key to users + - expires_at: ISO format timestamp + - created_at: ISO format timestamp + - is_revoked: Boolean (stored as INTEGER) + + Indexes are created for efficient queries. + """ + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + + # Create tokens table + cursor.execute(""" + CREATE TABLE IF NOT EXISTS tokens + ( + token + TEXT + PRIMARY + KEY, + token_type + TEXT + NOT + NULL, + user_id + TEXT + NOT + NULL, + expires_at + TEXT + NOT + NULL, + created_at + TEXT + NOT + NULL, + is_revoked + INTEGER + NOT + NULL + DEFAULT + 0 + ) + """) + + # Create index on user_id for revoking all user tokens + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_tokens_user_id + ON tokens(user_id) + """) + + # Create composite index on token_type and user_id + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_tokens_type_user + ON tokens(token_type, user_id) + """) + + # Create index on expires_at for cleanup operations + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_tokens_expires_at + ON tokens(expires_at) + """) + + conn.commit() + + def _row_to_token(self, row: tuple) -> TokenData: + """ + Convert a database row to a TokenData model. + + Args: + row: Database row tuple. + + Returns: + TokenData model instance. + """ + return TokenData( + token=row[0], + token_type=row[1], + user_id=row[2], + expires_at=datetime.fromisoformat(row[3]), + created_at=datetime.fromisoformat(row[4]), + is_revoked=bool(row[5]) + ) + + def save_token(self, token_data: TokenData) -> None: + """ + Save a token to the database. + + Args: + token_data: Complete token information. + """ + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + cursor.execute(""" + INSERT INTO tokens (token, token_type, user_id, expires_at, created_at, is_revoked) + VALUES (?, ?, ?, ?, ?, ?) + """, ( + token_data.token, + token_data.token_type, + token_data.user_id, + token_data.expires_at.isoformat(), + token_data.created_at.isoformat(), + int(token_data.is_revoked) + )) + conn.commit() + + def get_token(self, token: str, token_type: str) -> Optional[TokenData]: + """ + Retrieve a token from the database. + + Args: + token: The token string to search for. + token_type: Type of token ("refresh" or "password_reset"). + + Returns: + The token data if found, None otherwise. + """ + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + cursor.execute(""" + SELECT token, token_type, user_id, expires_at, created_at, is_revoked + FROM tokens + WHERE token = ? + AND token_type = ? + """, (token, token_type)) + + row = cursor.fetchone() + if row: + return self._row_to_token(row) + return None + + def revoke_token(self, token: str) -> bool: + """ + Revoke a specific token. + + Args: + token: The token string to revoke. + + Returns: + True if token was revoked, False if token was not found. + """ + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + cursor.execute(""" + UPDATE tokens + SET is_revoked = 1 + WHERE token = ? + """, (token,)) + conn.commit() + return cursor.rowcount > 0 + + def revoke_all_user_tokens(self, user_id: str, token_type: str) -> int: + """ + Revoke all tokens of a specific type for a user. + + Args: + user_id: The user whose tokens should be revoked. + token_type: Type of tokens to revoke. + + Returns: + Number of tokens revoked. + """ + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + cursor.execute(""" + UPDATE tokens + SET is_revoked = 1 + WHERE user_id = ? + AND token_type = ? + """, (user_id, token_type)) + conn.commit() + return cursor.rowcount + + def delete_expired_tokens(self) -> int: + """ + Delete all expired tokens from the database. + + Returns: + Number of tokens deleted. + """ + now = datetime.utcnow().isoformat() + + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + cursor.execute(""" + DELETE + FROM tokens + WHERE expires_at < ? + """, (now,)) + conn.commit() + return cursor.rowcount + + def is_token_valid(self, token: str, token_type: str) -> bool: + """ + Check if a token is valid (exists, not revoked, not expired). + + Args: + token: The token string to validate. + token_type: Type of token. + + Returns: + True if token is valid, False otherwise. + """ + token_data = self.get_token(token, token_type) + + if not token_data: + return False + + if token_data.is_revoked: + return False + + if token_data.expires_at < datetime.now(): + return False + + return True \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/core/__init__.py b/tests/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/core/conftest.py b/tests/core/conftest.py new file mode 100644 index 0000000..7eb7b66 --- /dev/null +++ b/tests/core/conftest.py @@ -0,0 +1,106 @@ +# tests/core/conftest.py + +import shutil +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + +from my_auth.core.password import PasswordManager +from my_auth.core.token import TokenManager +from src.my_auth.core.auth import AuthService +from src.my_auth.models.user import UserCreate +from src.my_auth.persistence.sqlite import SQLiteUserRepository, SQLiteTokenRepository + + +@pytest.fixture +def test_user_data_create(): + """Provides valid data for creating a test user.""" + return UserCreate( + email="test.service@example.com", + username="TestServiceUser", + password="ValidPassword123!", + roles=["member"], + user_settings={"theme": "dark"} + ) + + +@pytest.fixture +def test_user_hashed_password(): + """Provides a dummy hashed password (only used internally by service)""" + # Note: In service tests, we rely on the service to do the hashing/verification, + # but this is kept for completeness if needed elsewhere. + return "$2b$12$R.S/XfI2tQYt3Kk.iF1XwOQz0Qe.L0T0mD/O1H8E2V5D4Q6F7G8H9I0" + + +@pytest.fixture() +def sqlite_db_path(tmp_path_factory): + """ + Creates a temporary directory and an SQLite file path for the test session. + The directory is deleted after the session. + """ + temp_dir = tmp_path_factory.mktemp("sqlite_auth_service_test") + db_file: Path = temp_dir / "auth_service_test.db" + + yield str(db_file) + + # Cleanup phase + try: + if temp_dir.exists(): + shutil.rmtree(temp_dir) + except OSError as e: + print(f"Error during cleanup of temporary DB directory: {e}") + + +@pytest.fixture +def user_repository(sqlite_db_path: str) -> SQLiteUserRepository: + """Provides a real SQLiteUserRepository instance.""" + return SQLiteUserRepository(db_path=sqlite_db_path) + + +@pytest.fixture +def token_repository(sqlite_db_path: str) -> SQLiteTokenRepository: + """Provides a real SQLiteTokenRepository instance.""" + return SQLiteTokenRepository(db_path=sqlite_db_path) + + +@pytest.fixture +def mock_password_manager() -> PasswordManager: + """Provides a PasswordManager instance for injection (low rounds for speed).""" + mock = MagicMock(spec=PasswordManager) + mock.hash_password.return_value = "PREDICTABLE_HASHED_PASSWORD_FOR_TESTING" + mock.verify_password.return_value = True + return mock + + +@pytest.fixture +def mock_token_manager() -> TokenManager: + """Provides a TokenManager instance for injection (fast expiration settings).""" + mock = MagicMock(spec=TokenManager) + mock.create_access_token.return_value = "MOCKED_ACCESS_TOKEN" + mock.create_refresh_token.return_value = "MOCKED_REFRESH_TOKEN" + return mock + + +# --- Service Fixture --- + +@pytest.fixture +def auth_service(user_repository: SQLiteUserRepository, + token_repository: SQLiteTokenRepository, + mock_password_manager: PasswordManager, + mock_token_manager: TokenManager + ) -> AuthService: + """ + Provides an AuthService instance initialized with real repositories. + """ + # NOTE: To test hashing/verification, we must ensure password_hash_rounds is low. + # NOTE: To simplify JWT testing, we mock the internal components (hashing/JWT) + # as the AuthService shouldn't be responsible for these core algorithms, + # only for orchestrating them. If your service integrates them directly, + # we'll need to patch them below. + return AuthService( + user_repository=user_repository, + token_repository=token_repository, + password_manager=mock_password_manager, + token_manager=mock_token_manager, + ) diff --git a/tests/core/test_auth_service.py b/tests/core/test_auth_service.py new file mode 100644 index 0000000..ba1ce17 --- /dev/null +++ b/tests/core/test_auth_service.py @@ -0,0 +1,270 @@ +# tests/core/test_auth_service.py + +from datetime import datetime, timedelta +from unittest.mock import MagicMock, patch + +import pytest + +from src.my_auth.core.auth import AuthService +from src.my_auth.exceptions import UserAlreadyExistsError, InvalidCredentialsError, InvalidTokenError, ExpiredTokenError +from src.my_auth.models.token import TokenData +from src.my_auth.models.user import UserCreate, UserUpdate + + +class TestAuthServiceRegisterLogin(object): + """Tests for the user registration and login processes.""" + + # The mocks are injected via the auth_service fixture from conftest.py + + def test_register_success(self, auth_service: AuthService, + test_user_data_create: UserCreate): + """Success: Registration works and stores the predictable hash.""" + + user = auth_service.register(test_user_data_create) + + assert user is not None + assert user.hashed_password == "PREDICTABLE_HASHED_PASSWORD_FOR_TESTING" + auth_service.password_manager.hash_password.assert_called_once_with(test_user_data_create.password) + + def test_register_failure_if_email_exists(self, auth_service: AuthService, + test_user_data_create: UserCreate): + """Failure: Cannot register if the email already exists.""" + + auth_service.register(test_user_data_create) + + with pytest.raises(UserAlreadyExistsError): + auth_service.register(test_user_data_create) + + def test_login_success(self, auth_service: AuthService, test_user_data_create: UserCreate): + """Success: Logging in with correct credentials generates and saves tokens.""" + + # Execute register to insert the user + auth_service.register(test_user_data_create) + + # Execute login + user, tokens = auth_service.login(test_user_data_create.email, test_user_data_create.password) + + assert user.email == test_user_data_create.email + + auth_service.password_manager.verify_password.assert_called_once() + auth_service.token_manager.create_access_token.assert_called_once() + assert tokens.access_token == "MOCKED_ACCESS_TOKEN" + assert tokens.refresh_token == "MOCKED_REFRESH_TOKEN" + + def test_login_failure_invalid_password(self, auth_service: AuthService, # Removed mock_password_manager injection + test_user_data_create: UserCreate): + """Failure: Login fails with InvalidCredentialsError if the password is wrong.""" + + # Access the manager via the service instance + pm = auth_service.password_manager + + # Setup: Mock hash for registration + pm.hash_password.return_value = "HASH" + auth_service.register(test_user_data_create) + + # Mock verify for failure + pm.verify_password.return_value = False + + with pytest.raises(InvalidCredentialsError): + auth_service.login(test_user_data_create.email, "WrongPassword!") + + # Restore the default mock behavior defined in conftest for subsequent tests + pm.hash_password.return_value = "PREDICTABLE_HASHED_PASSWORD_FOR_TESTING" + pm.verify_password.return_value = True + + def test_login_failure_user_not_found(self, auth_service: AuthService): + """Failure: Login fails if the user does not exist.""" + with pytest.raises(InvalidCredentialsError): + auth_service.login("non.existent@example.com", "AnyPassword") + + +class TestAuthServiceTokenManagement(object): + """Tests for token-related flows (Refresh, Logout, GetCurrentUser).""" + + @pytest.fixture(autouse=True) + def setup_user_and_token(self, auth_service: AuthService, test_user_data_create: UserCreate): + """ + Sets up a registered user and an initial set of tokens for management tests. + Temporarily overrides manager behavior for setup. + """ + # Temporarily set up predictable values for the registration/login flow within the fixture setup + pm = auth_service.password_manager + tm = auth_service.token_manager + + original_hash = pm.hash_password.return_value + original_verify = pm.verify_password.return_value + original_access = tm.create_access_token.return_value + original_refresh = tm.create_refresh_token.return_value + + pm.hash_password.return_value = "HASHED_PASS" + pm.verify_password.return_value = True + tm.create_access_token.return_value = "SETUP_ACCESS_TOKEN" + tm.create_refresh_token.return_value = "SETUP_REFRESH_TOKEN" + + user = auth_service.register(test_user_data_create) + _, tokens = auth_service.login(test_user_data_create.email, test_user_data_create.password) + + self.user = user + self.refresh_token = tokens.refresh_token + self.access_token = tokens.access_token + + # Restore mock values to default conftest behavior for the actual tests + pm.hash_password.return_value = original_hash + pm.verify_password.return_value = original_verify + tm.create_access_token.return_value = original_access + tm.create_refresh_token.return_value = original_refresh + + def test_refresh_access_token_success(self, auth_service: AuthService): + """Success: Refreshing an access token works with a valid refresh token.""" + + # Access the manager via the service instance + tm = auth_service.token_manager + + # Use patch.object on the *instance* for granular control within the test + with patch.object(tm, 'create_access_token', + return_value="NEW_MOCKED_ACCESS_TOKEN") as mock_create_access: + tokens = auth_service.refresh_access_token(self.refresh_token) + + assert tokens.access_token == "NEW_MOCKED_ACCESS_TOKEN" + mock_create_access.assert_called_once() + + def test_refresh_access_token_failure_invalid_token(self, auth_service: AuthService): + """Failure: Refreshing fails if the token is invalid (revoked, expired, etc.).""" + auth_service.logout(self.refresh_token) + + with pytest.raises(InvalidTokenError): + auth_service.refresh_access_token("invalid_token") + + def test_logout_success(self, auth_service: AuthService): + """Success: Logout revokes the specified refresh token.""" + result = auth_service.logout(self.refresh_token) + assert result is True + + with pytest.raises(InvalidTokenError): + auth_service.refresh_access_token(self.refresh_token) + + def test_get_current_user_success(self, auth_service: AuthService): + """Success: Getting the current user works by successfully decoding the JWT.""" + + # Mock the decoder to simulate a decoded payload + with patch.object(auth_service.token_manager, 'decode_access_token', + return_value={"sub": self.user.id}) as mock_decode: + user = auth_service.get_current_user("dummy_jwt") + + assert user.id == self.user.id + mock_decode.assert_called_once() + + def test_get_current_user_failure_invalid_token(self, auth_service: AuthService): + """Failure: Getting the current user fails if the access token is invalid/expired.""" + + with patch.object(auth_service.token_manager, 'decode_access_token', + side_effect=InvalidTokenError("Invalid signature")): + with pytest.raises(InvalidTokenError): + auth_service.get_current_user("invalid_access_jwt") + + with patch.object(auth_service.token_manager, 'decode_access_token', + side_effect=ExpiredTokenError("Token expired")): + with pytest.raises(ExpiredTokenError): + auth_service.get_current_user("expired_access_jwt") + + +class TestAuthServiceResetVerification(object): + """Tests for password reset and email verification flows.""" + + @pytest.fixture(autouse=True) + def setup_user(self, auth_service: AuthService, test_user_data_create: UserCreate): + """Sets up a registered user using a mock hash for speed.""" + + pm = auth_service.password_manager + original_hash = pm.hash_password.return_value + + # Temporarily set hash for setup + pm.hash_password.return_value = "HASHED_PASS" + user = auth_service.register(test_user_data_create) + self.user = user + + # Restore hash mock + pm.hash_password.return_value = original_hash + + @patch('src.my_auth.core.email.send_email') + def test_request_password_reset_success(self, mock_send_email: MagicMock, auth_service: AuthService): + """Success: Requesting a password reset generates a token and sends an email.""" + + tm = auth_service.token_manager + with patch.object(tm, 'create_password_reset_token', + return_value="MOCKED_RESET_TOKEN") as mock_create_token: + token_string = auth_service.request_password_reset(self.user.email) + + assert token_string == "MOCKED_RESET_TOKEN" + mock_create_token.assert_called_once() + mock_send_email.assert_called_once() + + def test_reset_password_success(self, auth_service: AuthService): + """Success: Resetting the password works with a valid token.""" + + # Setup: Manually create a valid reset token + auth_service.token_repository.save_token( + TokenData(token="valid_reset_token", token_type="password_reset", user_id=self.user.id, + expires_at=datetime.now() + timedelta(minutes=10)) + ) + + # Patch the PasswordManager instance to control the hash output + pm = auth_service.password_manager + with patch.object(pm, 'hash_password', + return_value="NEW_HASHED_PASSWORD_FOR_RESET") as mock_hash: + new_password = "NewPassword123!" + result = auth_service.reset_password("valid_reset_token", new_password) + + assert result is True + mock_hash.assert_called_once_with(new_password) + + # Verification: Check that user data was updated + updated_user = auth_service.user_repository.get_user_by_id(self.user.id) + assert updated_user.hashed_password == "NEW_HASHED_PASSWORD_FOR_RESET" + + @patch('src.my_auth.core.email.send_email') + def test_request_email_verification_success(self, mock_send_email: MagicMock, auth_service: AuthService): + """Success: Requesting verification generates a token and sends an email.""" + + tm = auth_service.token_manager + with patch.object(tm, 'create_email_verification_token', + return_value="MOCKED_JWT_VERIFY_TOKEN") as mock_create_token: + token_string = auth_service.request_email_verification(self.user.email) + + assert token_string == "MOCKED_JWT_VERIFY_TOKEN" + mock_create_token.assert_called_once_with(self.user.email) + mock_send_email.assert_called_once() + + def test_verify_email_success(self, auth_service: AuthService): + """Success: Verification updates the user's status.""" + + # The token_manager is mocked in conftest, so we must access its real create method + # or rely on the mock's return value to get a token string to use in the call. + # Since we need a real token for the decode logic to pass, we need to bypass the mock here. + + # We will temporarily use the real TokenManager to create a valid, decodable token. + # This requires an *unmocked* token manager instance, which is tricky in this setup. + + # Alternative: Temporarily inject a real TokenManager for this test (or rely on a non-mocked method) + + # Assuming TokenManager.create_email_verification_token can be mocked to return a static string + # and TokenManager.decode_email_verification_token can be patched to simulate success. + + # Since the method calls decode_email_verification_token internally, we mock the output of the decode step. + + # Setup: Ensure user is unverified + auth_service.user_repository.update_user(self.user.id, UserUpdate(is_verified=False)) + + tm = auth_service.token_manager + + # Mock the decode step to ensure it returns the email used for verification + with patch.object(tm, 'decode_email_verification_token', return_value=self.user.email) as mock_decode: + # Test (we use a dummy token string as the decode step is mocked) + result = auth_service.verify_email("dummy_verification_token") + + assert result is True + mock_decode.assert_called_once() + + # Verification: User is verified + updated_user = auth_service.user_repository.get_user_by_id(self.user.id) + assert updated_user.is_verified is True diff --git a/tests/persistence/__init__.py b/tests/persistence/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/persistence/conftest.py b/tests/persistence/conftest.py new file mode 100644 index 0000000..b69a58c --- /dev/null +++ b/tests/persistence/conftest.py @@ -0,0 +1,75 @@ +from uuid import uuid4 +from datetime import datetime, timedelta +from pathlib import Path + +import pytest + +from my_auth.models.token import TokenData +from my_auth.models.user import UserCreate +from my_auth.persistence.sqlite import SQLiteUserRepository, SQLiteTokenRepository + + +@pytest.fixture +def test_user_data_create(): + """Provides valid data for creating a test user.""" + return UserCreate( + email="test.user@example.com", + username="TestUser", + password="ValidPassword123!", # Meets all strength criteria + roles=["member"], + user_settings={"theme": "dark"} + ) + + +@pytest.fixture +def test_user_hashed_password(): + """Provides a dummy hashed password for persistence (hashing is tested elsewhere).""" + return "$2b$12$R.S/XfI2tQYt3Kk.iF1XwOQz0Qe.L0T0mD/O1H8E2V5D4Q6F7G8H9I0" + + +@pytest.fixture +def test_token_data(): + """Provides valid data for a refresh token.""" + now = datetime.now() + return TokenData( + token=f"opaque_refresh_token_{uuid4()}", + token_type="refresh", + user_id="user_id_for_token_test", + expires_at=now + timedelta(days=7), + is_revoked=False, + created_at=now + ) + + +@pytest.fixture() +def sqlite_db_path(tmp_path_factory): + """ + Creates a temporary directory and an SQLite file path for the test session. + The directory and its contents are deleted after the session. + """ + temp_dir = tmp_path_factory.mktemp("sqlite_auth_test") + db_file: Path = temp_dir / "auth_test.db" + + yield str(db_file) + + try: + if temp_dir.exists(): + import shutil + shutil.rmtree(temp_dir) + except OSError as e: + # Handle case where directory might be locked, though rare in tests + print(f"Error during cleanup of temporary DB directory: {e}") + + +@pytest.fixture +def user_repository(sqlite_db_path: str) -> SQLiteUserRepository: + """Provides an instance of SQLiteUserRepository initialized with the in-memory DB.""" + # Assuming the repository takes the connection object or path (using connection here) + return SQLiteUserRepository(sqlite_db_path) + + +@pytest.fixture +def token_repository(sqlite_db_path: str) -> SQLiteTokenRepository: + """Provides an instance of SQLiteTokenRepository initialized with the in-memory DB.""" + # Assuming the repository takes the connection object or path (using connection here) + return SQLiteTokenRepository(sqlite_db_path) diff --git a/tests/persistence/test_sql_token.py b/tests/persistence/test_sql_token.py new file mode 100644 index 0000000..05d5dc4 --- /dev/null +++ b/tests/persistence/test_sql_token.py @@ -0,0 +1,87 @@ +# tests/persistence/test_sqlite_token.py + +from datetime import datetime, timedelta + +from my_auth.models.token import TokenData +from my_auth.persistence.sqlite import SQLiteTokenRepository + + +def test_i_can_save_and_retrieve_token(token_repository: SQLiteTokenRepository, + test_token_data: TokenData): + """Verifies token saving and successful retrieval by token string and type.""" + + # 1. Save Token + token_repository.save_token(test_token_data) + + # 2. Retrieve Token + retrieved_token = token_repository.get_token(test_token_data.token, test_token_data.token_type) + + # Assertions + assert retrieved_token is not None + assert retrieved_token.token == test_token_data.token + assert retrieved_token.user_id == test_token_data.user_id + assert retrieved_token.is_revoked is False + assert retrieved_token.token_type == test_token_data.token_type + + +def test_i_can_revoke_token(token_repository: SQLiteTokenRepository, + test_token_data: TokenData): + """Verifies a token can be revoked and its revoked status is updated.""" + + # Setup: Save the token + token_repository.save_token(test_token_data) + + # 1. Revoke the token + was_revoked = token_repository.revoke_token(test_token_data.token) + assert was_revoked is True + + # 2. Retrieve and check status + revoked_token = token_repository.get_token(test_token_data.token, test_token_data.token_type) + assert revoked_token is not None + assert revoked_token.is_revoked is True + + # 3. Attempt to revoke a non-existent token + was_revoked_again = token_repository.revoke_token("non_existent_token") + assert was_revoked_again is False + + +def test_i_can_use_is_token_valid_for_valid_token(token_repository: SQLiteTokenRepository, + test_token_data: TokenData): + """Verifies the convenience method returns True for a fresh, unexpired token.""" + + token_repository.save_token(test_token_data) + + is_valid = token_repository.is_token_valid(test_token_data.token, test_token_data.token_type) + assert is_valid is True + + is_valid = token_repository.is_token_valid("non_existent_token", test_token_data.token_type) + assert is_valid is False + + +def test_i_can_use_is_token_valid_for_revoked_token(token_repository: SQLiteTokenRepository, + test_token_data: TokenData): + """Verifies is_token_valid returns False for a token marked as revoked.""" + + token_repository.save_token(test_token_data) + token_repository.revoke_token(test_token_data.token) + + is_valid = token_repository.is_token_valid(test_token_data.token, test_token_data.token_type) + assert is_valid is False + + +def test_i_can_use_is_token_valid_for_expired_token(token_repository: SQLiteTokenRepository): + """Verifies is_token_valid returns False for a token whose expiration is in the past.""" + + expired_token_data = TokenData( + token="expired_token_test", + token_type="password_reset", + user_id="user_id_expired", + expires_at=datetime.now() - timedelta(hours=1), # Set expiration to 1 hour ago + is_revoked=False, + created_at=datetime.now() - timedelta(hours=2) + ) + + token_repository.save_token(expired_token_data) + + is_valid = token_repository.is_token_valid(expired_token_data.token, expired_token_data.token_type) + assert is_valid is False diff --git a/tests/persistence/test_sqlite_user.py b/tests/persistence/test_sqlite_user.py new file mode 100644 index 0000000..03a199d --- /dev/null +++ b/tests/persistence/test_sqlite_user.py @@ -0,0 +1,155 @@ +# tests/persistence/test_sqlite_user.py + +import pytest +import json +from datetime import datetime + +from my_auth.persistence.sqlite import SQLiteUserRepository +from my_auth.models.user import UserCreate, UserUpdate +from my_auth.exceptions import UserAlreadyExistsError, UserNotFoundError + + +def test_i_can_create_and_retrieve_user_by_email(user_repository: SQLiteUserRepository, + test_user_data_create: UserCreate, + test_user_hashed_password: str): + """Verifies user creation and successful retrieval by email.""" + + # 1. Create User + created_user = user_repository.create_user( + user_data=test_user_data_create, + hashed_password=test_user_hashed_password + ) + + # Assertions on creation + assert created_user is not None + assert created_user.email == test_user_data_create.email + assert created_user.hashed_password == test_user_hashed_password + assert created_user.is_active is True + assert created_user.is_verified is False + assert created_user.id is not None + assert isinstance(created_user.created_at, datetime) + + # 2. Retrieve User + retrieved_user = user_repository.get_user_by_email(test_user_data_create.email) + + # Assertions on retrieval + assert retrieved_user is not None + assert retrieved_user.id == created_user.id + assert retrieved_user.username == "TestUser" + + +def test_i_can_retrieve_user_by_id(user_repository: SQLiteUserRepository, + test_user_data_create: UserCreate, + test_user_hashed_password: str): + """Verifies user retrieval by unique ID.""" + created_user = user_repository.create_user(test_user_data_create, test_user_hashed_password) + + retrieved_user = user_repository.get_user_by_id(created_user.id) + + assert retrieved_user is not None + assert retrieved_user.email == created_user.email + + +def test_i_cannot_create_user_if_email_exists(user_repository: SQLiteUserRepository, + test_user_data_create: UserCreate, + test_user_hashed_password: str): + """Ensures that creating a user with an existing email raises UserAlreadyExistsError.""" + # First creation should succeed + user_repository.create_user(test_user_data_create, test_user_hashed_password) + + # Second creation with same email should fail + with pytest.raises(UserAlreadyExistsError): + user_repository.create_user(test_user_data_create, test_user_hashed_password) + + +def test_i_can_check_if_email_exists(user_repository: SQLiteUserRepository, + test_user_data_create: UserCreate, + test_user_hashed_password: str): + """Verifies the email_exists method returns correct boolean results.""" + email = test_user_data_create.email + non_existent_email = "non.existent@example.com" + + # 1. Check before creation (Should be False) + assert user_repository.email_exists(email) is False + + user_repository.create_user(test_user_data_create, test_user_hashed_password) + + # 2. Check after creation (Should be True) + assert user_repository.email_exists(email) is True + + # 3. Check for another non-existent email + assert user_repository.email_exists(non_existent_email) is False + + +def test_i_can_update_username_and_roles(user_repository: SQLiteUserRepository, + test_user_data_create: UserCreate, + test_user_hashed_password: str): + """Tests partial update of user fields (username and roles).""" + + created_user = user_repository.create_user(test_user_data_create, test_user_hashed_password) + + updates = UserUpdate(username="NewUsername", roles=["admin", "staff"]) + + updated_user = user_repository.update_user(created_user.id, updates) + + assert updated_user.username == "NewUsername" + assert updated_user.roles == ["admin", "staff"] + # Check that unrelated fields remain the same + assert updated_user.email == created_user.email + # Check that update timestamp changed + assert updated_user.updated_at > created_user.updated_at + assert updated_user.is_verified == created_user.is_verified + + +def test_i_can_update_is_active_status(user_repository: SQLiteUserRepository, + test_user_data_create: UserCreate, + test_user_hashed_password: str): + """Tests the specific update of the 'is_active' status.""" + + created_user = user_repository.create_user(test_user_data_create, test_user_hashed_password) + + # Deactivate + updates_deactivate = UserUpdate(is_active=False) + deactivated_user = user_repository.update_user(created_user.id, updates_deactivate) + assert deactivated_user.is_active is False + + # Reactivate + updates_activate = UserUpdate(is_active=True) + reactivated_user = user_repository.update_user(created_user.id, updates_activate) + assert reactivated_user.is_active is True + + +def test_i_cannot_update_non_existent_user(user_repository: SQLiteUserRepository): + """Ensures that updating a user with an unknown ID raises UserNotFoundError.""" + non_existent_id = "unknown_id_123" + updates = UserUpdate(username="Phantom") + + with pytest.raises(UserNotFoundError): + user_repository.update_user(non_existent_id, updates) + + +def test_i_can_delete_user(user_repository: SQLiteUserRepository, + test_user_data_create: UserCreate, + test_user_hashed_password: str): + """Verifies user deletion and subsequent failure to retrieve.""" + + created_user = user_repository.create_user(test_user_data_create, test_user_hashed_password) + + # 1. Delete the user + was_deleted = user_repository.delete_user(created_user.id) + assert was_deleted is True + + # 2. Verify deletion by attempting retrieval + retrieved_user = user_repository.get_user_by_id(created_user.id) + assert retrieved_user is None + + # 3. Verify attempting to delete again returns False + was_deleted_again = user_repository.delete_user(created_user.id) + assert was_deleted_again is False + + +def test_i_cannot_retrieve_non_existent_user_by_email(user_repository: SQLiteUserRepository): + """Ensures retrieval by email returns None for non-existent email.""" + + retrieved_user = user_repository.get_user_by_email("ghost@example.com") + assert retrieved_user is None \ No newline at end of file