Unit testing AuthService
This commit is contained in:
216
.gitignore
vendored
Normal file
216
.gitignore
vendored
Normal file
@@ -0,0 +1,216 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[codz]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py.cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
# Pipfile.lock
|
||||
|
||||
# UV
|
||||
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# uv.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
# poetry.lock
|
||||
# poetry.toml
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
|
||||
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
|
||||
# pdm.lock
|
||||
# pdm.toml
|
||||
.pdm-python
|
||||
.pdm-build/
|
||||
|
||||
# pixi
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
|
||||
# pixi.lock
|
||||
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
|
||||
# in the .venv directory. It is recommended not to include this directory in version control.
|
||||
.pixi
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# Redis
|
||||
*.rdb
|
||||
*.aof
|
||||
*.pid
|
||||
|
||||
# RabbitMQ
|
||||
mnesia/
|
||||
rabbitmq/
|
||||
rabbitmq-data/
|
||||
|
||||
# ActiveMQ
|
||||
activemq-data/
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.envrc
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
# .idea/
|
||||
|
||||
# Abstra
|
||||
# Abstra is an AI-powered process automation framework.
|
||||
# Ignore directories containing user credentials, local state, and settings.
|
||||
# Learn more at https://abstra.io/docs
|
||||
.abstra/
|
||||
|
||||
# Visual Studio Code
|
||||
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
||||
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
||||
# you could uncomment the following to ignore the entire vscode folder
|
||||
# .vscode/
|
||||
|
||||
# Ruff stuff:
|
||||
.ruff_cache/
|
||||
|
||||
# PyPI configuration file
|
||||
.pypirc
|
||||
|
||||
# Marimo
|
||||
marimo/_static/
|
||||
marimo/_lsp/
|
||||
__marimo__/
|
||||
|
||||
# Streamlit
|
||||
.streamlit/secrets.toml
|
||||
285
README.md
Normal file
285
README.md
Normal file
@@ -0,0 +1,285 @@
|
||||
# MyAuth Module
|
||||
|
||||
A reusable, modular authentication system for FastAPI applications with pluggable database backends.
|
||||
|
||||
## Overview
|
||||
|
||||
This module provides a complete authentication solution designed to be deployed on internal PyPI servers and reused across multiple projects. It handles user registration, login/logout, token management, email verification, and password reset functionality.
|
||||
|
||||
## Features
|
||||
|
||||
### Core Authentication
|
||||
- ✅ User registration with email and username
|
||||
- ✅ Login/logout with email-based authentication
|
||||
- ✅ JWT-based access tokens (30 minutes validity)
|
||||
- ✅ Opaque refresh tokens stored in database (7 days validity)
|
||||
- ✅ Password hashing with configurable bcrypt rounds
|
||||
|
||||
### User Management
|
||||
- ✅ Email verification with JWT tokens
|
||||
- ✅ Password reset with secure random tokens (15 minutes validity)
|
||||
- ✅ User roles management (flexible, no predefined roles)
|
||||
- ✅ User settings storage (dict field)
|
||||
- ✅ Account activation/deactivation
|
||||
|
||||
### Password Security
|
||||
- ✅ Strict password validation (via Pydantic):
|
||||
- Minimum 8 characters
|
||||
- At least 1 uppercase letter
|
||||
- At least 1 lowercase letter
|
||||
- At least 1 digit
|
||||
- At least 1 special character
|
||||
|
||||
### Architecture
|
||||
- ✅ Abstract base classes for database persistence
|
||||
- ✅ Multiple database implementations: MongoDB, SQLite, PostgreSQL
|
||||
- ✅ Abstract email service interface with optional SMTP implementation
|
||||
- ✅ Custom exceptions with FastAPI integration
|
||||
- ✅ Synchronous implementation
|
||||
- ✅ Ready-to-use FastAPI router with `/auth` prefix
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
├──src
|
||||
│ my_auth/
|
||||
│ ├── __init__.py
|
||||
│ ├── models/ # Pydantic models
|
||||
│ │ ├── user.py # User model with roles and settings
|
||||
│ │ ├── token.py # Token models (access, refresh, reset)
|
||||
│ │ └── email_verification.py # Email verification models
|
||||
│ ├── core/ # Business logic
|
||||
│ │ ├── auth.py # Authentication service
|
||||
│ │ ├── password.py # Password hashing/verification
|
||||
│ │ └── token.py # Token generation/validation
|
||||
│ ├── persistence/ # Database abstraction
|
||||
│ │ ├── base.py # Abstract base classes
|
||||
│ │ ├── mongodb.py # MongoDB implementation
|
||||
│ │ ├── sqlite.py # SQLite implementation
|
||||
│ │ └── postgresql.py # PostgreSQL implementation
|
||||
│ ├── api/ # FastAPI routes
|
||||
│ │ └── routes.py # All authentication endpoints
|
||||
│ ├── email/ # Email service
|
||||
│ │ ├── base.py # Abstract interface
|
||||
│ │ └── smtp.py # SMTP implementation (optional)
|
||||
│ ├── exceptions.py # Custom exceptions
|
||||
│ └── config.py # Configuration classes
|
||||
├── tests
|
||||
```
|
||||
|
||||
## User Model
|
||||
|
||||
```python
|
||||
class User:
|
||||
id: str # Unique identifier
|
||||
email: str # Unique, required
|
||||
username: str # Required, non-unique
|
||||
hashed_password: str # Bcrypt hashed
|
||||
roles: list[str] # Free-form roles, no defaults
|
||||
user_settings: dict # Custom user settings
|
||||
is_verified: bool # Email verification status
|
||||
is_active: bool # Account active status
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
```
|
||||
|
||||
## Token Management
|
||||
|
||||
### Token Types
|
||||
The module uses a unified tokens collection with a discriminator field:
|
||||
|
||||
1. **Access Token (JWT)**: 30 minutes validity, stateless
|
||||
2. **Refresh Token (Opaque)**: 7 days validity, stored in DB
|
||||
3. **Password Reset Token (Random)**: 15 minutes validity, stored in DB
|
||||
4. **Email Verification Token (JWT)**: Stateless, no DB storage
|
||||
|
||||
### Token Storage
|
||||
All tokens requiring storage (refresh and password reset) are kept in a single `tokens` collection/table with a `token_type` discriminator field.
|
||||
|
||||
## API Endpoints
|
||||
|
||||
The module exposes a pre-configured FastAPI router with the following endpoints:
|
||||
|
||||
```
|
||||
POST /auth/register # User registration
|
||||
POST /auth/login # Login (email + password)
|
||||
POST /auth/logout # Logout (revokes refresh token)
|
||||
POST /auth/refresh # Refresh access token
|
||||
POST /auth/password-reset-request # Request password reset
|
||||
POST /auth/password-reset # Reset password with token
|
||||
POST /auth/verify-email-request # Request email verification
|
||||
POST /auth/verify-email # Verify email with token
|
||||
GET /auth/me # Get current user info
|
||||
```
|
||||
|
||||
## Installation
|
||||
|
||||
### Dependencies
|
||||
|
||||
**Core dependencies:**
|
||||
```bash
|
||||
pip install fastapi pydantic pydantic-settings python-jose[cryptography] passlib[bcrypt] python-multipart
|
||||
```
|
||||
|
||||
**Database-specific dependencies:**
|
||||
- MongoDB: `pip install pymongo`
|
||||
- SQLite: Built-in (no additional dependency)
|
||||
- PostgreSQL: `pip install psycopg2-binary`
|
||||
|
||||
**Optional email dependency:**
|
||||
- SMTP: `pip install secure-smtplib`
|
||||
|
||||
## Usage
|
||||
|
||||
### Basic Setup
|
||||
|
||||
```python
|
||||
from fastapi import FastAPI
|
||||
from auth_module import AuthService
|
||||
from auth_module.persistence.mongodb import MongoUserRepository, MongoTokenRepository
|
||||
from auth_module.api import auth_router
|
||||
|
||||
# Initialize repositories
|
||||
user_repo = MongoUserRepository(connection_string="mongodb://localhost:27017/mydb")
|
||||
token_repo = MongoTokenRepository(connection_string="mongodb://localhost:27017/mydb")
|
||||
|
||||
# Initialize auth service
|
||||
auth_service = AuthService(
|
||||
user_repository=user_repo,
|
||||
token_repository=token_repo,
|
||||
jwt_secret="your-secret-key-here",
|
||||
access_token_expire_minutes=30,
|
||||
refresh_token_expire_days=7,
|
||||
password_reset_token_expire_minutes=15,
|
||||
password_hash_rounds=12
|
||||
)
|
||||
|
||||
# Create FastAPI app and include auth router
|
||||
app = FastAPI()
|
||||
app.include_router(auth_router) # Mounts at /auth prefix
|
||||
```
|
||||
|
||||
### Using Different Databases
|
||||
|
||||
#### SQLite
|
||||
```python
|
||||
from auth_module.persistence.sqlite import SQLiteUserRepository, SQLiteTokenRepository
|
||||
|
||||
user_repo = SQLiteUserRepository(db_path="./auth.db")
|
||||
token_repo = SQLiteTokenRepository(db_path="./auth.db")
|
||||
```
|
||||
|
||||
#### PostgreSQL
|
||||
```python
|
||||
from auth_module.persistence.postgresql import PostgreSQLUserRepository, PostgreSQLTokenRepository
|
||||
|
||||
user_repo = PostgreSQLUserRepository(
|
||||
host="localhost",
|
||||
port=5432,
|
||||
database="mydb",
|
||||
user="postgres",
|
||||
password="secret"
|
||||
)
|
||||
token_repo = PostgreSQLTokenRepository(...)
|
||||
```
|
||||
|
||||
### Email Service Configuration
|
||||
|
||||
```python
|
||||
from auth_module.email.smtp import SMTPEmailService
|
||||
|
||||
email_service = SMTPEmailService(
|
||||
host="smtp.gmail.com",
|
||||
port=587,
|
||||
username="your-email@gmail.com",
|
||||
password="your-app-password",
|
||||
use_tls=True
|
||||
)
|
||||
|
||||
auth_service = AuthService(
|
||||
user_repository=user_repo,
|
||||
token_repository=token_repo,
|
||||
email_service=email_service, # Optional
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
### Custom Email Service
|
||||
|
||||
Implement your own email service by extending the abstract base class:
|
||||
|
||||
```python
|
||||
from auth_module.email.base import EmailService
|
||||
|
||||
class CustomEmailService(EmailService):
|
||||
def send_verification_email(self, email: str, token: str) -> None:
|
||||
# Your implementation (SendGrid, AWS SES, etc.)
|
||||
pass
|
||||
|
||||
def send_password_reset_email(self, email: str, token: str) -> None:
|
||||
# Your implementation
|
||||
pass
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
The module uses custom exceptions that are automatically converted to appropriate HTTP responses:
|
||||
|
||||
- `InvalidCredentialsError` → 401 Unauthorized
|
||||
- `UserAlreadyExistsError` → 409 Conflict
|
||||
- `UserNotFoundError` → 404 Not Found
|
||||
- `InvalidTokenError` → 401 Unauthorized
|
||||
- `RevokedTokenError` → 401 Unauthorized
|
||||
- `ExpiredTokenError` → 401 Unauthorized
|
||||
- `EmailNotVerifiedError` → 403 Forbidden
|
||||
- `AccountDisabledError` → 403 Forbidden
|
||||
|
||||
## Configuration Options
|
||||
|
||||
```python
|
||||
AuthService(
|
||||
user_repository: UserRepository, # Required
|
||||
token_repository: TokenRepository, # Required
|
||||
jwt_secret: str, # Required
|
||||
jwt_algorithm: str = "HS256", # Optional
|
||||
access_token_expire_minutes: int = 30, # Optional
|
||||
refresh_token_expire_days: int = 7, # Optional
|
||||
password_reset_token_expire_minutes: int = 15, # Optional
|
||||
password_hash_rounds: int = 12, # Optional (bcrypt cost)
|
||||
email_service: EmailService = None # Optional
|
||||
)
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
The module is fully testable with pytest. Test fixtures are provided for each database implementation.
|
||||
|
||||
```bash
|
||||
pytest tests/
|
||||
```
|
||||
|
||||
## Security Considerations
|
||||
|
||||
- Passwords are hashed using bcrypt with configurable rounds
|
||||
- JWT tokens are signed with HS256 (configurable)
|
||||
- Refresh tokens are opaque and stored securely
|
||||
- Password reset tokens are single-use and expire after 15 minutes
|
||||
- Email verification tokens are stateless JWT
|
||||
- Rate limiting should be implemented at the application level
|
||||
- HTTPS should be enforced by the application
|
||||
|
||||
## Future Enhancements (Not Included)
|
||||
|
||||
- Multi-factor authentication (2FA/MFA)
|
||||
- Rate limiting on login attempts
|
||||
- OAuth2 provider integration
|
||||
- Session management (multiple device tracking)
|
||||
- Account lockout after failed attempts
|
||||
|
||||
## License
|
||||
|
||||
[Your License Here]
|
||||
|
||||
## Contributing
|
||||
|
||||
[Your Contributing Guidelines Here]
|
||||
16
main.py
Normal file
16
main.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# This is a sample Python script.
|
||||
|
||||
# Press Ctrl+F5 to execute it or replace it with your code.
|
||||
# Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings.
|
||||
|
||||
|
||||
def print_hi(name):
|
||||
# Use a breakpoint in the code line below to debug your script.
|
||||
print(f'Hi, {name}') # Press F9 to toggle the breakpoint.
|
||||
|
||||
|
||||
# Press the green button in the gutter to run the script.
|
||||
if __name__ == '__main__':
|
||||
print_hi('PyCharm')
|
||||
|
||||
# See PyCharm help at https://www.jetbrains.com/help/pycharm/
|
||||
0
src/__init__.py
Normal file
0
src/__init__.py
Normal file
0
src/my_auth/__init__.py
Normal file
0
src/my_auth/__init__.py
Normal file
0
src/my_auth/core/__init__.py
Normal file
0
src/my_auth/core/__init__.py
Normal file
445
src/my_auth/core/auth.py
Normal file
445
src/my_auth/core/auth.py
Normal file
@@ -0,0 +1,445 @@
|
||||
"""
|
||||
Authentication service for the authentication module.
|
||||
|
||||
This module provides the main authentication service that orchestrates
|
||||
all authentication operations including registration, login, token management,
|
||||
password reset, and email verification.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from .password import PasswordManager
|
||||
from .token import TokenManager
|
||||
from ..exceptions import (
|
||||
InvalidCredentialsError,
|
||||
UserNotFoundError,
|
||||
AccountDisabledError,
|
||||
ExpiredTokenError,
|
||||
InvalidTokenError,
|
||||
RevokedTokenError
|
||||
)
|
||||
from ..models.token import AccessTokenResponse, TokenData
|
||||
from ..models.user import UserCreate, UserInDB, UserUpdate
|
||||
from ..persistence.base import UserRepository, TokenRepository
|
||||
|
||||
|
||||
class AuthService:
|
||||
"""
|
||||
Main authentication service.
|
||||
|
||||
This service orchestrates all authentication-related operations by
|
||||
coordinating between password management, token management, and
|
||||
persistence layers.
|
||||
|
||||
Attributes:
|
||||
user_repository: Repository for user persistence operations.
|
||||
token_repository: Repository for token persistence operations.
|
||||
password_manager: Manager for password hashing and verification.
|
||||
token_manager: Manager for token creation and validation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_repository: UserRepository,
|
||||
token_repository: TokenRepository,
|
||||
password_manager: PasswordManager,
|
||||
token_manager: TokenManager
|
||||
):
|
||||
"""
|
||||
Initialize the authentication service.
|
||||
|
||||
Args:
|
||||
user_repository: Repository for user persistence.
|
||||
token_repository: Repository for token persistence.
|
||||
jwt_secret: Secret key for JWT signing.
|
||||
jwt_algorithm: JWT algorithm (default: HS256).
|
||||
access_token_expire_minutes: Access token validity (default: 30).
|
||||
refresh_token_expire_days: Refresh token validity (default: 7).
|
||||
password_reset_token_expire_minutes: Reset token validity (default: 15).
|
||||
password_hash_rounds: Bcrypt rounds (default: 12).
|
||||
"""
|
||||
self.user_repository = user_repository
|
||||
self.token_repository = token_repository
|
||||
self.password_manager = password_manager
|
||||
self.token_manager = token_manager
|
||||
|
||||
def register(self, user_data: UserCreate) -> UserInDB:
|
||||
"""
|
||||
Register a new user.
|
||||
|
||||
This method creates a new user account with hashed password.
|
||||
The user's email is initially unverified.
|
||||
|
||||
Args:
|
||||
user_data: User registration data including password.
|
||||
|
||||
Returns:
|
||||
The created user (without password).
|
||||
|
||||
Raises:
|
||||
UserAlreadyExistsError: If email is already registered.
|
||||
|
||||
Example:
|
||||
>>> user_data = UserCreate(
|
||||
... email="user@example.com",
|
||||
... username="john_doe",
|
||||
... password="SecurePass123!"
|
||||
... )
|
||||
>>> user = auth_service.register(user_data)
|
||||
"""
|
||||
hashed_password = self.password_manager.hash_password(user_data.password)
|
||||
user = self.user_repository.create_user(user_data, hashed_password)
|
||||
return user
|
||||
|
||||
def login(self, email: str, password: str) -> tuple[UserInDB, AccessTokenResponse]:
|
||||
"""
|
||||
Authenticate a user and create tokens.
|
||||
|
||||
This method verifies credentials, checks account status, and generates
|
||||
both access and refresh tokens.
|
||||
|
||||
Args:
|
||||
email: User's email address.
|
||||
password: User's plain text password.
|
||||
|
||||
Returns:
|
||||
Tuple of (user, tokens) where tokens contains access_token and refresh_token.
|
||||
|
||||
Raises:
|
||||
InvalidCredentialsError: If email or password is incorrect.
|
||||
AccountDisabledError: If user account is disabled.
|
||||
|
||||
Example:
|
||||
>>> user, tokens = auth_service.login("user@example.com", "password")
|
||||
>>> print(tokens.access_token)
|
||||
"""
|
||||
# Get user by email
|
||||
user = self.user_repository.get_user_by_email(email)
|
||||
if not user:
|
||||
raise InvalidCredentialsError()
|
||||
|
||||
# Verify password
|
||||
if not self.password_manager.verify_password(password, user.hashed_password):
|
||||
raise InvalidCredentialsError()
|
||||
|
||||
# Check if account is active
|
||||
if not user.is_active:
|
||||
raise AccountDisabledError()
|
||||
|
||||
# Create tokens
|
||||
access_token = self.token_manager.create_access_token(user)
|
||||
refresh_token = self.token_manager.create_refresh_token()
|
||||
|
||||
# Store refresh token in database
|
||||
token_data = TokenData(
|
||||
token=refresh_token,
|
||||
token_type="refresh",
|
||||
user_id=user.id,
|
||||
expires_at=self.token_manager.get_refresh_token_expiration(),
|
||||
created_at=datetime.now(),
|
||||
is_revoked=False
|
||||
)
|
||||
self.token_repository.save_token(token_data)
|
||||
|
||||
tokens = AccessTokenResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
token_type="bearer"
|
||||
)
|
||||
|
||||
return user, tokens
|
||||
|
||||
def refresh_access_token(self, refresh_token: str) -> AccessTokenResponse:
|
||||
"""
|
||||
Create a new access token using a refresh token.
|
||||
|
||||
This method validates the refresh token and generates a new access token
|
||||
without requiring the user to re-enter their password.
|
||||
|
||||
Args:
|
||||
refresh_token: The refresh token to exchange.
|
||||
|
||||
Returns:
|
||||
New access and refresh tokens.
|
||||
|
||||
Raises:
|
||||
InvalidTokenError: If refresh token is invalid.
|
||||
ExpiredTokenError: If refresh token has expired.
|
||||
RevokedTokenError: If refresh token has been revoked.
|
||||
UserNotFoundError: If user no longer exists.
|
||||
AccountDisabledError: If user account is disabled.
|
||||
|
||||
Example:
|
||||
>>> tokens = auth_service.refresh_access_token(old_refresh_token)
|
||||
>>> print(tokens.access_token)
|
||||
"""
|
||||
# Validate refresh token
|
||||
token_data = self.token_repository.get_token(refresh_token, "refresh")
|
||||
if not token_data:
|
||||
raise InvalidTokenError("Invalid refresh token")
|
||||
|
||||
if token_data.is_revoked:
|
||||
raise RevokedTokenError()
|
||||
|
||||
if token_data.expires_at < datetime.now():
|
||||
raise ExpiredTokenError()
|
||||
|
||||
# Get user
|
||||
user = self.user_repository.get_user_by_id(token_data.user_id)
|
||||
if not user:
|
||||
raise UserNotFoundError()
|
||||
|
||||
if not user.is_active:
|
||||
raise AccountDisabledError()
|
||||
|
||||
# Create new tokens
|
||||
access_token = self.token_manager.create_access_token(user)
|
||||
new_refresh_token = self.token_manager.create_refresh_token()
|
||||
|
||||
# Revoke old refresh token
|
||||
self.token_repository.revoke_token(refresh_token)
|
||||
|
||||
# Store new refresh token
|
||||
new_token_data = TokenData(
|
||||
token=new_refresh_token,
|
||||
token_type="refresh",
|
||||
user_id=user.id,
|
||||
expires_at=self.token_manager.get_refresh_token_expiration(),
|
||||
created_at=datetime.now(),
|
||||
is_revoked=False
|
||||
)
|
||||
self.token_repository.save_token(new_token_data)
|
||||
|
||||
return AccessTokenResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=new_refresh_token,
|
||||
token_type="bearer"
|
||||
)
|
||||
|
||||
def logout(self, refresh_token: str) -> bool:
|
||||
"""
|
||||
Logout a user by revoking their refresh token.
|
||||
|
||||
This prevents the refresh token from being used to obtain new
|
||||
access tokens. The current access token will remain valid until
|
||||
it expires naturally.
|
||||
|
||||
Args:
|
||||
refresh_token: The refresh token to revoke.
|
||||
|
||||
Returns:
|
||||
True if logout was successful, False if token not found.
|
||||
|
||||
Example:
|
||||
>>> auth_service.logout(refresh_token)
|
||||
True
|
||||
"""
|
||||
return self.token_repository.revoke_token(refresh_token)
|
||||
|
||||
def request_password_reset(self, email: str) -> str:
|
||||
"""
|
||||
Generate a password reset token for a user.
|
||||
|
||||
This method creates a secure token that can be sent to the user
|
||||
via email to reset their password.
|
||||
|
||||
Args:
|
||||
email: User's email address.
|
||||
|
||||
Returns:
|
||||
The password reset token to be sent via email.
|
||||
|
||||
Raises:
|
||||
UserNotFoundError: If email is not registered.
|
||||
|
||||
Example:
|
||||
>>> token = auth_service.request_password_reset("user@example.com")
|
||||
>>> # Send token via email service
|
||||
"""
|
||||
user = self.user_repository.get_user_by_email(email)
|
||||
if not user:
|
||||
raise UserNotFoundError(f"No user found with email {email}")
|
||||
|
||||
# Create reset token
|
||||
reset_token = self.token_manager.create_password_reset_token()
|
||||
|
||||
# Store token in database
|
||||
token_data = TokenData(
|
||||
token=reset_token,
|
||||
token_type="password_reset",
|
||||
user_id=user.id,
|
||||
expires_at=self.token_manager.get_password_reset_token_expiration(),
|
||||
created_at=datetime.now(),
|
||||
is_revoked=False
|
||||
)
|
||||
self.token_repository.save_token(token_data)
|
||||
|
||||
return reset_token
|
||||
|
||||
def reset_password(self, token: str, new_password: str) -> bool:
|
||||
"""
|
||||
Reset a user's password using a reset token.
|
||||
|
||||
This method validates the reset token and updates the user's password.
|
||||
All existing refresh tokens for the user are revoked for security.
|
||||
|
||||
Args:
|
||||
token: Password reset token.
|
||||
new_password: New plain text password (will be hashed).
|
||||
|
||||
Returns:
|
||||
True if password was reset successfully.
|
||||
|
||||
Raises:
|
||||
InvalidTokenError: If reset token is invalid.
|
||||
ExpiredTokenError: If reset token has expired.
|
||||
RevokedTokenError: If reset token has been used.
|
||||
UserNotFoundError: If user no longer exists.
|
||||
|
||||
Example:
|
||||
>>> auth_service.reset_password(token, "NewSecurePass123!")
|
||||
True
|
||||
"""
|
||||
# Validate reset token
|
||||
token_data = self.token_repository.get_token(token, "password_reset")
|
||||
if not token_data:
|
||||
raise InvalidTokenError("Invalid password reset token")
|
||||
|
||||
if token_data.is_revoked:
|
||||
raise RevokedTokenError("Password reset token has already been used")
|
||||
|
||||
if token_data.expires_at < datetime.now():
|
||||
raise ExpiredTokenError("Password reset token has expired")
|
||||
|
||||
# Get user
|
||||
user = self.user_repository.get_user_by_id(token_data.user_id)
|
||||
if not user:
|
||||
raise UserNotFoundError()
|
||||
|
||||
# Hash new password
|
||||
hashed_password = self.password_manager.hash_password(new_password)
|
||||
|
||||
# Update user password
|
||||
updates = UserUpdate(password=hashed_password)
|
||||
self.user_repository.update_user(user.id, updates)
|
||||
|
||||
# Revoke the reset token
|
||||
self.token_repository.revoke_token(token)
|
||||
|
||||
# Revoke all user's refresh tokens for security
|
||||
self.token_repository.revoke_all_user_tokens(user.id, "refresh")
|
||||
|
||||
return True
|
||||
|
||||
def request_email_verification(self, email: str) -> str:
|
||||
"""
|
||||
Generate an email verification token for a user.
|
||||
|
||||
This method creates a JWT token that can be sent to the user
|
||||
to verify their email address.
|
||||
|
||||
Args:
|
||||
email: User's email address.
|
||||
|
||||
Returns:
|
||||
The email verification token (JWT) to be sent via email.
|
||||
|
||||
Raises:
|
||||
UserNotFoundError: If email is not registered.
|
||||
|
||||
Example:
|
||||
>>> token = auth_service.request_email_verification("user@example.com")
|
||||
>>> # Send token via email service
|
||||
"""
|
||||
user = self.user_repository.get_user_by_email(email)
|
||||
if not user:
|
||||
raise UserNotFoundError(f"No user found with email {email}")
|
||||
|
||||
return self.token_manager.create_email_verification_token(email)
|
||||
|
||||
def verify_email(self, token: str) -> bool:
|
||||
"""
|
||||
Verify a user's email address using a verification token.
|
||||
|
||||
This method decodes the JWT token, extracts the email, and marks
|
||||
the user's email as verified.
|
||||
|
||||
Args:
|
||||
token: Email verification token (JWT).
|
||||
|
||||
Returns:
|
||||
True if email was verified successfully.
|
||||
|
||||
Raises:
|
||||
InvalidTokenError: If verification token is invalid.
|
||||
ExpiredTokenError: If verification token has expired.
|
||||
UserNotFoundError: If user no longer exists.
|
||||
|
||||
Example:
|
||||
>>> auth_service.verify_email(token)
|
||||
True
|
||||
"""
|
||||
# Decode and validate token
|
||||
email = self.token_manager.decode_email_verification_token(token)
|
||||
|
||||
# Get user
|
||||
user = self.user_repository.get_user_by_email(email)
|
||||
if not user:
|
||||
raise UserNotFoundError(f"No user found with email {email}")
|
||||
|
||||
# Update user's verified status
|
||||
updates = UserUpdate(is_verified=True)
|
||||
self.user_repository.update_user(user.id, updates)
|
||||
|
||||
return True
|
||||
|
||||
def get_current_user(self, access_token: str) -> UserInDB:
|
||||
"""
|
||||
Retrieve the current user from an access token.
|
||||
|
||||
This method decodes and validates the JWT access token and
|
||||
retrieves the corresponding user from the database.
|
||||
|
||||
Args:
|
||||
access_token: JWT access token.
|
||||
|
||||
Returns:
|
||||
The user associated with the token.
|
||||
|
||||
Raises:
|
||||
InvalidTokenError: If access token is invalid.
|
||||
ExpiredTokenError: If access token has expired.
|
||||
UserNotFoundError: If user no longer exists.
|
||||
AccountDisabledError: If user account is disabled.
|
||||
|
||||
Example:
|
||||
>>> user = auth_service.get_current_user(access_token)
|
||||
>>> print(user.email)
|
||||
"""
|
||||
# Decode and validate token
|
||||
token_payload = self.token_manager.decode_access_token(access_token)
|
||||
|
||||
# Get user
|
||||
user = self.user_repository.get_user_by_id(token_payload.sub)
|
||||
if not user:
|
||||
raise UserNotFoundError()
|
||||
|
||||
if not user.is_active:
|
||||
raise AccountDisabledError()
|
||||
|
||||
return user
|
||||
|
||||
|
||||
def get_default_sqlite_auth_service(db_path: str, jwt_secret: str) -> AuthService:
|
||||
from my_auth.persistence.sqlite import SQLiteUserRepository
|
||||
from my_auth.persistence.sqlite import SQLiteTokenRepository
|
||||
user_repository = SQLiteUserRepository(db_path=db_path)
|
||||
token_repository = SQLiteTokenRepository(db_path=db_path)
|
||||
password_manager = PasswordManager()
|
||||
token_manager = TokenManager(jwt_secret=jwt_secret)
|
||||
return AuthService(
|
||||
user_repository=user_repository,
|
||||
token_repository=token_repository,
|
||||
password_manager=password_manager,
|
||||
token_manager=token_manager
|
||||
)
|
||||
87
src/my_auth/core/password.py
Normal file
87
src/my_auth/core/password.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""
|
||||
Password management for authentication module.
|
||||
|
||||
This module provides password hashing and verification functionality using
|
||||
bcrypt. It handles secure password storage and validation.
|
||||
"""
|
||||
|
||||
from passlib.context import CryptContext
|
||||
|
||||
|
||||
class PasswordManager:
|
||||
"""
|
||||
Manager for password hashing and verification operations.
|
||||
|
||||
This class uses bcrypt for secure password hashing with configurable
|
||||
cost factor (rounds). Higher rounds provide better security but slower
|
||||
performance.
|
||||
|
||||
Attributes:
|
||||
rounds: Bcrypt cost factor (default: 12). Higher values increase
|
||||
security but also increase computation time.
|
||||
"""
|
||||
|
||||
def __init__(self, rounds: int = 12):
|
||||
"""
|
||||
Initialize the password manager.
|
||||
|
||||
Args:
|
||||
rounds: Bcrypt cost factor (4-31). Default is 12 which provides
|
||||
a good balance between security and performance.
|
||||
"""
|
||||
if rounds < 4 or rounds > 31:
|
||||
raise ValueError("Bcrypt rounds must be between 4 and 31")
|
||||
|
||||
self.rounds = rounds
|
||||
self._context = CryptContext(
|
||||
schemes=["bcrypt"],
|
||||
deprecated="auto",
|
||||
bcrypt__rounds=self.rounds
|
||||
)
|
||||
|
||||
def hash_password(self, password: str) -> str:
|
||||
"""
|
||||
Hash a plain text password using bcrypt.
|
||||
|
||||
This method creates a secure hash of the password that can be
|
||||
safely stored in the database. Each hash includes a salt, making
|
||||
rainbow table attacks ineffective.
|
||||
|
||||
Args:
|
||||
password: Plain text password to hash.
|
||||
|
||||
Returns:
|
||||
Bcrypt hashed password string.
|
||||
|
||||
Example:
|
||||
>>> pm = PasswordManager()
|
||||
>>> hashed = pm.hash_password("SecurePassword123!")
|
||||
>>> print(hashed)
|
||||
$2b$12$abcdefghijklmnopqrstuvwxyz...
|
||||
"""
|
||||
return self._context.hash(password)
|
||||
|
||||
def verify_password(self, plain_password: str, hashed_password: str) -> bool:
|
||||
"""
|
||||
Verify a plain text password against a hashed password.
|
||||
|
||||
This method performs constant-time comparison to prevent timing
|
||||
attacks. It returns True only if the plain password matches the
|
||||
hash.
|
||||
|
||||
Args:
|
||||
plain_password: Plain text password to verify.
|
||||
hashed_password: Bcrypt hashed password from database.
|
||||
|
||||
Returns:
|
||||
True if password matches, False otherwise.
|
||||
|
||||
Example:
|
||||
>>> pm = PasswordManager()
|
||||
>>> hashed = pm.hash_password("SecurePassword123!")
|
||||
>>> pm.verify_password("SecurePassword123!", hashed)
|
||||
True
|
||||
>>> pm.verify_password("WrongPassword", hashed)
|
||||
False
|
||||
"""
|
||||
return self._context.verify(plain_password, hashed_password)
|
||||
265
src/my_auth/core/token.py
Normal file
265
src/my_auth/core/token.py
Normal file
@@ -0,0 +1,265 @@
|
||||
"""
|
||||
Token management for authentication module.
|
||||
|
||||
This module provides functionality for creating and validating different types
|
||||
of tokens: JWT access tokens, opaque refresh tokens, password reset tokens,
|
||||
and email verification tokens.
|
||||
"""
|
||||
|
||||
import secrets
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from jose import JWTError, jwt
|
||||
|
||||
from ..exceptions import InvalidTokenError, ExpiredTokenError
|
||||
from ..models.token import TokenPayload
|
||||
from ..models.user import UserInDB
|
||||
|
||||
|
||||
class TokenManager:
|
||||
"""
|
||||
Manager for token creation and validation operations.
|
||||
|
||||
This class handles multiple token types:
|
||||
- JWT access tokens (stateless, short-lived)
|
||||
- Opaque refresh tokens (stored in DB, long-lived)
|
||||
- Opaque password reset tokens (stored in DB, short-lived)
|
||||
- JWT email verification tokens (stateless)
|
||||
|
||||
Attributes:
|
||||
jwt_secret: Secret key for signing JWT tokens.
|
||||
jwt_algorithm: Algorithm for JWT encoding (default: HS256).
|
||||
access_token_expire_minutes: Validity duration for access tokens.
|
||||
refresh_token_expire_days: Validity duration for refresh tokens.
|
||||
password_reset_token_expire_minutes: Validity duration for reset tokens.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
jwt_secret: str,
|
||||
jwt_algorithm: str = "HS256",
|
||||
access_token_expire_minutes: int = 30,
|
||||
refresh_token_expire_days: int = 7,
|
||||
password_reset_token_expire_minutes: int = 15
|
||||
):
|
||||
"""
|
||||
Initialize the token manager.
|
||||
|
||||
Args:
|
||||
jwt_secret: Secret key for signing JWT tokens.
|
||||
jwt_algorithm: Algorithm for JWT encoding (default: HS256).
|
||||
access_token_expire_minutes: Validity duration for access tokens.
|
||||
refresh_token_expire_days: Validity duration for refresh tokens.
|
||||
password_reset_token_expire_minutes: Validity duration for reset tokens.
|
||||
"""
|
||||
if not jwt_secret:
|
||||
raise ValueError("JWT secret cannot be empty")
|
||||
|
||||
self.jwt_secret = jwt_secret
|
||||
self.jwt_algorithm = jwt_algorithm
|
||||
self.access_token_expire_minutes = access_token_expire_minutes
|
||||
self.refresh_token_expire_days = refresh_token_expire_days
|
||||
self.password_reset_token_expire_minutes = password_reset_token_expire_minutes
|
||||
|
||||
def create_access_token(self, user: UserInDB) -> str:
|
||||
"""
|
||||
Create a JWT access token for a user.
|
||||
|
||||
The token contains the user's ID and email, and is signed with the
|
||||
configured secret. It expires after the configured duration.
|
||||
|
||||
Args:
|
||||
user: The user for whom to create the token.
|
||||
|
||||
Returns:
|
||||
Encoded JWT access token string.
|
||||
|
||||
Example:
|
||||
>>> tm = TokenManager(jwt_secret="secret")
|
||||
>>> token = tm.create_access_token(user)
|
||||
>>> print(token)
|
||||
eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...
|
||||
"""
|
||||
expires_at = datetime.now() + timedelta(minutes=self.access_token_expire_minutes)
|
||||
|
||||
payload = {
|
||||
"sub": user.id,
|
||||
"email": user.email,
|
||||
"exp": int(expires_at.timestamp()),
|
||||
"type": "access"
|
||||
}
|
||||
|
||||
return jwt.encode(payload, self.jwt_secret, algorithm=self.jwt_algorithm)
|
||||
|
||||
def create_refresh_token(self) -> str:
|
||||
"""
|
||||
Create an opaque refresh token.
|
||||
|
||||
This generates a cryptographically secure random token that will be
|
||||
stored in the database along with its expiration time and user association.
|
||||
|
||||
Returns:
|
||||
Secure random token string (64 characters hex).
|
||||
|
||||
Example:
|
||||
>>> tm = TokenManager(jwt_secret="secret")
|
||||
>>> token = tm.create_refresh_token()
|
||||
>>> len(token)
|
||||
64
|
||||
"""
|
||||
return secrets.token_hex(32)
|
||||
|
||||
def create_password_reset_token(self) -> str:
|
||||
"""
|
||||
Create an opaque password reset token.
|
||||
|
||||
This generates a cryptographically secure random token that will be
|
||||
stored in the database with a short expiration time (15 minutes by default).
|
||||
|
||||
Returns:
|
||||
Secure random token string (64 characters hex).
|
||||
|
||||
Example:
|
||||
>>> tm = TokenManager(jwt_secret="secret")
|
||||
>>> token = tm.create_password_reset_token()
|
||||
>>> len(token)
|
||||
64
|
||||
"""
|
||||
return secrets.token_hex(32)
|
||||
|
||||
def create_email_verification_token(self, email: str) -> str:
|
||||
"""
|
||||
Create a JWT email verification token.
|
||||
|
||||
This token is stateless (not stored in DB) and contains the email
|
||||
to be verified. It has no explicit expiration in the payload but
|
||||
should be validated for reasonable freshness.
|
||||
|
||||
Args:
|
||||
email: The email address to encode in the token.
|
||||
|
||||
Returns:
|
||||
Encoded JWT verification token string.
|
||||
|
||||
Example:
|
||||
>>> tm = TokenManager(jwt_secret="secret")
|
||||
>>> token = tm.create_email_verification_token("user@example.com")
|
||||
"""
|
||||
# Set expiration to 7 days for email verification
|
||||
expires_at = datetime.now() + timedelta(days=7)
|
||||
|
||||
payload = {
|
||||
"email": email,
|
||||
"exp": int(expires_at.timestamp()),
|
||||
"type": "email_verification"
|
||||
}
|
||||
|
||||
return jwt.encode(payload, self.jwt_secret, algorithm=self.jwt_algorithm)
|
||||
|
||||
def decode_access_token(self, token: str) -> TokenPayload:
|
||||
"""
|
||||
Decode and validate a JWT access token.
|
||||
|
||||
This method verifies the token signature, checks expiration,
|
||||
and validates the token type.
|
||||
|
||||
Args:
|
||||
token: The JWT access token to decode.
|
||||
|
||||
Returns:
|
||||
TokenPayload containing the decoded claims.
|
||||
|
||||
Raises:
|
||||
InvalidTokenError: If token is malformed or has invalid signature.
|
||||
ExpiredTokenError: If token has expired.
|
||||
|
||||
Example:
|
||||
>>> tm = TokenManager(jwt_secret="secret")
|
||||
>>> payload = tm.decode_access_token(token)
|
||||
>>> print(payload.sub) # user_id
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
self.jwt_secret,
|
||||
algorithms=[self.jwt_algorithm]
|
||||
)
|
||||
|
||||
# Validate token type
|
||||
if payload.get("type") != "access":
|
||||
raise InvalidTokenError("Invalid token type")
|
||||
|
||||
return TokenPayload(
|
||||
sub=payload.get("sub"),
|
||||
email=payload.get("email"),
|
||||
exp=payload.get("exp"),
|
||||
type=payload.get("type")
|
||||
)
|
||||
|
||||
except JWTError as e:
|
||||
if "expired" in str(e).lower():
|
||||
raise ExpiredTokenError("Access token has expired")
|
||||
raise InvalidTokenError(f"Invalid access token: {str(e)}")
|
||||
|
||||
def decode_email_verification_token(self, token: str) -> str:
|
||||
"""
|
||||
Decode and validate an email verification JWT token.
|
||||
|
||||
This method verifies the token signature, checks expiration,
|
||||
and extracts the email address.
|
||||
|
||||
Args:
|
||||
token: The JWT email verification token to decode.
|
||||
|
||||
Returns:
|
||||
The email address contained in the token.
|
||||
|
||||
Raises:
|
||||
InvalidTokenError: If token is malformed or has invalid signature.
|
||||
ExpiredTokenError: If token has expired.
|
||||
|
||||
Example:
|
||||
>>> tm = TokenManager(jwt_secret="secret")
|
||||
>>> email = tm.decode_email_verification_token(token)
|
||||
>>> print(email)
|
||||
user@example.com
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
self.jwt_secret,
|
||||
algorithms=[self.jwt_algorithm]
|
||||
)
|
||||
|
||||
# Validate token type
|
||||
if payload.get("type") != "email_verification":
|
||||
raise InvalidTokenError("Invalid token type")
|
||||
|
||||
email = payload.get("email")
|
||||
if not email:
|
||||
raise InvalidTokenError("Email not found in token")
|
||||
|
||||
return email
|
||||
|
||||
except JWTError as e:
|
||||
if "expired" in str(e).lower():
|
||||
raise ExpiredTokenError("Email verification token has expired")
|
||||
raise InvalidTokenError(f"Invalid email verification token: {str(e)}")
|
||||
|
||||
def get_refresh_token_expiration(self) -> datetime:
|
||||
"""
|
||||
Calculate expiration datetime for a refresh token.
|
||||
|
||||
Returns:
|
||||
Datetime when a newly created refresh token should expire.
|
||||
"""
|
||||
return datetime.now() + timedelta(days=self.refresh_token_expire_days)
|
||||
|
||||
def get_password_reset_token_expiration(self) -> datetime:
|
||||
"""
|
||||
Calculate expiration datetime for a password reset token.
|
||||
|
||||
Returns:
|
||||
Datetime when a newly created reset token should expire.
|
||||
"""
|
||||
return datetime.now() + timedelta(minutes=self.password_reset_token_expire_minutes)
|
||||
198
src/my_auth/exceptions.py
Normal file
198
src/my_auth/exceptions.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""
|
||||
Custom exceptions for authentication module.
|
||||
|
||||
This module defines all custom exceptions used throughout the authentication
|
||||
system. Each exception is designed to be caught and converted to appropriate
|
||||
HTTP responses by FastAPI exception handlers.
|
||||
"""
|
||||
|
||||
|
||||
class AuthError(Exception):
|
||||
"""
|
||||
Base exception for all authentication-related errors.
|
||||
|
||||
This base class provides common attributes for all auth exceptions
|
||||
and facilitates centralized error handling.
|
||||
|
||||
Attributes:
|
||||
message: Human-readable error message.
|
||||
status_code: HTTP status code associated with this error.
|
||||
error_code: Unique error code for client-side handling (optional).
|
||||
"""
|
||||
|
||||
def __init__(self, message: str, status_code: int = 500, error_code: str = None):
|
||||
"""
|
||||
Initialize the authentication error.
|
||||
|
||||
Args:
|
||||
message: Human-readable error message.
|
||||
status_code: HTTP status code (default: 500).
|
||||
error_code: Unique error code for client identification.
|
||||
"""
|
||||
self.message = message
|
||||
self.status_code = status_code
|
||||
self.error_code = error_code or self.__class__.__name__
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class InvalidCredentialsError(AuthError):
|
||||
"""
|
||||
Exception raised when login credentials are invalid.
|
||||
|
||||
This exception is raised during authentication when the provided
|
||||
email or password does not match any user in the database.
|
||||
|
||||
HTTP Status: 401 Unauthorized
|
||||
"""
|
||||
|
||||
def __init__(self, message: str = "Invalid email or password"):
|
||||
"""
|
||||
Initialize invalid credentials error.
|
||||
|
||||
Args:
|
||||
message: Custom error message (default: "Invalid email or password").
|
||||
"""
|
||||
super().__init__(message=message, status_code=401)
|
||||
|
||||
|
||||
class UserAlreadyExistsError(AuthError):
|
||||
"""
|
||||
Exception raised when attempting to create a user with an existing email.
|
||||
|
||||
This exception is raised during registration when the provided email
|
||||
address is already associated with an existing user account.
|
||||
|
||||
HTTP Status: 409 Conflict
|
||||
"""
|
||||
|
||||
def __init__(self, message: str = "User with this email already exists"):
|
||||
"""
|
||||
Initialize user already exists error.
|
||||
|
||||
Args:
|
||||
message: Custom error message (default: "User with this email already exists").
|
||||
"""
|
||||
super().__init__(message=message, status_code=409)
|
||||
|
||||
|
||||
class UserNotFoundError(AuthError):
|
||||
"""
|
||||
Exception raised when a requested user cannot be found.
|
||||
|
||||
This exception is raised when attempting to retrieve, update, or delete
|
||||
a user that does not exist in the database.
|
||||
|
||||
HTTP Status: 404 Not Found
|
||||
"""
|
||||
|
||||
def __init__(self, message: str = "User not found"):
|
||||
"""
|
||||
Initialize user not found error.
|
||||
|
||||
Args:
|
||||
message: Custom error message (default: "User not found").
|
||||
"""
|
||||
super().__init__(message=message, status_code=404)
|
||||
|
||||
|
||||
class InvalidTokenError(AuthError):
|
||||
"""
|
||||
Exception raised when a token is invalid or malformed.
|
||||
|
||||
This exception is raised when a token cannot be decoded, has an invalid
|
||||
signature, or contains invalid data. This applies to JWT tokens, refresh
|
||||
tokens, and password reset tokens.
|
||||
|
||||
HTTP Status: 401 Unauthorized
|
||||
"""
|
||||
|
||||
def __init__(self, message: str = "Invalid token"):
|
||||
"""
|
||||
Initialize invalid token error.
|
||||
|
||||
Args:
|
||||
message: Custom error message (default: "Invalid token").
|
||||
"""
|
||||
super().__init__(message=message, status_code=401)
|
||||
|
||||
|
||||
class ExpiredTokenError(AuthError):
|
||||
"""
|
||||
Exception raised when a token has expired.
|
||||
|
||||
This exception is raised when attempting to use a token that has
|
||||
passed its expiration time. This applies to all token types.
|
||||
|
||||
HTTP Status: 401 Unauthorized
|
||||
"""
|
||||
|
||||
def __init__(self, message: str = "Token has expired"):
|
||||
"""
|
||||
Initialize expired token error.
|
||||
|
||||
Args:
|
||||
message: Custom error message (default: "Token has expired").
|
||||
"""
|
||||
super().__init__(message=message, status_code=401)
|
||||
|
||||
|
||||
class RevokedTokenError(AuthError):
|
||||
"""
|
||||
Exception raised when attempting to use a revoked token.
|
||||
|
||||
This exception is raised when a token has been explicitly revoked,
|
||||
typically after logout or when a user's tokens are invalidated for
|
||||
security reasons.
|
||||
|
||||
HTTP Status: 401 Unauthorized
|
||||
"""
|
||||
|
||||
def __init__(self, message: str = "Token has been revoked"):
|
||||
"""
|
||||
Initialize revoked token error.
|
||||
|
||||
Args:
|
||||
message: Custom error message (default: "Token has been revoked").
|
||||
"""
|
||||
super().__init__(message=message, status_code=401)
|
||||
|
||||
|
||||
class EmailNotVerifiedError(AuthError):
|
||||
"""
|
||||
Exception raised when an action requires email verification.
|
||||
|
||||
This exception is raised when a user attempts to perform an action
|
||||
that requires their email to be verified, but their account email
|
||||
has not been verified yet.
|
||||
|
||||
HTTP Status: 403 Forbidden
|
||||
"""
|
||||
|
||||
def __init__(self, message: str = "Email address is not verified"):
|
||||
"""
|
||||
Initialize email not verified error.
|
||||
|
||||
Args:
|
||||
message: Custom error message (default: "Email address is not verified").
|
||||
"""
|
||||
super().__init__(message=message, status_code=403)
|
||||
|
||||
|
||||
class AccountDisabledError(AuthError):
|
||||
"""
|
||||
Exception raised when attempting to access a disabled account.
|
||||
|
||||
This exception is raised when a user attempts to login or perform
|
||||
actions with an account that has been disabled by an administrator.
|
||||
|
||||
HTTP Status: 403 Forbidden
|
||||
"""
|
||||
|
||||
def __init__(self, message: str = "Account has been disabled"):
|
||||
"""
|
||||
Initialize account disabled error.
|
||||
|
||||
Args:
|
||||
message: Custom error message (default: "Account has been disabled").
|
||||
"""
|
||||
super().__init__(message=message, status_code=403)
|
||||
0
src/my_auth/models/__init__.py
Normal file
0
src/my_auth/models/__init__.py
Normal file
91
src/my_auth/models/email_verification.py
Normal file
91
src/my_auth/models/email_verification.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""
|
||||
Email verification and password reset models for authentication module.
|
||||
|
||||
This module defines Pydantic models for email verification and password
|
||||
reset operations including request and confirmation models.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, EmailStr, field_validator
|
||||
|
||||
from my_auth.models.validators import validate_password_strength
|
||||
|
||||
|
||||
class EmailVerificationRequest(BaseModel):
|
||||
"""
|
||||
Request model for email verification.
|
||||
|
||||
This model is used when a user requests an email verification link
|
||||
to be sent to their email address.
|
||||
|
||||
Attributes:
|
||||
email: The email address to send the verification link to.
|
||||
"""
|
||||
|
||||
email: EmailStr
|
||||
|
||||
|
||||
class EmailVerificationConfirm(BaseModel):
|
||||
"""
|
||||
Confirmation model for email verification.
|
||||
|
||||
This model is used when a user clicks the verification link and
|
||||
submits the token to confirm their email address.
|
||||
|
||||
Attributes:
|
||||
token: JWT token received in the verification email.
|
||||
"""
|
||||
|
||||
token: str
|
||||
|
||||
|
||||
class PasswordResetRequest(BaseModel):
|
||||
"""
|
||||
Request model for password reset.
|
||||
|
||||
This model is used when a user requests a password reset link
|
||||
to be sent to their email address.
|
||||
|
||||
Attributes:
|
||||
email: The email address to send the password reset link to.
|
||||
"""
|
||||
|
||||
email: EmailStr
|
||||
|
||||
|
||||
class PasswordResetConfirm(BaseModel):
|
||||
"""
|
||||
Confirmation model for password reset.
|
||||
|
||||
This model is used when a user submits a new password along with
|
||||
their reset token. The new password is validated with the same
|
||||
strict rules as user registration.
|
||||
|
||||
Attributes:
|
||||
token: Random secure token received in the password reset email.
|
||||
new_password: The new password that must meet security requirements:
|
||||
- Minimum 8 characters
|
||||
- At least 1 uppercase letter
|
||||
- At least 1 lowercase letter
|
||||
- At least 1 digit
|
||||
- At least 1 special character
|
||||
"""
|
||||
|
||||
token: str
|
||||
new_password: str
|
||||
|
||||
@field_validator('new_password')
|
||||
@classmethod
|
||||
def validate_password(cls, value: str) -> str:
|
||||
"""
|
||||
Validate new password meets security requirements.
|
||||
|
||||
Args:
|
||||
value: The new password to validate.
|
||||
|
||||
Returns:
|
||||
The validated password.
|
||||
|
||||
Raises:
|
||||
ValueError: If password does not meet security requirements.
|
||||
"""
|
||||
return validate_password_strength(value)
|
||||
88
src/my_auth/models/token.py
Normal file
88
src/my_auth/models/token.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""
|
||||
Token models for authentication module.
|
||||
|
||||
This module defines Pydantic models for token-related operations including
|
||||
JWT payloads, API responses, requests, and database storage.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
"""
|
||||
JWT token payload structure.
|
||||
|
||||
This model represents the data contained within a JWT access token.
|
||||
It follows the standard JWT claims format.
|
||||
|
||||
Attributes:
|
||||
sub: Subject (user ID).
|
||||
email: User's email address.
|
||||
exp: Expiration timestamp (Unix timestamp).
|
||||
type: Token type identifier (always "access" for access tokens).
|
||||
"""
|
||||
|
||||
sub: str
|
||||
email: str
|
||||
exp: int
|
||||
type: Literal["access"] = "access"
|
||||
|
||||
|
||||
class AccessTokenResponse(BaseModel):
|
||||
"""
|
||||
Response model after successful authentication.
|
||||
|
||||
This model is returned after login or token refresh operations.
|
||||
It contains both access and refresh tokens.
|
||||
|
||||
Attributes:
|
||||
access_token: JWT access token for API authentication.
|
||||
refresh_token: Opaque refresh token for obtaining new access tokens.
|
||||
token_type: OAuth2 token type (always "bearer").
|
||||
"""
|
||||
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
|
||||
|
||||
class RefreshTokenRequest(BaseModel):
|
||||
"""
|
||||
Request model for token refresh operation.
|
||||
|
||||
This model is used when a client wants to obtain a new access token
|
||||
using their refresh token.
|
||||
|
||||
Attributes:
|
||||
refresh_token: The refresh token to exchange for a new access token.
|
||||
"""
|
||||
|
||||
refresh_token: str
|
||||
|
||||
|
||||
class TokenData(BaseModel):
|
||||
"""
|
||||
Token storage model for database.
|
||||
|
||||
This model represents tokens that need to be persisted in the database.
|
||||
It uses a discriminator field to distinguish between different token types
|
||||
(refresh tokens and password reset tokens) in a single collection/table.
|
||||
|
||||
Attributes:
|
||||
token: The token string (random secure string for refresh and reset).
|
||||
token_type: Discriminator field ("refresh" or "password_reset").
|
||||
user_id: ID of the user this token belongs to.
|
||||
expires_at: When the token expires.
|
||||
created_at: When the token was created.
|
||||
is_revoked: Whether the token has been revoked (for logout/security).
|
||||
"""
|
||||
|
||||
token: str
|
||||
token_type: Literal["refresh", "password_reset"]
|
||||
user_id: str
|
||||
expires_at: datetime
|
||||
created_at: datetime
|
||||
is_revoked: bool = False
|
||||
189
src/my_auth/models/user.py
Normal file
189
src/my_auth/models/user.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""
|
||||
User models for authentication module.
|
||||
|
||||
This module defines Pydantic models for user-related operations including
|
||||
creation, updates, database representation, and API responses.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, EmailStr, Field, field_validator
|
||||
|
||||
from my_auth.models.validators import validate_password_strength, validate_username_not_empty
|
||||
|
||||
|
||||
class UserBase(BaseModel):
|
||||
"""
|
||||
Base user model with common fields.
|
||||
|
||||
This model contains the shared fields used across different user representations.
|
||||
It serves as the foundation for other user models.
|
||||
|
||||
Attributes:
|
||||
email: User's email address (unique in database).
|
||||
username: User's display name (required, non-unique).
|
||||
roles: List of user roles (free-form strings, no defaults).
|
||||
user_settings: Dictionary for storing custom user settings.
|
||||
"""
|
||||
|
||||
email: EmailStr
|
||||
username: str
|
||||
roles: list[str] = Field(default_factory=list)
|
||||
user_settings: dict = Field(default_factory=dict)
|
||||
|
||||
|
||||
class UserCreate(UserBase):
|
||||
"""
|
||||
Model for user creation (registration).
|
||||
|
||||
This model extends UserBase with a password field and enforces strict
|
||||
password validation rules.
|
||||
|
||||
Attributes:
|
||||
password: Plain text password that must meet security requirements:
|
||||
- Minimum 8 characters
|
||||
- At least 1 uppercase letter
|
||||
- At least 1 lowercase letter
|
||||
- At least 1 digit
|
||||
- At least 1 special character
|
||||
"""
|
||||
|
||||
password: str
|
||||
|
||||
@field_validator('password')
|
||||
@classmethod
|
||||
def validate_password(cls, value: str) -> Optional[str]:
|
||||
"""
|
||||
Validate password meets security requirements.
|
||||
|
||||
Args:
|
||||
value: The password to validate.
|
||||
|
||||
Returns:
|
||||
The validated password.
|
||||
|
||||
Raises:
|
||||
ValueError: If password does not meet security requirements.
|
||||
"""
|
||||
return validate_password_strength(value)
|
||||
|
||||
@field_validator('username')
|
||||
@classmethod
|
||||
def validate_username(cls, value: str) -> str:
|
||||
"""
|
||||
Validate username is not empty and has reasonable length.
|
||||
|
||||
Args:
|
||||
value: The username to validate.
|
||||
|
||||
Returns:
|
||||
The validated username.
|
||||
|
||||
Raises:
|
||||
ValueError: If username is empty or too long.
|
||||
"""
|
||||
return validate_username_not_empty(value)
|
||||
|
||||
|
||||
class UserUpdate(BaseModel):
|
||||
"""
|
||||
Model for user updates.
|
||||
|
||||
All fields are optional to allow partial updates. Password updates
|
||||
are validated with the same strict rules as UserCreate.
|
||||
|
||||
Attributes:
|
||||
email: Optional new email address.
|
||||
username: Optional new username.
|
||||
password: Optional new password (will be hashed).
|
||||
roles: Optional new roles list.
|
||||
user_settings: Optional new settings dict.
|
||||
is_verified: Optional email verification status.
|
||||
is_active: Optional account active status.
|
||||
"""
|
||||
|
||||
email: Optional[EmailStr] = None
|
||||
username: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
roles: Optional[list[str]] = None
|
||||
user_settings: Optional[dict] = None
|
||||
is_verified: Optional[bool] = None
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
@field_validator('password')
|
||||
@classmethod
|
||||
def validate_password_strength(cls, value: Optional[str]) -> Optional[str]:
|
||||
"""
|
||||
Validate password meets security requirements if provided.
|
||||
|
||||
Args:
|
||||
value: The password to validate (can be None).
|
||||
|
||||
Returns:
|
||||
The validated password or None.
|
||||
|
||||
Raises:
|
||||
ValueError: If password is provided but does not meet security requirements.
|
||||
"""
|
||||
return validate_password_strength(value)
|
||||
|
||||
@field_validator('username')
|
||||
@classmethod
|
||||
def validate_username(cls, value: Optional[str]) -> Optional[str]:
|
||||
"""
|
||||
Validate username if provided.
|
||||
|
||||
Args:
|
||||
value: The username to validate (can be None).
|
||||
|
||||
Returns:
|
||||
The validated username or None.
|
||||
|
||||
Raises:
|
||||
ValueError: If username is provided but empty or too long.
|
||||
"""
|
||||
return validate_username_not_empty(value)
|
||||
|
||||
|
||||
class UserInDB(UserBase):
|
||||
"""
|
||||
Complete user model as stored in database.
|
||||
|
||||
This model represents the full user entity including security-sensitive
|
||||
fields like hashed_password. It should not be directly exposed via API.
|
||||
|
||||
Attributes:
|
||||
id: Unique user identifier (string for database compatibility).
|
||||
hashed_password: Bcrypt hashed password.
|
||||
is_verified: Whether the user's email has been verified.
|
||||
is_active: Whether the user account is active.
|
||||
created_at: Timestamp of account creation.
|
||||
updated_at: Timestamp of last update.
|
||||
"""
|
||||
|
||||
id: str
|
||||
hashed_password: str
|
||||
is_verified: bool = False
|
||||
is_active: bool = True
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class UserResponse(UserBase):
|
||||
"""
|
||||
User model for API responses.
|
||||
|
||||
This is a safe representation of the user without sensitive fields
|
||||
like hashed_password. Used for API responses where user information
|
||||
needs to be exposed.
|
||||
|
||||
Attributes:
|
||||
id: Unique user identifier.
|
||||
created_at: Timestamp of account creation.
|
||||
updated_at: Timestamp of last update.
|
||||
"""
|
||||
|
||||
id: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
57
src/my_auth/models/validators.py
Normal file
57
src/my_auth/models/validators.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from typing import Optional
|
||||
|
||||
def validate_password_strength(password: str) -> Optional[str]:
|
||||
"""
|
||||
Validate password meets security requirements.
|
||||
|
||||
Args:
|
||||
password: The password to validate.
|
||||
|
||||
Returns:
|
||||
The validated password.
|
||||
|
||||
Raises:
|
||||
ValueError: If password does not meet security requirements.
|
||||
"""
|
||||
if password is None:
|
||||
return password
|
||||
|
||||
if len(password) < 8:
|
||||
raise ValueError('Password must be at least 8 characters long')
|
||||
|
||||
if not any(char.isupper() for char in password):
|
||||
raise ValueError('Password must contain at least one uppercase letter')
|
||||
|
||||
if not any(char.islower() for char in password):
|
||||
raise ValueError('Password must contain at least one lowercase letter')
|
||||
|
||||
if not any(char.isdigit() for char in password):
|
||||
raise ValueError('Password must contain at least one digit')
|
||||
|
||||
special_characters = "!@#$%^&*()_+-=[]{}|;:,.<>?/~`"
|
||||
if not any(char in special_characters for char in password):
|
||||
raise ValueError('Password must contain at least one special character')
|
||||
|
||||
return password
|
||||
|
||||
|
||||
def validate_username_not_empty(user_name: str) -> str:
|
||||
"""
|
||||
Validate username is not empty and has reasonable length.
|
||||
|
||||
Args:
|
||||
user_name: The username to validate.
|
||||
|
||||
Returns:
|
||||
The validated username.
|
||||
|
||||
Raises:
|
||||
ValueError: If username is empty or too long.
|
||||
"""
|
||||
if not user_name or not user_name.strip():
|
||||
raise ValueError('Username cannot be empty')
|
||||
|
||||
if len(user_name) > 100:
|
||||
raise ValueError('Username cannot exceed 100 characters')
|
||||
|
||||
return user_name.strip()
|
||||
0
src/my_auth/persistence/__init__.py
Normal file
0
src/my_auth/persistence/__init__.py
Normal file
226
src/my_auth/persistence/base.py
Normal file
226
src/my_auth/persistence/base.py
Normal file
@@ -0,0 +1,226 @@
|
||||
"""
|
||||
Abstract base classes for persistence layer.
|
||||
|
||||
This module defines the interfaces that all database implementations must
|
||||
follow. It provides abstract base classes for user and token repositories,
|
||||
ensuring consistency across different database backends (MongoDB, SQLite,
|
||||
PostgreSQL, custom engines, etc.).
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from ..models.token import TokenData
|
||||
from ..models.user import UserCreate, UserInDB, UserUpdate
|
||||
|
||||
|
||||
class UserRepository(ABC):
|
||||
"""
|
||||
Abstract base class for user persistence operations.
|
||||
|
||||
This interface defines all methods required for user management in the
|
||||
database. Concrete implementations must provide all these methods for
|
||||
their specific database backend.
|
||||
|
||||
All methods should raise appropriate exceptions (UserNotFoundError,
|
||||
UserAlreadyExistsError, etc.) when operations fail.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def create_user(self, user_data: UserCreate, hashed_password: str) -> UserInDB:
|
||||
"""
|
||||
Create a new user in the database.
|
||||
|
||||
Args:
|
||||
user_data: User information from registration.
|
||||
hashed_password: Pre-hashed password (hashing is done by the service layer).
|
||||
|
||||
Returns:
|
||||
The created user with database-assigned ID and timestamps.
|
||||
|
||||
Raises:
|
||||
UserAlreadyExistsError: If email already exists in database.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_user_by_email(self, email: str) -> UserInDB | None:
|
||||
"""
|
||||
Retrieve a user by their email address.
|
||||
|
||||
Args:
|
||||
email: The email address to search for.
|
||||
|
||||
Returns:
|
||||
The user if found, None otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_user_by_id(self, user_id: str) -> UserInDB | None:
|
||||
"""
|
||||
Retrieve a user by their unique identifier.
|
||||
|
||||
Args:
|
||||
user_id: The unique user identifier.
|
||||
|
||||
Returns:
|
||||
The user if found, None otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_user(self, user_id: str, updates: UserUpdate) -> UserInDB:
|
||||
"""
|
||||
Update an existing user's information.
|
||||
|
||||
Only fields present in the updates object (non-None) should be updated.
|
||||
The updated_at timestamp should be automatically set to the current time.
|
||||
|
||||
Args:
|
||||
user_id: The unique user identifier.
|
||||
updates: Pydantic model containing fields to update.
|
||||
|
||||
Returns:
|
||||
The updated user.
|
||||
|
||||
Raises:
|
||||
UserNotFoundError: If user does not exist.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_user(self, user_id: str) -> bool:
|
||||
"""
|
||||
Delete a user from the database.
|
||||
|
||||
Implementation can choose between soft delete (setting is_active=False)
|
||||
or hard delete (removing from database). The choice depends on the
|
||||
application's requirements for data retention.
|
||||
|
||||
Args:
|
||||
user_id: The unique user identifier.
|
||||
|
||||
Returns:
|
||||
True if user was deleted, False if user was not found.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def email_exists(self, email: str) -> bool:
|
||||
"""
|
||||
Check if an email address is already registered.
|
||||
|
||||
This is useful for validating registration requests without
|
||||
retrieving the full user object.
|
||||
|
||||
Args:
|
||||
email: The email address to check.
|
||||
|
||||
Returns:
|
||||
True if email exists, False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class TokenRepository(ABC):
|
||||
"""
|
||||
Abstract base class for token persistence operations.
|
||||
|
||||
This interface defines all methods required for token management in the
|
||||
database. It handles both refresh tokens and password reset tokens using
|
||||
a discriminator field (token_type).
|
||||
|
||||
All methods should raise appropriate exceptions (InvalidTokenError,
|
||||
ExpiredTokenError, etc.) when operations fail.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def save_token(self, token_data: TokenData) -> None:
|
||||
"""
|
||||
Save a token to the database.
|
||||
|
||||
This is used for both refresh tokens and password reset tokens.
|
||||
The token_type field in TokenData distinguishes between them.
|
||||
|
||||
Args:
|
||||
token_data: Complete token information including type, user_id, expiration.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_token(self, token: str, token_type: str) -> TokenData | None:
|
||||
"""
|
||||
Retrieve a token from the database.
|
||||
|
||||
Args:
|
||||
token: The token string to search for.
|
||||
token_type: Type of token ("refresh" or "password_reset").
|
||||
|
||||
Returns:
|
||||
The token data if found, None otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def revoke_token(self, token: str) -> bool:
|
||||
"""
|
||||
Revoke a specific token.
|
||||
|
||||
This sets the is_revoked flag to True, preventing the token from
|
||||
being used again without physically deleting it (useful for audit trails).
|
||||
|
||||
Args:
|
||||
token: The token string to revoke.
|
||||
|
||||
Returns:
|
||||
True if token was revoked, False if token was not found.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def revoke_all_user_tokens(self, user_id: str, token_type: str) -> int:
|
||||
"""
|
||||
Revoke all tokens of a specific type for a user.
|
||||
|
||||
This is useful for logout-all-devices functionality or when a user's
|
||||
password is changed (invalidating all refresh tokens).
|
||||
|
||||
Args:
|
||||
user_id: The user whose tokens should be revoked.
|
||||
token_type: Type of tokens to revoke ("refresh" or "password_reset").
|
||||
|
||||
Returns:
|
||||
Number of tokens revoked.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_expired_tokens(self) -> int:
|
||||
"""
|
||||
Delete all expired tokens from the database.
|
||||
|
||||
This maintenance operation should be called periodically to keep
|
||||
the tokens collection clean. It permanently removes tokens that
|
||||
have passed their expiration time.
|
||||
|
||||
Returns:
|
||||
Number of tokens deleted.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_token_valid(self, token: str, token_type: str) -> bool:
|
||||
"""
|
||||
Check if a token is valid (exists, not revoked, not expired).
|
||||
|
||||
This is a convenience method that combines multiple checks into
|
||||
a single boolean result.
|
||||
|
||||
Args:
|
||||
token: The token string to validate.
|
||||
token_type: Type of token ("refresh" or "password_reset").
|
||||
|
||||
Returns:
|
||||
True if token is valid and can be used, False otherwise.
|
||||
"""
|
||||
pass
|
||||
611
src/my_auth/persistence/sqlite.py
Normal file
611
src/my_auth/persistence/sqlite.py
Normal file
@@ -0,0 +1,611 @@
|
||||
"""
|
||||
SQLite implementation of persistence layer.
|
||||
|
||||
This module provides SQLite database implementations for user and token
|
||||
repositories. It uses the standard library sqlite3 module and handles
|
||||
JSON serialization for complex fields.
|
||||
"""
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from .base import UserRepository, TokenRepository
|
||||
from ..models.user import UserCreate, UserInDB, UserUpdate
|
||||
from ..models.token import TokenData
|
||||
from ..exceptions import UserAlreadyExistsError, UserNotFoundError
|
||||
|
||||
|
||||
class SQLiteUserRepository(UserRepository):
|
||||
"""
|
||||
SQLite implementation of UserRepository.
|
||||
|
||||
This implementation uses sqlite3 to manage user data. JSON fields
|
||||
(roles, user_settings) are serialized/deserialized automatically.
|
||||
The database schema is created automatically on initialization.
|
||||
|
||||
Attributes:
|
||||
db_path: Path to the SQLite database file.
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: str):
|
||||
"""
|
||||
Initialize SQLite user repository.
|
||||
|
||||
Creates the users table and indexes if they don't exist.
|
||||
|
||||
Args:
|
||||
db_path: Path to the SQLite database file.
|
||||
"""
|
||||
self.db_path = db_path
|
||||
self._create_tables()
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""
|
||||
Create users table and indexes if they don't exist.
|
||||
|
||||
The table schema includes:
|
||||
- id: Primary key (UUID string)
|
||||
- email: Unique, indexed for fast lookups
|
||||
- username: Non-unique display name
|
||||
- hashed_password: Bcrypt hash
|
||||
- roles: JSON array of role strings
|
||||
- user_settings: JSON object for custom settings
|
||||
- is_verified: Boolean (stored as INTEGER)
|
||||
- is_active: Boolean (stored as INTEGER)
|
||||
- created_at: ISO format timestamp
|
||||
- updated_at: ISO format timestamp
|
||||
"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Create users table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS users
|
||||
(
|
||||
id
|
||||
TEXT
|
||||
PRIMARY
|
||||
KEY,
|
||||
email
|
||||
TEXT
|
||||
UNIQUE
|
||||
NOT
|
||||
NULL,
|
||||
username
|
||||
TEXT
|
||||
NOT
|
||||
NULL,
|
||||
hashed_password
|
||||
TEXT
|
||||
NOT
|
||||
NULL,
|
||||
roles
|
||||
TEXT
|
||||
NOT
|
||||
NULL,
|
||||
user_settings
|
||||
TEXT
|
||||
NOT
|
||||
NULL,
|
||||
is_verified
|
||||
INTEGER
|
||||
NOT
|
||||
NULL
|
||||
DEFAULT
|
||||
0,
|
||||
is_active
|
||||
INTEGER
|
||||
NOT
|
||||
NULL
|
||||
DEFAULT
|
||||
1,
|
||||
created_at
|
||||
TEXT
|
||||
NOT
|
||||
NULL,
|
||||
updated_at
|
||||
TEXT
|
||||
NOT
|
||||
NULL
|
||||
)
|
||||
""")
|
||||
|
||||
# Create index on email for fast lookups
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_users_email
|
||||
ON users(email)
|
||||
""")
|
||||
|
||||
conn.commit()
|
||||
|
||||
def _row_to_user(self, row: tuple) -> UserInDB:
|
||||
"""
|
||||
Convert a database row to a UserInDB model.
|
||||
|
||||
Args:
|
||||
row: Database row tuple.
|
||||
|
||||
Returns:
|
||||
UserInDB model instance.
|
||||
"""
|
||||
return UserInDB(
|
||||
id=row[0],
|
||||
email=row[1],
|
||||
username=row[2],
|
||||
hashed_password=row[3],
|
||||
roles=json.loads(row[4]),
|
||||
user_settings=json.loads(row[5]),
|
||||
is_verified=bool(row[6]),
|
||||
is_active=bool(row[7]),
|
||||
created_at=datetime.fromisoformat(row[8]),
|
||||
updated_at=datetime.fromisoformat(row[9])
|
||||
)
|
||||
|
||||
def create_user(self, user_data: UserCreate, hashed_password: str) -> UserInDB:
|
||||
"""
|
||||
Create a new user in the database.
|
||||
|
||||
Args:
|
||||
user_data: User information from registration.
|
||||
hashed_password: Pre-hashed password.
|
||||
|
||||
Returns:
|
||||
The created user with generated ID and timestamps.
|
||||
|
||||
Raises:
|
||||
UserAlreadyExistsError: If email already exists.
|
||||
"""
|
||||
if self.email_exists(user_data.email):
|
||||
raise UserAlreadyExistsError(f"User with email {user_data.email} already exists")
|
||||
|
||||
user_id = str(uuid4())
|
||||
now = datetime.utcnow()
|
||||
|
||||
user = UserInDB(
|
||||
id=user_id,
|
||||
email=user_data.email,
|
||||
username=user_data.username,
|
||||
hashed_password=hashed_password,
|
||||
roles=user_data.roles,
|
||||
user_settings=user_data.user_settings,
|
||||
is_verified=False,
|
||||
is_active=True,
|
||||
created_at=now,
|
||||
updated_at=now
|
||||
)
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
INSERT INTO users (id, email, username, hashed_password, roles,
|
||||
user_settings, is_verified, is_active, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
user.id,
|
||||
user.email,
|
||||
user.username,
|
||||
user.hashed_password,
|
||||
json.dumps(user.roles),
|
||||
json.dumps(user.user_settings),
|
||||
int(user.is_verified),
|
||||
int(user.is_active),
|
||||
user.created_at.isoformat(),
|
||||
user.updated_at.isoformat()
|
||||
))
|
||||
conn.commit()
|
||||
|
||||
return user
|
||||
|
||||
def get_user_by_email(self, email: str) -> Optional[UserInDB]:
|
||||
"""
|
||||
Retrieve a user by their email address.
|
||||
|
||||
Args:
|
||||
email: The email address to search for.
|
||||
|
||||
Returns:
|
||||
The user if found, None otherwise.
|
||||
"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
SELECT id,
|
||||
email,
|
||||
username,
|
||||
hashed_password,
|
||||
roles,
|
||||
user_settings,
|
||||
is_verified,
|
||||
is_active,
|
||||
created_at,
|
||||
updated_at
|
||||
FROM users
|
||||
WHERE email = ?
|
||||
""", (email,))
|
||||
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
return self._row_to_user(row)
|
||||
return None
|
||||
|
||||
def get_user_by_id(self, user_id: str) -> Optional[UserInDB]:
|
||||
"""
|
||||
Retrieve a user by their unique identifier.
|
||||
|
||||
Args:
|
||||
user_id: The unique user identifier.
|
||||
|
||||
Returns:
|
||||
The user if found, None otherwise.
|
||||
"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
SELECT id,
|
||||
email,
|
||||
username,
|
||||
hashed_password,
|
||||
roles,
|
||||
user_settings,
|
||||
is_verified,
|
||||
is_active,
|
||||
created_at,
|
||||
updated_at
|
||||
FROM users
|
||||
WHERE id = ?
|
||||
""", (user_id,))
|
||||
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
return self._row_to_user(row)
|
||||
return None
|
||||
|
||||
def update_user(self, user_id: str, updates: UserUpdate) -> UserInDB:
|
||||
"""
|
||||
Update an existing user's information.
|
||||
|
||||
Only non-None fields in the updates object are applied.
|
||||
The updated_at timestamp is automatically set to current time.
|
||||
|
||||
Args:
|
||||
user_id: The unique user identifier.
|
||||
updates: Pydantic model containing fields to update.
|
||||
|
||||
Returns:
|
||||
The updated user.
|
||||
|
||||
Raises:
|
||||
UserNotFoundError: If user does not exist.
|
||||
"""
|
||||
user = self.get_user_by_id(user_id)
|
||||
if not user:
|
||||
raise UserNotFoundError(f"User with id {user_id} not found")
|
||||
|
||||
# Build update query dynamically based on non-None fields
|
||||
update_fields = []
|
||||
update_values = []
|
||||
|
||||
update_data = updates.model_dump(exclude_unset=True)
|
||||
|
||||
for field, value in update_data.items():
|
||||
if value is not None:
|
||||
if field == "roles":
|
||||
update_fields.append("roles = ?")
|
||||
update_values.append(json.dumps(value))
|
||||
elif field == "user_settings":
|
||||
update_fields.append("user_settings = ?")
|
||||
update_values.append(json.dumps(value))
|
||||
elif field == "password":
|
||||
# Password should be hashed before calling this method
|
||||
# This field should contain hashed_password value
|
||||
update_fields.append("hashed_password = ?")
|
||||
update_values.append(value)
|
||||
elif field in ["is_verified", "is_active"]:
|
||||
update_fields.append(f"{field} = ?")
|
||||
update_values.append(int(value))
|
||||
else:
|
||||
update_fields.append(f"{field} = ?")
|
||||
update_values.append(value)
|
||||
|
||||
if not update_fields:
|
||||
return user
|
||||
|
||||
# Always update the updated_at timestamp
|
||||
update_fields.append("updated_at = ?")
|
||||
update_values.append(datetime.utcnow().isoformat())
|
||||
|
||||
# Add user_id for WHERE clause
|
||||
update_values.append(user_id)
|
||||
|
||||
query = f"UPDATE users SET {', '.join(update_fields)} WHERE id = ?"
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(query, update_values)
|
||||
conn.commit()
|
||||
|
||||
# Return updated user
|
||||
updated_user = self.get_user_by_id(user_id)
|
||||
if not updated_user:
|
||||
raise UserNotFoundError(f"User with id {user_id} not found after update")
|
||||
|
||||
return updated_user
|
||||
|
||||
def delete_user(self, user_id: str) -> bool:
|
||||
"""
|
||||
Delete a user from the database (hard delete).
|
||||
|
||||
Args:
|
||||
user_id: The unique user identifier.
|
||||
|
||||
Returns:
|
||||
True if user was deleted, False if user was not found.
|
||||
"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM users WHERE id = ?", (user_id,))
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
def email_exists(self, email: str) -> bool:
|
||||
"""
|
||||
Check if an email address is already registered.
|
||||
|
||||
Args:
|
||||
email: The email address to check.
|
||||
|
||||
Returns:
|
||||
True if email exists, False otherwise.
|
||||
"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT 1 FROM users WHERE email = ? LIMIT 1", (email,))
|
||||
return cursor.fetchone() is not None
|
||||
|
||||
|
||||
class SQLiteTokenRepository(TokenRepository):
|
||||
"""
|
||||
SQLite implementation of TokenRepository.
|
||||
|
||||
This implementation manages both refresh tokens and password reset tokens
|
||||
in a single table with a discriminator field (token_type).
|
||||
|
||||
Attributes:
|
||||
db_path: Path to the SQLite database file.
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: str):
|
||||
"""
|
||||
Initialize SQLite token repository.
|
||||
|
||||
Creates the tokens table and indexes if they don't exist.
|
||||
|
||||
Args:
|
||||
db_path: Path to the SQLite database file.
|
||||
"""
|
||||
self.db_path = db_path
|
||||
self._create_tables()
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""
|
||||
Create tokens table and indexes if they don't exist.
|
||||
|
||||
The table schema includes:
|
||||
- token: Primary key (random secure string)
|
||||
- token_type: Discriminator ("refresh" or "password_reset")
|
||||
- user_id: Foreign key to users
|
||||
- expires_at: ISO format timestamp
|
||||
- created_at: ISO format timestamp
|
||||
- is_revoked: Boolean (stored as INTEGER)
|
||||
|
||||
Indexes are created for efficient queries.
|
||||
"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Create tokens table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS tokens
|
||||
(
|
||||
token
|
||||
TEXT
|
||||
PRIMARY
|
||||
KEY,
|
||||
token_type
|
||||
TEXT
|
||||
NOT
|
||||
NULL,
|
||||
user_id
|
||||
TEXT
|
||||
NOT
|
||||
NULL,
|
||||
expires_at
|
||||
TEXT
|
||||
NOT
|
||||
NULL,
|
||||
created_at
|
||||
TEXT
|
||||
NOT
|
||||
NULL,
|
||||
is_revoked
|
||||
INTEGER
|
||||
NOT
|
||||
NULL
|
||||
DEFAULT
|
||||
0
|
||||
)
|
||||
""")
|
||||
|
||||
# Create index on user_id for revoking all user tokens
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_tokens_user_id
|
||||
ON tokens(user_id)
|
||||
""")
|
||||
|
||||
# Create composite index on token_type and user_id
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_tokens_type_user
|
||||
ON tokens(token_type, user_id)
|
||||
""")
|
||||
|
||||
# Create index on expires_at for cleanup operations
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_tokens_expires_at
|
||||
ON tokens(expires_at)
|
||||
""")
|
||||
|
||||
conn.commit()
|
||||
|
||||
def _row_to_token(self, row: tuple) -> TokenData:
|
||||
"""
|
||||
Convert a database row to a TokenData model.
|
||||
|
||||
Args:
|
||||
row: Database row tuple.
|
||||
|
||||
Returns:
|
||||
TokenData model instance.
|
||||
"""
|
||||
return TokenData(
|
||||
token=row[0],
|
||||
token_type=row[1],
|
||||
user_id=row[2],
|
||||
expires_at=datetime.fromisoformat(row[3]),
|
||||
created_at=datetime.fromisoformat(row[4]),
|
||||
is_revoked=bool(row[5])
|
||||
)
|
||||
|
||||
def save_token(self, token_data: TokenData) -> None:
|
||||
"""
|
||||
Save a token to the database.
|
||||
|
||||
Args:
|
||||
token_data: Complete token information.
|
||||
"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
INSERT INTO tokens (token, token_type, user_id, expires_at, created_at, is_revoked)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
token_data.token,
|
||||
token_data.token_type,
|
||||
token_data.user_id,
|
||||
token_data.expires_at.isoformat(),
|
||||
token_data.created_at.isoformat(),
|
||||
int(token_data.is_revoked)
|
||||
))
|
||||
conn.commit()
|
||||
|
||||
def get_token(self, token: str, token_type: str) -> Optional[TokenData]:
|
||||
"""
|
||||
Retrieve a token from the database.
|
||||
|
||||
Args:
|
||||
token: The token string to search for.
|
||||
token_type: Type of token ("refresh" or "password_reset").
|
||||
|
||||
Returns:
|
||||
The token data if found, None otherwise.
|
||||
"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
SELECT token, token_type, user_id, expires_at, created_at, is_revoked
|
||||
FROM tokens
|
||||
WHERE token = ?
|
||||
AND token_type = ?
|
||||
""", (token, token_type))
|
||||
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
return self._row_to_token(row)
|
||||
return None
|
||||
|
||||
def revoke_token(self, token: str) -> bool:
|
||||
"""
|
||||
Revoke a specific token.
|
||||
|
||||
Args:
|
||||
token: The token string to revoke.
|
||||
|
||||
Returns:
|
||||
True if token was revoked, False if token was not found.
|
||||
"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
UPDATE tokens
|
||||
SET is_revoked = 1
|
||||
WHERE token = ?
|
||||
""", (token,))
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
def revoke_all_user_tokens(self, user_id: str, token_type: str) -> int:
|
||||
"""
|
||||
Revoke all tokens of a specific type for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user whose tokens should be revoked.
|
||||
token_type: Type of tokens to revoke.
|
||||
|
||||
Returns:
|
||||
Number of tokens revoked.
|
||||
"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
UPDATE tokens
|
||||
SET is_revoked = 1
|
||||
WHERE user_id = ?
|
||||
AND token_type = ?
|
||||
""", (user_id, token_type))
|
||||
conn.commit()
|
||||
return cursor.rowcount
|
||||
|
||||
def delete_expired_tokens(self) -> int:
|
||||
"""
|
||||
Delete all expired tokens from the database.
|
||||
|
||||
Returns:
|
||||
Number of tokens deleted.
|
||||
"""
|
||||
now = datetime.utcnow().isoformat()
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
DELETE
|
||||
FROM tokens
|
||||
WHERE expires_at < ?
|
||||
""", (now,))
|
||||
conn.commit()
|
||||
return cursor.rowcount
|
||||
|
||||
def is_token_valid(self, token: str, token_type: str) -> bool:
|
||||
"""
|
||||
Check if a token is valid (exists, not revoked, not expired).
|
||||
|
||||
Args:
|
||||
token: The token string to validate.
|
||||
token_type: Type of token.
|
||||
|
||||
Returns:
|
||||
True if token is valid, False otherwise.
|
||||
"""
|
||||
token_data = self.get_token(token, token_type)
|
||||
|
||||
if not token_data:
|
||||
return False
|
||||
|
||||
if token_data.is_revoked:
|
||||
return False
|
||||
|
||||
if token_data.expires_at < datetime.now():
|
||||
return False
|
||||
|
||||
return True
|
||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
0
tests/core/__init__.py
Normal file
0
tests/core/__init__.py
Normal file
106
tests/core/conftest.py
Normal file
106
tests/core/conftest.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# tests/core/conftest.py
|
||||
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from my_auth.core.password import PasswordManager
|
||||
from my_auth.core.token import TokenManager
|
||||
from src.my_auth.core.auth import AuthService
|
||||
from src.my_auth.models.user import UserCreate
|
||||
from src.my_auth.persistence.sqlite import SQLiteUserRepository, SQLiteTokenRepository
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user_data_create():
|
||||
"""Provides valid data for creating a test user."""
|
||||
return UserCreate(
|
||||
email="test.service@example.com",
|
||||
username="TestServiceUser",
|
||||
password="ValidPassword123!",
|
||||
roles=["member"],
|
||||
user_settings={"theme": "dark"}
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user_hashed_password():
|
||||
"""Provides a dummy hashed password (only used internally by service)"""
|
||||
# Note: In service tests, we rely on the service to do the hashing/verification,
|
||||
# but this is kept for completeness if needed elsewhere.
|
||||
return "$2b$12$R.S/XfI2tQYt3Kk.iF1XwOQz0Qe.L0T0mD/O1H8E2V5D4Q6F7G8H9I0"
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def sqlite_db_path(tmp_path_factory):
|
||||
"""
|
||||
Creates a temporary directory and an SQLite file path for the test session.
|
||||
The directory is deleted after the session.
|
||||
"""
|
||||
temp_dir = tmp_path_factory.mktemp("sqlite_auth_service_test")
|
||||
db_file: Path = temp_dir / "auth_service_test.db"
|
||||
|
||||
yield str(db_file)
|
||||
|
||||
# Cleanup phase
|
||||
try:
|
||||
if temp_dir.exists():
|
||||
shutil.rmtree(temp_dir)
|
||||
except OSError as e:
|
||||
print(f"Error during cleanup of temporary DB directory: {e}")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_repository(sqlite_db_path: str) -> SQLiteUserRepository:
|
||||
"""Provides a real SQLiteUserRepository instance."""
|
||||
return SQLiteUserRepository(db_path=sqlite_db_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def token_repository(sqlite_db_path: str) -> SQLiteTokenRepository:
|
||||
"""Provides a real SQLiteTokenRepository instance."""
|
||||
return SQLiteTokenRepository(db_path=sqlite_db_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_password_manager() -> PasswordManager:
|
||||
"""Provides a PasswordManager instance for injection (low rounds for speed)."""
|
||||
mock = MagicMock(spec=PasswordManager)
|
||||
mock.hash_password.return_value = "PREDICTABLE_HASHED_PASSWORD_FOR_TESTING"
|
||||
mock.verify_password.return_value = True
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_token_manager() -> TokenManager:
|
||||
"""Provides a TokenManager instance for injection (fast expiration settings)."""
|
||||
mock = MagicMock(spec=TokenManager)
|
||||
mock.create_access_token.return_value = "MOCKED_ACCESS_TOKEN"
|
||||
mock.create_refresh_token.return_value = "MOCKED_REFRESH_TOKEN"
|
||||
return mock
|
||||
|
||||
|
||||
# --- Service Fixture ---
|
||||
|
||||
@pytest.fixture
|
||||
def auth_service(user_repository: SQLiteUserRepository,
|
||||
token_repository: SQLiteTokenRepository,
|
||||
mock_password_manager: PasswordManager,
|
||||
mock_token_manager: TokenManager
|
||||
) -> AuthService:
|
||||
"""
|
||||
Provides an AuthService instance initialized with real repositories.
|
||||
"""
|
||||
# NOTE: To test hashing/verification, we must ensure password_hash_rounds is low.
|
||||
# NOTE: To simplify JWT testing, we mock the internal components (hashing/JWT)
|
||||
# as the AuthService shouldn't be responsible for these core algorithms,
|
||||
# only for orchestrating them. If your service integrates them directly,
|
||||
# we'll need to patch them below.
|
||||
return AuthService(
|
||||
user_repository=user_repository,
|
||||
token_repository=token_repository,
|
||||
password_manager=mock_password_manager,
|
||||
token_manager=mock_token_manager,
|
||||
)
|
||||
270
tests/core/test_auth_service.py
Normal file
270
tests/core/test_auth_service.py
Normal file
@@ -0,0 +1,270 @@
|
||||
# tests/core/test_auth_service.py
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.my_auth.core.auth import AuthService
|
||||
from src.my_auth.exceptions import UserAlreadyExistsError, InvalidCredentialsError, InvalidTokenError, ExpiredTokenError
|
||||
from src.my_auth.models.token import TokenData
|
||||
from src.my_auth.models.user import UserCreate, UserUpdate
|
||||
|
||||
|
||||
class TestAuthServiceRegisterLogin(object):
|
||||
"""Tests for the user registration and login processes."""
|
||||
|
||||
# The mocks are injected via the auth_service fixture from conftest.py
|
||||
|
||||
def test_register_success(self, auth_service: AuthService,
|
||||
test_user_data_create: UserCreate):
|
||||
"""Success: Registration works and stores the predictable hash."""
|
||||
|
||||
user = auth_service.register(test_user_data_create)
|
||||
|
||||
assert user is not None
|
||||
assert user.hashed_password == "PREDICTABLE_HASHED_PASSWORD_FOR_TESTING"
|
||||
auth_service.password_manager.hash_password.assert_called_once_with(test_user_data_create.password)
|
||||
|
||||
def test_register_failure_if_email_exists(self, auth_service: AuthService,
|
||||
test_user_data_create: UserCreate):
|
||||
"""Failure: Cannot register if the email already exists."""
|
||||
|
||||
auth_service.register(test_user_data_create)
|
||||
|
||||
with pytest.raises(UserAlreadyExistsError):
|
||||
auth_service.register(test_user_data_create)
|
||||
|
||||
def test_login_success(self, auth_service: AuthService, test_user_data_create: UserCreate):
|
||||
"""Success: Logging in with correct credentials generates and saves tokens."""
|
||||
|
||||
# Execute register to insert the user
|
||||
auth_service.register(test_user_data_create)
|
||||
|
||||
# Execute login
|
||||
user, tokens = auth_service.login(test_user_data_create.email, test_user_data_create.password)
|
||||
|
||||
assert user.email == test_user_data_create.email
|
||||
|
||||
auth_service.password_manager.verify_password.assert_called_once()
|
||||
auth_service.token_manager.create_access_token.assert_called_once()
|
||||
assert tokens.access_token == "MOCKED_ACCESS_TOKEN"
|
||||
assert tokens.refresh_token == "MOCKED_REFRESH_TOKEN"
|
||||
|
||||
def test_login_failure_invalid_password(self, auth_service: AuthService, # Removed mock_password_manager injection
|
||||
test_user_data_create: UserCreate):
|
||||
"""Failure: Login fails with InvalidCredentialsError if the password is wrong."""
|
||||
|
||||
# Access the manager via the service instance
|
||||
pm = auth_service.password_manager
|
||||
|
||||
# Setup: Mock hash for registration
|
||||
pm.hash_password.return_value = "HASH"
|
||||
auth_service.register(test_user_data_create)
|
||||
|
||||
# Mock verify for failure
|
||||
pm.verify_password.return_value = False
|
||||
|
||||
with pytest.raises(InvalidCredentialsError):
|
||||
auth_service.login(test_user_data_create.email, "WrongPassword!")
|
||||
|
||||
# Restore the default mock behavior defined in conftest for subsequent tests
|
||||
pm.hash_password.return_value = "PREDICTABLE_HASHED_PASSWORD_FOR_TESTING"
|
||||
pm.verify_password.return_value = True
|
||||
|
||||
def test_login_failure_user_not_found(self, auth_service: AuthService):
|
||||
"""Failure: Login fails if the user does not exist."""
|
||||
with pytest.raises(InvalidCredentialsError):
|
||||
auth_service.login("non.existent@example.com", "AnyPassword")
|
||||
|
||||
|
||||
class TestAuthServiceTokenManagement(object):
|
||||
"""Tests for token-related flows (Refresh, Logout, GetCurrentUser)."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_user_and_token(self, auth_service: AuthService, test_user_data_create: UserCreate):
|
||||
"""
|
||||
Sets up a registered user and an initial set of tokens for management tests.
|
||||
Temporarily overrides manager behavior for setup.
|
||||
"""
|
||||
# Temporarily set up predictable values for the registration/login flow within the fixture setup
|
||||
pm = auth_service.password_manager
|
||||
tm = auth_service.token_manager
|
||||
|
||||
original_hash = pm.hash_password.return_value
|
||||
original_verify = pm.verify_password.return_value
|
||||
original_access = tm.create_access_token.return_value
|
||||
original_refresh = tm.create_refresh_token.return_value
|
||||
|
||||
pm.hash_password.return_value = "HASHED_PASS"
|
||||
pm.verify_password.return_value = True
|
||||
tm.create_access_token.return_value = "SETUP_ACCESS_TOKEN"
|
||||
tm.create_refresh_token.return_value = "SETUP_REFRESH_TOKEN"
|
||||
|
||||
user = auth_service.register(test_user_data_create)
|
||||
_, tokens = auth_service.login(test_user_data_create.email, test_user_data_create.password)
|
||||
|
||||
self.user = user
|
||||
self.refresh_token = tokens.refresh_token
|
||||
self.access_token = tokens.access_token
|
||||
|
||||
# Restore mock values to default conftest behavior for the actual tests
|
||||
pm.hash_password.return_value = original_hash
|
||||
pm.verify_password.return_value = original_verify
|
||||
tm.create_access_token.return_value = original_access
|
||||
tm.create_refresh_token.return_value = original_refresh
|
||||
|
||||
def test_refresh_access_token_success(self, auth_service: AuthService):
|
||||
"""Success: Refreshing an access token works with a valid refresh token."""
|
||||
|
||||
# Access the manager via the service instance
|
||||
tm = auth_service.token_manager
|
||||
|
||||
# Use patch.object on the *instance* for granular control within the test
|
||||
with patch.object(tm, 'create_access_token',
|
||||
return_value="NEW_MOCKED_ACCESS_TOKEN") as mock_create_access:
|
||||
tokens = auth_service.refresh_access_token(self.refresh_token)
|
||||
|
||||
assert tokens.access_token == "NEW_MOCKED_ACCESS_TOKEN"
|
||||
mock_create_access.assert_called_once()
|
||||
|
||||
def test_refresh_access_token_failure_invalid_token(self, auth_service: AuthService):
|
||||
"""Failure: Refreshing fails if the token is invalid (revoked, expired, etc.)."""
|
||||
auth_service.logout(self.refresh_token)
|
||||
|
||||
with pytest.raises(InvalidTokenError):
|
||||
auth_service.refresh_access_token("invalid_token")
|
||||
|
||||
def test_logout_success(self, auth_service: AuthService):
|
||||
"""Success: Logout revokes the specified refresh token."""
|
||||
result = auth_service.logout(self.refresh_token)
|
||||
assert result is True
|
||||
|
||||
with pytest.raises(InvalidTokenError):
|
||||
auth_service.refresh_access_token(self.refresh_token)
|
||||
|
||||
def test_get_current_user_success(self, auth_service: AuthService):
|
||||
"""Success: Getting the current user works by successfully decoding the JWT."""
|
||||
|
||||
# Mock the decoder to simulate a decoded payload
|
||||
with patch.object(auth_service.token_manager, 'decode_access_token',
|
||||
return_value={"sub": self.user.id}) as mock_decode:
|
||||
user = auth_service.get_current_user("dummy_jwt")
|
||||
|
||||
assert user.id == self.user.id
|
||||
mock_decode.assert_called_once()
|
||||
|
||||
def test_get_current_user_failure_invalid_token(self, auth_service: AuthService):
|
||||
"""Failure: Getting the current user fails if the access token is invalid/expired."""
|
||||
|
||||
with patch.object(auth_service.token_manager, 'decode_access_token',
|
||||
side_effect=InvalidTokenError("Invalid signature")):
|
||||
with pytest.raises(InvalidTokenError):
|
||||
auth_service.get_current_user("invalid_access_jwt")
|
||||
|
||||
with patch.object(auth_service.token_manager, 'decode_access_token',
|
||||
side_effect=ExpiredTokenError("Token expired")):
|
||||
with pytest.raises(ExpiredTokenError):
|
||||
auth_service.get_current_user("expired_access_jwt")
|
||||
|
||||
|
||||
class TestAuthServiceResetVerification(object):
|
||||
"""Tests for password reset and email verification flows."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_user(self, auth_service: AuthService, test_user_data_create: UserCreate):
|
||||
"""Sets up a registered user using a mock hash for speed."""
|
||||
|
||||
pm = auth_service.password_manager
|
||||
original_hash = pm.hash_password.return_value
|
||||
|
||||
# Temporarily set hash for setup
|
||||
pm.hash_password.return_value = "HASHED_PASS"
|
||||
user = auth_service.register(test_user_data_create)
|
||||
self.user = user
|
||||
|
||||
# Restore hash mock
|
||||
pm.hash_password.return_value = original_hash
|
||||
|
||||
@patch('src.my_auth.core.email.send_email')
|
||||
def test_request_password_reset_success(self, mock_send_email: MagicMock, auth_service: AuthService):
|
||||
"""Success: Requesting a password reset generates a token and sends an email."""
|
||||
|
||||
tm = auth_service.token_manager
|
||||
with patch.object(tm, 'create_password_reset_token',
|
||||
return_value="MOCKED_RESET_TOKEN") as mock_create_token:
|
||||
token_string = auth_service.request_password_reset(self.user.email)
|
||||
|
||||
assert token_string == "MOCKED_RESET_TOKEN"
|
||||
mock_create_token.assert_called_once()
|
||||
mock_send_email.assert_called_once()
|
||||
|
||||
def test_reset_password_success(self, auth_service: AuthService):
|
||||
"""Success: Resetting the password works with a valid token."""
|
||||
|
||||
# Setup: Manually create a valid reset token
|
||||
auth_service.token_repository.save_token(
|
||||
TokenData(token="valid_reset_token", token_type="password_reset", user_id=self.user.id,
|
||||
expires_at=datetime.now() + timedelta(minutes=10))
|
||||
)
|
||||
|
||||
# Patch the PasswordManager instance to control the hash output
|
||||
pm = auth_service.password_manager
|
||||
with patch.object(pm, 'hash_password',
|
||||
return_value="NEW_HASHED_PASSWORD_FOR_RESET") as mock_hash:
|
||||
new_password = "NewPassword123!"
|
||||
result = auth_service.reset_password("valid_reset_token", new_password)
|
||||
|
||||
assert result is True
|
||||
mock_hash.assert_called_once_with(new_password)
|
||||
|
||||
# Verification: Check that user data was updated
|
||||
updated_user = auth_service.user_repository.get_user_by_id(self.user.id)
|
||||
assert updated_user.hashed_password == "NEW_HASHED_PASSWORD_FOR_RESET"
|
||||
|
||||
@patch('src.my_auth.core.email.send_email')
|
||||
def test_request_email_verification_success(self, mock_send_email: MagicMock, auth_service: AuthService):
|
||||
"""Success: Requesting verification generates a token and sends an email."""
|
||||
|
||||
tm = auth_service.token_manager
|
||||
with patch.object(tm, 'create_email_verification_token',
|
||||
return_value="MOCKED_JWT_VERIFY_TOKEN") as mock_create_token:
|
||||
token_string = auth_service.request_email_verification(self.user.email)
|
||||
|
||||
assert token_string == "MOCKED_JWT_VERIFY_TOKEN"
|
||||
mock_create_token.assert_called_once_with(self.user.email)
|
||||
mock_send_email.assert_called_once()
|
||||
|
||||
def test_verify_email_success(self, auth_service: AuthService):
|
||||
"""Success: Verification updates the user's status."""
|
||||
|
||||
# The token_manager is mocked in conftest, so we must access its real create method
|
||||
# or rely on the mock's return value to get a token string to use in the call.
|
||||
# Since we need a real token for the decode logic to pass, we need to bypass the mock here.
|
||||
|
||||
# We will temporarily use the real TokenManager to create a valid, decodable token.
|
||||
# This requires an *unmocked* token manager instance, which is tricky in this setup.
|
||||
|
||||
# Alternative: Temporarily inject a real TokenManager for this test (or rely on a non-mocked method)
|
||||
|
||||
# Assuming TokenManager.create_email_verification_token can be mocked to return a static string
|
||||
# and TokenManager.decode_email_verification_token can be patched to simulate success.
|
||||
|
||||
# Since the method calls decode_email_verification_token internally, we mock the output of the decode step.
|
||||
|
||||
# Setup: Ensure user is unverified
|
||||
auth_service.user_repository.update_user(self.user.id, UserUpdate(is_verified=False))
|
||||
|
||||
tm = auth_service.token_manager
|
||||
|
||||
# Mock the decode step to ensure it returns the email used for verification
|
||||
with patch.object(tm, 'decode_email_verification_token', return_value=self.user.email) as mock_decode:
|
||||
# Test (we use a dummy token string as the decode step is mocked)
|
||||
result = auth_service.verify_email("dummy_verification_token")
|
||||
|
||||
assert result is True
|
||||
mock_decode.assert_called_once()
|
||||
|
||||
# Verification: User is verified
|
||||
updated_user = auth_service.user_repository.get_user_by_id(self.user.id)
|
||||
assert updated_user.is_verified is True
|
||||
0
tests/persistence/__init__.py
Normal file
0
tests/persistence/__init__.py
Normal file
75
tests/persistence/conftest.py
Normal file
75
tests/persistence/conftest.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from uuid import uuid4
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from my_auth.models.token import TokenData
|
||||
from my_auth.models.user import UserCreate
|
||||
from my_auth.persistence.sqlite import SQLiteUserRepository, SQLiteTokenRepository
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user_data_create():
|
||||
"""Provides valid data for creating a test user."""
|
||||
return UserCreate(
|
||||
email="test.user@example.com",
|
||||
username="TestUser",
|
||||
password="ValidPassword123!", # Meets all strength criteria
|
||||
roles=["member"],
|
||||
user_settings={"theme": "dark"}
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user_hashed_password():
|
||||
"""Provides a dummy hashed password for persistence (hashing is tested elsewhere)."""
|
||||
return "$2b$12$R.S/XfI2tQYt3Kk.iF1XwOQz0Qe.L0T0mD/O1H8E2V5D4Q6F7G8H9I0"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_token_data():
|
||||
"""Provides valid data for a refresh token."""
|
||||
now = datetime.now()
|
||||
return TokenData(
|
||||
token=f"opaque_refresh_token_{uuid4()}",
|
||||
token_type="refresh",
|
||||
user_id="user_id_for_token_test",
|
||||
expires_at=now + timedelta(days=7),
|
||||
is_revoked=False,
|
||||
created_at=now
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def sqlite_db_path(tmp_path_factory):
|
||||
"""
|
||||
Creates a temporary directory and an SQLite file path for the test session.
|
||||
The directory and its contents are deleted after the session.
|
||||
"""
|
||||
temp_dir = tmp_path_factory.mktemp("sqlite_auth_test")
|
||||
db_file: Path = temp_dir / "auth_test.db"
|
||||
|
||||
yield str(db_file)
|
||||
|
||||
try:
|
||||
if temp_dir.exists():
|
||||
import shutil
|
||||
shutil.rmtree(temp_dir)
|
||||
except OSError as e:
|
||||
# Handle case where directory might be locked, though rare in tests
|
||||
print(f"Error during cleanup of temporary DB directory: {e}")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_repository(sqlite_db_path: str) -> SQLiteUserRepository:
|
||||
"""Provides an instance of SQLiteUserRepository initialized with the in-memory DB."""
|
||||
# Assuming the repository takes the connection object or path (using connection here)
|
||||
return SQLiteUserRepository(sqlite_db_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def token_repository(sqlite_db_path: str) -> SQLiteTokenRepository:
|
||||
"""Provides an instance of SQLiteTokenRepository initialized with the in-memory DB."""
|
||||
# Assuming the repository takes the connection object or path (using connection here)
|
||||
return SQLiteTokenRepository(sqlite_db_path)
|
||||
87
tests/persistence/test_sql_token.py
Normal file
87
tests/persistence/test_sql_token.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# tests/persistence/test_sqlite_token.py
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from my_auth.models.token import TokenData
|
||||
from my_auth.persistence.sqlite import SQLiteTokenRepository
|
||||
|
||||
|
||||
def test_i_can_save_and_retrieve_token(token_repository: SQLiteTokenRepository,
|
||||
test_token_data: TokenData):
|
||||
"""Verifies token saving and successful retrieval by token string and type."""
|
||||
|
||||
# 1. Save Token
|
||||
token_repository.save_token(test_token_data)
|
||||
|
||||
# 2. Retrieve Token
|
||||
retrieved_token = token_repository.get_token(test_token_data.token, test_token_data.token_type)
|
||||
|
||||
# Assertions
|
||||
assert retrieved_token is not None
|
||||
assert retrieved_token.token == test_token_data.token
|
||||
assert retrieved_token.user_id == test_token_data.user_id
|
||||
assert retrieved_token.is_revoked is False
|
||||
assert retrieved_token.token_type == test_token_data.token_type
|
||||
|
||||
|
||||
def test_i_can_revoke_token(token_repository: SQLiteTokenRepository,
|
||||
test_token_data: TokenData):
|
||||
"""Verifies a token can be revoked and its revoked status is updated."""
|
||||
|
||||
# Setup: Save the token
|
||||
token_repository.save_token(test_token_data)
|
||||
|
||||
# 1. Revoke the token
|
||||
was_revoked = token_repository.revoke_token(test_token_data.token)
|
||||
assert was_revoked is True
|
||||
|
||||
# 2. Retrieve and check status
|
||||
revoked_token = token_repository.get_token(test_token_data.token, test_token_data.token_type)
|
||||
assert revoked_token is not None
|
||||
assert revoked_token.is_revoked is True
|
||||
|
||||
# 3. Attempt to revoke a non-existent token
|
||||
was_revoked_again = token_repository.revoke_token("non_existent_token")
|
||||
assert was_revoked_again is False
|
||||
|
||||
|
||||
def test_i_can_use_is_token_valid_for_valid_token(token_repository: SQLiteTokenRepository,
|
||||
test_token_data: TokenData):
|
||||
"""Verifies the convenience method returns True for a fresh, unexpired token."""
|
||||
|
||||
token_repository.save_token(test_token_data)
|
||||
|
||||
is_valid = token_repository.is_token_valid(test_token_data.token, test_token_data.token_type)
|
||||
assert is_valid is True
|
||||
|
||||
is_valid = token_repository.is_token_valid("non_existent_token", test_token_data.token_type)
|
||||
assert is_valid is False
|
||||
|
||||
|
||||
def test_i_can_use_is_token_valid_for_revoked_token(token_repository: SQLiteTokenRepository,
|
||||
test_token_data: TokenData):
|
||||
"""Verifies is_token_valid returns False for a token marked as revoked."""
|
||||
|
||||
token_repository.save_token(test_token_data)
|
||||
token_repository.revoke_token(test_token_data.token)
|
||||
|
||||
is_valid = token_repository.is_token_valid(test_token_data.token, test_token_data.token_type)
|
||||
assert is_valid is False
|
||||
|
||||
|
||||
def test_i_can_use_is_token_valid_for_expired_token(token_repository: SQLiteTokenRepository):
|
||||
"""Verifies is_token_valid returns False for a token whose expiration is in the past."""
|
||||
|
||||
expired_token_data = TokenData(
|
||||
token="expired_token_test",
|
||||
token_type="password_reset",
|
||||
user_id="user_id_expired",
|
||||
expires_at=datetime.now() - timedelta(hours=1), # Set expiration to 1 hour ago
|
||||
is_revoked=False,
|
||||
created_at=datetime.now() - timedelta(hours=2)
|
||||
)
|
||||
|
||||
token_repository.save_token(expired_token_data)
|
||||
|
||||
is_valid = token_repository.is_token_valid(expired_token_data.token, expired_token_data.token_type)
|
||||
assert is_valid is False
|
||||
155
tests/persistence/test_sqlite_user.py
Normal file
155
tests/persistence/test_sqlite_user.py
Normal file
@@ -0,0 +1,155 @@
|
||||
# tests/persistence/test_sqlite_user.py
|
||||
|
||||
import pytest
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from my_auth.persistence.sqlite import SQLiteUserRepository
|
||||
from my_auth.models.user import UserCreate, UserUpdate
|
||||
from my_auth.exceptions import UserAlreadyExistsError, UserNotFoundError
|
||||
|
||||
|
||||
def test_i_can_create_and_retrieve_user_by_email(user_repository: SQLiteUserRepository,
|
||||
test_user_data_create: UserCreate,
|
||||
test_user_hashed_password: str):
|
||||
"""Verifies user creation and successful retrieval by email."""
|
||||
|
||||
# 1. Create User
|
||||
created_user = user_repository.create_user(
|
||||
user_data=test_user_data_create,
|
||||
hashed_password=test_user_hashed_password
|
||||
)
|
||||
|
||||
# Assertions on creation
|
||||
assert created_user is not None
|
||||
assert created_user.email == test_user_data_create.email
|
||||
assert created_user.hashed_password == test_user_hashed_password
|
||||
assert created_user.is_active is True
|
||||
assert created_user.is_verified is False
|
||||
assert created_user.id is not None
|
||||
assert isinstance(created_user.created_at, datetime)
|
||||
|
||||
# 2. Retrieve User
|
||||
retrieved_user = user_repository.get_user_by_email(test_user_data_create.email)
|
||||
|
||||
# Assertions on retrieval
|
||||
assert retrieved_user is not None
|
||||
assert retrieved_user.id == created_user.id
|
||||
assert retrieved_user.username == "TestUser"
|
||||
|
||||
|
||||
def test_i_can_retrieve_user_by_id(user_repository: SQLiteUserRepository,
|
||||
test_user_data_create: UserCreate,
|
||||
test_user_hashed_password: str):
|
||||
"""Verifies user retrieval by unique ID."""
|
||||
created_user = user_repository.create_user(test_user_data_create, test_user_hashed_password)
|
||||
|
||||
retrieved_user = user_repository.get_user_by_id(created_user.id)
|
||||
|
||||
assert retrieved_user is not None
|
||||
assert retrieved_user.email == created_user.email
|
||||
|
||||
|
||||
def test_i_cannot_create_user_if_email_exists(user_repository: SQLiteUserRepository,
|
||||
test_user_data_create: UserCreate,
|
||||
test_user_hashed_password: str):
|
||||
"""Ensures that creating a user with an existing email raises UserAlreadyExistsError."""
|
||||
# First creation should succeed
|
||||
user_repository.create_user(test_user_data_create, test_user_hashed_password)
|
||||
|
||||
# Second creation with same email should fail
|
||||
with pytest.raises(UserAlreadyExistsError):
|
||||
user_repository.create_user(test_user_data_create, test_user_hashed_password)
|
||||
|
||||
|
||||
def test_i_can_check_if_email_exists(user_repository: SQLiteUserRepository,
|
||||
test_user_data_create: UserCreate,
|
||||
test_user_hashed_password: str):
|
||||
"""Verifies the email_exists method returns correct boolean results."""
|
||||
email = test_user_data_create.email
|
||||
non_existent_email = "non.existent@example.com"
|
||||
|
||||
# 1. Check before creation (Should be False)
|
||||
assert user_repository.email_exists(email) is False
|
||||
|
||||
user_repository.create_user(test_user_data_create, test_user_hashed_password)
|
||||
|
||||
# 2. Check after creation (Should be True)
|
||||
assert user_repository.email_exists(email) is True
|
||||
|
||||
# 3. Check for another non-existent email
|
||||
assert user_repository.email_exists(non_existent_email) is False
|
||||
|
||||
|
||||
def test_i_can_update_username_and_roles(user_repository: SQLiteUserRepository,
|
||||
test_user_data_create: UserCreate,
|
||||
test_user_hashed_password: str):
|
||||
"""Tests partial update of user fields (username and roles)."""
|
||||
|
||||
created_user = user_repository.create_user(test_user_data_create, test_user_hashed_password)
|
||||
|
||||
updates = UserUpdate(username="NewUsername", roles=["admin", "staff"])
|
||||
|
||||
updated_user = user_repository.update_user(created_user.id, updates)
|
||||
|
||||
assert updated_user.username == "NewUsername"
|
||||
assert updated_user.roles == ["admin", "staff"]
|
||||
# Check that unrelated fields remain the same
|
||||
assert updated_user.email == created_user.email
|
||||
# Check that update timestamp changed
|
||||
assert updated_user.updated_at > created_user.updated_at
|
||||
assert updated_user.is_verified == created_user.is_verified
|
||||
|
||||
|
||||
def test_i_can_update_is_active_status(user_repository: SQLiteUserRepository,
|
||||
test_user_data_create: UserCreate,
|
||||
test_user_hashed_password: str):
|
||||
"""Tests the specific update of the 'is_active' status."""
|
||||
|
||||
created_user = user_repository.create_user(test_user_data_create, test_user_hashed_password)
|
||||
|
||||
# Deactivate
|
||||
updates_deactivate = UserUpdate(is_active=False)
|
||||
deactivated_user = user_repository.update_user(created_user.id, updates_deactivate)
|
||||
assert deactivated_user.is_active is False
|
||||
|
||||
# Reactivate
|
||||
updates_activate = UserUpdate(is_active=True)
|
||||
reactivated_user = user_repository.update_user(created_user.id, updates_activate)
|
||||
assert reactivated_user.is_active is True
|
||||
|
||||
|
||||
def test_i_cannot_update_non_existent_user(user_repository: SQLiteUserRepository):
|
||||
"""Ensures that updating a user with an unknown ID raises UserNotFoundError."""
|
||||
non_existent_id = "unknown_id_123"
|
||||
updates = UserUpdate(username="Phantom")
|
||||
|
||||
with pytest.raises(UserNotFoundError):
|
||||
user_repository.update_user(non_existent_id, updates)
|
||||
|
||||
|
||||
def test_i_can_delete_user(user_repository: SQLiteUserRepository,
|
||||
test_user_data_create: UserCreate,
|
||||
test_user_hashed_password: str):
|
||||
"""Verifies user deletion and subsequent failure to retrieve."""
|
||||
|
||||
created_user = user_repository.create_user(test_user_data_create, test_user_hashed_password)
|
||||
|
||||
# 1. Delete the user
|
||||
was_deleted = user_repository.delete_user(created_user.id)
|
||||
assert was_deleted is True
|
||||
|
||||
# 2. Verify deletion by attempting retrieval
|
||||
retrieved_user = user_repository.get_user_by_id(created_user.id)
|
||||
assert retrieved_user is None
|
||||
|
||||
# 3. Verify attempting to delete again returns False
|
||||
was_deleted_again = user_repository.delete_user(created_user.id)
|
||||
assert was_deleted_again is False
|
||||
|
||||
|
||||
def test_i_cannot_retrieve_non_existent_user_by_email(user_repository: SQLiteUserRepository):
|
||||
"""Ensures retrieval by email returns None for non-existent email."""
|
||||
|
||||
retrieved_user = user_repository.get_user_by_email("ghost@example.com")
|
||||
assert retrieved_user is None
|
||||
Reference in New Issue
Block a user