Added user admin auto creation for Sqlite
This commit is contained in:
@@ -351,3 +351,8 @@ pytest tests/
|
||||
## License
|
||||
|
||||
MIT
|
||||
|
||||
## Release History
|
||||
|
||||
* 0.1.0 - Initial Release
|
||||
* 0.2.0 - Added admin user auto creation
|
||||
|
||||
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "myauth"
|
||||
version = "0.1.0"
|
||||
version = "0.2.0"
|
||||
description = "A reusable, modular authentication system for FastAPI applications with pluggable database backends."
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
|
||||
@@ -4,7 +4,7 @@ FastAPI routes for authentication module.
|
||||
This module provides ready-to-use FastAPI routes for all authentication
|
||||
operations. Routes are organized in an APIRouter with /auth prefix.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, HTTPException, status, FastAPI
|
||||
@@ -375,5 +375,6 @@ def create_auth_app(auth_service: AuthService) -> FastAPI:
|
||||
HTTPException: 401 if token is invalid or expired.
|
||||
"""
|
||||
return current_user
|
||||
|
||||
|
||||
return auth_app
|
||||
|
||||
@@ -5,442 +5,510 @@ This module provides the main authentication service that orchestrates
|
||||
all authentication operations including registration, login, token management,
|
||||
password reset, and email verification.
|
||||
"""
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from .password import PasswordManager
|
||||
from .token import TokenManager
|
||||
from ..persistence.base import UserRepository, TokenRepository
|
||||
from ..models.user import UserCreate, UserInDB, UserUpdate
|
||||
from ..models.token import AccessTokenResponse, TokenData
|
||||
from ..emailing.base import EmailService
|
||||
from ..exceptions import (
|
||||
InvalidCredentialsError,
|
||||
UserNotFoundError,
|
||||
AccountDisabledError,
|
||||
ExpiredTokenError,
|
||||
InvalidTokenError,
|
||||
RevokedTokenError
|
||||
InvalidCredentialsError,
|
||||
UserNotFoundError,
|
||||
AccountDisabledError,
|
||||
ExpiredTokenError,
|
||||
InvalidTokenError,
|
||||
RevokedTokenError
|
||||
)
|
||||
from ..models.token import AccessTokenResponse, TokenData
|
||||
from ..models.user import UserCreate, UserInDB, UserUpdate, UserCreateNoValidation
|
||||
from ..persistence.base import UserRepository, TokenRepository
|
||||
|
||||
|
||||
class AuthService:
|
||||
"""
|
||||
Main authentication service.
|
||||
|
||||
This service orchestrates all authentication-related operations by
|
||||
coordinating between password management, token management, and
|
||||
persistence layers.
|
||||
|
||||
Attributes:
|
||||
user_repository: Repository for user persistence operations.
|
||||
token_repository: Repository for token persistence operations.
|
||||
password_manager: Manager for password hashing and verification.
|
||||
token_manager: Manager for token creation and validation.
|
||||
email_service: Optional service for sending emails.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_repository: UserRepository,
|
||||
token_repository: TokenRepository,
|
||||
password_manager: PasswordManager,
|
||||
token_manager: TokenManager,
|
||||
email_service: Optional[EmailService] = None
|
||||
):
|
||||
"""
|
||||
Main authentication service.
|
||||
Initialize the authentication service.
|
||||
|
||||
This service orchestrates all authentication-related operations by
|
||||
coordinating between password management, token management, and
|
||||
persistence layers.
|
||||
|
||||
Attributes:
|
||||
user_repository: Repository for user persistence operations.
|
||||
token_repository: Repository for token persistence operations.
|
||||
Args:
|
||||
user_repository: Repository for user persistence.
|
||||
token_repository: Repository for token persistence.
|
||||
password_manager: Manager for password hashing and verification.
|
||||
token_manager: Manager for token creation and validation.
|
||||
email_service: Optional service for sending emails.
|
||||
email_service: Optional service for sending emails (password reset, verification).
|
||||
"""
|
||||
self.user_repository = user_repository
|
||||
self.token_repository = token_repository
|
||||
self.password_manager = password_manager
|
||||
self.token_manager = token_manager
|
||||
self.email_service = email_service
|
||||
|
||||
def create_admin_if_needed(self, admin_email: str = None, admin_username: str = None, admin_password: str = None):
|
||||
"""
|
||||
Create the admin user if it does not exist. This function checks the current number
|
||||
of users in the system. If there are no users present, it creates a user with
|
||||
administrative privileges using the provided credentials or environment variables
|
||||
as defaults.
|
||||
|
||||
:param admin_email: The email of the admin user. Defaults to "AUTH_ADMIN_EMAIL"
|
||||
environment variable or "admin@myauth.com" if not provided.
|
||||
:type admin_email: str, optional
|
||||
:param admin_username: The username of the admin user. Defaults to
|
||||
"AUTH_ADMIN_USERNAME" environment variable or "admin" if not provided.
|
||||
:type admin_username: str, optional
|
||||
:param admin_password: The password of the admin user. Defaults to
|
||||
"AUTH_ADMIN_PASSWORD" environment variable or "admin" if not provided.
|
||||
:type admin_password: str, optional
|
||||
:return: True if an admin user is created, otherwise False.
|
||||
:rtype: bool
|
||||
"""
|
||||
# create the admin user if it doesn't exist
|
||||
nb_users = self.count_users()
|
||||
if nb_users == 0:
|
||||
admin_email = admin_email or os.getenv("AUTH_ADMIN_EMAIL", "admin@myauth.com")
|
||||
admin_username = admin_username or os.getenv("AUTH_ADMIN_USERNAME", "admin")
|
||||
admin_password = admin_password or os.getenv("AUTH_ADMIN_PASSWORD", "admin")
|
||||
|
||||
admin_user = UserCreateNoValidation(
|
||||
email=admin_email,
|
||||
username=admin_username,
|
||||
password=admin_password,
|
||||
roles=["admin"]
|
||||
)
|
||||
self.register(admin_user)
|
||||
return True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_repository: UserRepository,
|
||||
token_repository: TokenRepository,
|
||||
password_manager: PasswordManager,
|
||||
token_manager: TokenManager,
|
||||
email_service: Optional[EmailService] = None
|
||||
):
|
||||
"""
|
||||
Initialize the authentication service.
|
||||
|
||||
Args:
|
||||
user_repository: Repository for user persistence.
|
||||
token_repository: Repository for token persistence.
|
||||
password_manager: Manager for password hashing and verification.
|
||||
token_manager: Manager for token creation and validation.
|
||||
email_service: Optional service for sending emails (password reset, verification).
|
||||
"""
|
||||
self.user_repository = user_repository
|
||||
self.token_repository = token_repository
|
||||
self.password_manager = password_manager
|
||||
self.token_manager = token_manager
|
||||
self.email_service = email_service
|
||||
return False
|
||||
|
||||
def register(self, user_data: UserCreate | UserCreateNoValidation) -> UserInDB:
|
||||
"""
|
||||
Register a new user.
|
||||
|
||||
def register(self, user_data: UserCreate) -> UserInDB:
|
||||
"""
|
||||
Register a new user.
|
||||
|
||||
This method creates a new user account with hashed password.
|
||||
The user's email is initially unverified.
|
||||
|
||||
Args:
|
||||
user_data: User registration data including password.
|
||||
|
||||
Returns:
|
||||
The created user (without password).
|
||||
|
||||
Raises:
|
||||
UserAlreadyExistsError: If email is already registered.
|
||||
|
||||
Example:
|
||||
>>> user_data = UserCreate(
|
||||
... email="user@example.com",
|
||||
... username="john_doe",
|
||||
... password="SecurePass123!"
|
||||
... )
|
||||
>>> user = auth_service.register(user_data)
|
||||
"""
|
||||
hashed_password = self.password_manager.hash_password(user_data.password)
|
||||
user = self.user_repository.create_user(user_data, hashed_password)
|
||||
return user
|
||||
This method creates a new user account with hashed password.
|
||||
The user's email is initially unverified.
|
||||
|
||||
def login(self, email: str, password: str) -> tuple[UserInDB, AccessTokenResponse]:
|
||||
"""
|
||||
Authenticate a user and create tokens.
|
||||
Args:
|
||||
user_data: User registration data including password.
|
||||
|
||||
This method verifies credentials, checks account status, and generates
|
||||
both access and refresh tokens.
|
||||
Returns:
|
||||
The created user (without password).
|
||||
|
||||
Args:
|
||||
email: User's email address.
|
||||
password: User's plain text password.
|
||||
|
||||
Returns:
|
||||
Tuple of (user, tokens) where tokens contains access_token and refresh_token.
|
||||
|
||||
Raises:
|
||||
InvalidCredentialsError: If email or password is incorrect.
|
||||
AccountDisabledError: If user account is disabled.
|
||||
|
||||
Example:
|
||||
>>> user, tokens = auth_service.login("user@example.com", "password")
|
||||
>>> print(tokens.access_token)
|
||||
"""
|
||||
# Get user by email
|
||||
user = self.user_repository.get_user_by_email(email)
|
||||
if not user:
|
||||
raise InvalidCredentialsError()
|
||||
Raises:
|
||||
UserAlreadyExistsError: If email is already registered.
|
||||
|
||||
# Verify password
|
||||
if not self.password_manager.verify_password(password, user.hashed_password):
|
||||
raise InvalidCredentialsError()
|
||||
|
||||
# Check if account is active
|
||||
if not user.is_active:
|
||||
raise AccountDisabledError()
|
||||
|
||||
# Create tokens
|
||||
access_token = self.token_manager.create_access_token(user)
|
||||
refresh_token = self.token_manager.create_refresh_token()
|
||||
|
||||
# Store refresh token in database
|
||||
token_data = TokenData(
|
||||
token=refresh_token,
|
||||
token_type="refresh",
|
||||
user_id=user.id,
|
||||
expires_at=self.token_manager.get_refresh_token_expiration(),
|
||||
created_at=datetime.utcnow(),
|
||||
is_revoked=False
|
||||
)
|
||||
self.token_repository.save_token(token_data)
|
||||
|
||||
tokens = AccessTokenResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
token_type="bearer"
|
||||
)
|
||||
|
||||
return user, tokens
|
||||
Example:
|
||||
>>> user_data = UserCreate(
|
||||
... email="user@example.com",
|
||||
... username="john_doe",
|
||||
... password="SecurePass123!"
|
||||
... )
|
||||
>>> user = auth_service.register(user_data)
|
||||
"""
|
||||
hashed_password = self.password_manager.hash_password(user_data.password)
|
||||
user = self.user_repository.create_user(user_data, hashed_password)
|
||||
return user
|
||||
|
||||
def login(self, email: str, password: str) -> tuple[UserInDB, AccessTokenResponse]:
|
||||
"""
|
||||
Authenticate a user and create tokens.
|
||||
|
||||
def refresh_access_token(self, refresh_token: str) -> AccessTokenResponse:
|
||||
"""
|
||||
Create a new access token using a refresh token.
|
||||
|
||||
This method validates the refresh token and generates a new access token
|
||||
without requiring the user to re-enter their password.
|
||||
|
||||
Args:
|
||||
refresh_token: The refresh token to exchange.
|
||||
|
||||
Returns:
|
||||
New access and refresh tokens.
|
||||
|
||||
Raises:
|
||||
InvalidTokenError: If refresh token is invalid.
|
||||
ExpiredTokenError: If refresh token has expired.
|
||||
RevokedTokenError: If refresh token has been revoked.
|
||||
UserNotFoundError: If user no longer exists.
|
||||
AccountDisabledError: If user account is disabled.
|
||||
|
||||
Example:
|
||||
>>> tokens = auth_service.refresh_access_token(old_refresh_token)
|
||||
>>> print(tokens.access_token)
|
||||
"""
|
||||
# Validate refresh token
|
||||
token_data = self.token_repository.get_token(refresh_token, "refresh")
|
||||
if not token_data:
|
||||
raise InvalidTokenError("Invalid refresh token")
|
||||
|
||||
if token_data.is_revoked:
|
||||
raise RevokedTokenError()
|
||||
|
||||
if token_data.expires_at < datetime.utcnow():
|
||||
raise ExpiredTokenError()
|
||||
|
||||
# Get user
|
||||
user = self.user_repository.get_user_by_id(token_data.user_id)
|
||||
if not user:
|
||||
raise UserNotFoundError()
|
||||
|
||||
if not user.is_active:
|
||||
raise AccountDisabledError()
|
||||
|
||||
# Create new tokens
|
||||
access_token = self.token_manager.create_access_token(user)
|
||||
new_refresh_token = self.token_manager.create_refresh_token()
|
||||
|
||||
# Revoke old refresh token
|
||||
self.token_repository.revoke_token(refresh_token)
|
||||
|
||||
# Store new refresh token
|
||||
new_token_data = TokenData(
|
||||
token=new_refresh_token,
|
||||
token_type="refresh",
|
||||
user_id=user.id,
|
||||
expires_at=self.token_manager.get_refresh_token_expiration(),
|
||||
created_at=datetime.utcnow(),
|
||||
is_revoked=False
|
||||
)
|
||||
self.token_repository.save_token(new_token_data)
|
||||
|
||||
return AccessTokenResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=new_refresh_token,
|
||||
token_type="bearer"
|
||||
)
|
||||
This method verifies credentials, checks account status, and generates
|
||||
both access and refresh tokens.
|
||||
|
||||
def logout(self, refresh_token: str) -> bool:
|
||||
"""
|
||||
Logout a user by revoking their refresh token.
|
||||
Args:
|
||||
email: User's email address.
|
||||
password: User's plain text password.
|
||||
|
||||
This prevents the refresh token from being used to obtain new
|
||||
access tokens. The current access token will remain valid until
|
||||
it expires naturally.
|
||||
Returns:
|
||||
Tuple of (user, tokens) where tokens contains access_token and refresh_token.
|
||||
|
||||
Args:
|
||||
refresh_token: The refresh token to revoke.
|
||||
|
||||
Returns:
|
||||
True if logout was successful, False if token not found.
|
||||
|
||||
Example:
|
||||
>>> auth_service.logout(refresh_token)
|
||||
True
|
||||
"""
|
||||
return self.token_repository.revoke_token(refresh_token)
|
||||
Raises:
|
||||
InvalidCredentialsError: If email or password is incorrect.
|
||||
AccountDisabledError: If user account is disabled.
|
||||
|
||||
Example:
|
||||
>>> user, tokens = auth_service.login("user@example.com", "password")
|
||||
>>> print(tokens.access_token)
|
||||
"""
|
||||
# Get user by email
|
||||
user = self.user_repository.get_user_by_email(email)
|
||||
if not user:
|
||||
raise InvalidCredentialsError()
|
||||
|
||||
def request_password_reset(self, email: str) -> str:
|
||||
"""
|
||||
Generate a password reset token for a user.
|
||||
|
||||
This method creates a secure token that can be sent to the user
|
||||
via email to reset their password. If an email service is configured,
|
||||
the email is sent automatically. Otherwise, the token is returned
|
||||
for manual handling.
|
||||
|
||||
Args:
|
||||
email: User's email address.
|
||||
|
||||
Returns:
|
||||
The password reset token to be sent via email.
|
||||
|
||||
Raises:
|
||||
UserNotFoundError: If email is not registered.
|
||||
|
||||
Example:
|
||||
>>> token = auth_service.request_password_reset("user@example.com")
|
||||
>>> # Token is automatically sent via email if service is configured
|
||||
"""
|
||||
user = self.user_repository.get_user_by_email(email)
|
||||
if not user:
|
||||
raise UserNotFoundError(f"No user found with email {email}")
|
||||
|
||||
# Create reset token
|
||||
reset_token = self.token_manager.create_password_reset_token()
|
||||
|
||||
# Store token in database
|
||||
token_data = TokenData(
|
||||
token=reset_token,
|
||||
token_type="password_reset",
|
||||
user_id=user.id,
|
||||
expires_at=self.token_manager.get_password_reset_token_expiration(),
|
||||
created_at=datetime.utcnow(),
|
||||
is_revoked=False
|
||||
)
|
||||
self.token_repository.save_token(token_data)
|
||||
|
||||
# Send email if service is configured
|
||||
if self.email_service:
|
||||
self.email_service.send_password_reset_email(email, reset_token)
|
||||
|
||||
return reset_token
|
||||
# Verify password
|
||||
if not self.password_manager.verify_password(password, user.hashed_password):
|
||||
raise InvalidCredentialsError()
|
||||
|
||||
def reset_password(self, token: str, new_password: str) -> bool:
|
||||
"""
|
||||
Reset a user's password using a reset token.
|
||||
|
||||
This method validates the reset token and updates the user's password.
|
||||
All existing refresh tokens for the user are revoked for security.
|
||||
|
||||
Args:
|
||||
token: Password reset token.
|
||||
new_password: New plain text password (will be hashed).
|
||||
|
||||
Returns:
|
||||
True if password was reset successfully.
|
||||
|
||||
Raises:
|
||||
InvalidTokenError: If reset token is invalid.
|
||||
ExpiredTokenError: If reset token has expired.
|
||||
RevokedTokenError: If reset token has been used.
|
||||
UserNotFoundError: If user no longer exists.
|
||||
|
||||
Example:
|
||||
>>> auth_service.reset_password(token, "NewSecurePass123!")
|
||||
True
|
||||
"""
|
||||
# Validate reset token
|
||||
token_data = self.token_repository.get_token(token, "password_reset")
|
||||
if not token_data:
|
||||
raise InvalidTokenError("Invalid password reset token")
|
||||
|
||||
if token_data.is_revoked:
|
||||
raise RevokedTokenError("Password reset token has already been used")
|
||||
|
||||
if token_data.expires_at < datetime.utcnow():
|
||||
raise ExpiredTokenError("Password reset token has expired")
|
||||
|
||||
# Get user
|
||||
user = self.user_repository.get_user_by_id(token_data.user_id)
|
||||
if not user:
|
||||
raise UserNotFoundError()
|
||||
|
||||
# Hash new password
|
||||
hashed_password = self.password_manager.hash_password(new_password)
|
||||
|
||||
# Update user password
|
||||
updates = UserUpdate(password=hashed_password)
|
||||
self.user_repository.update_user(user.id, updates)
|
||||
|
||||
# Revoke the reset token
|
||||
self.token_repository.revoke_token(token)
|
||||
|
||||
# Revoke all user's refresh tokens for security
|
||||
self.token_repository.revoke_all_user_tokens(user.id, "refresh")
|
||||
|
||||
return True
|
||||
# Check if account is active
|
||||
if not user.is_active:
|
||||
raise AccountDisabledError()
|
||||
|
||||
def request_email_verification(self, email: str) -> str:
|
||||
"""
|
||||
Generate an email verification token for a user.
|
||||
|
||||
This method creates a JWT token that can be sent to the user
|
||||
to verify their email address. If an email service is configured,
|
||||
the email is sent automatically. Otherwise, the token is returned
|
||||
for manual handling.
|
||||
|
||||
Args:
|
||||
email: User's email address.
|
||||
|
||||
Returns:
|
||||
The email verification token (JWT) to be sent via email.
|
||||
|
||||
Raises:
|
||||
UserNotFoundError: If email is not registered.
|
||||
|
||||
Example:
|
||||
>>> token = auth_service.request_email_verification("user@example.com")
|
||||
>>> # Token is automatically sent via email if service is configured
|
||||
"""
|
||||
user = self.user_repository.get_user_by_email(email)
|
||||
if not user:
|
||||
raise UserNotFoundError(f"No user found with email {email}")
|
||||
|
||||
verification_token = self.token_manager.create_email_verification_token(email)
|
||||
|
||||
# Send email if service is configured
|
||||
if self.email_service:
|
||||
self.email_service.send_verification_email(email, verification_token)
|
||||
|
||||
return verification_token
|
||||
# Create tokens
|
||||
access_token = self.token_manager.create_access_token(user)
|
||||
refresh_token = self.token_manager.create_refresh_token()
|
||||
|
||||
def verify_email(self, token: str) -> bool:
|
||||
"""
|
||||
Verify a user's email address using a verification token.
|
||||
|
||||
This method decodes the JWT token, extracts the email, and marks
|
||||
the user's email as verified.
|
||||
|
||||
Args:
|
||||
token: Email verification token (JWT).
|
||||
|
||||
Returns:
|
||||
True if email was verified successfully.
|
||||
|
||||
Raises:
|
||||
InvalidTokenError: If verification token is invalid.
|
||||
ExpiredTokenError: If verification token has expired.
|
||||
UserNotFoundError: If user no longer exists.
|
||||
|
||||
Example:
|
||||
>>> auth_service.verify_email(token)
|
||||
True
|
||||
"""
|
||||
# Decode and validate token
|
||||
email = self.token_manager.decode_email_verification_token(token)
|
||||
|
||||
# Get user
|
||||
user = self.user_repository.get_user_by_email(email)
|
||||
if not user:
|
||||
raise UserNotFoundError(f"No user found with email {email}")
|
||||
|
||||
# Update user's verified status
|
||||
updates = UserUpdate(is_verified=True)
|
||||
self.user_repository.update_user(user.id, updates)
|
||||
|
||||
return True
|
||||
# Store refresh token in database
|
||||
token_data = TokenData(
|
||||
token=refresh_token,
|
||||
token_type="refresh",
|
||||
user_id=user.id,
|
||||
expires_at=self.token_manager.get_refresh_token_expiration(),
|
||||
created_at=datetime.utcnow(),
|
||||
is_revoked=False
|
||||
)
|
||||
self.token_repository.save_token(token_data)
|
||||
|
||||
def get_current_user(self, access_token: str) -> UserInDB:
|
||||
"""
|
||||
Retrieve the current user from an access token.
|
||||
tokens = AccessTokenResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
token_type="bearer"
|
||||
)
|
||||
|
||||
return user, tokens
|
||||
|
||||
def refresh_access_token(self, refresh_token: str) -> AccessTokenResponse:
|
||||
"""
|
||||
Create a new access token using a refresh token.
|
||||
|
||||
This method validates the refresh token and generates a new access token
|
||||
without requiring the user to re-enter their password.
|
||||
|
||||
Args:
|
||||
refresh_token: The refresh token to exchange.
|
||||
|
||||
This method decodes and validates the JWT access token and
|
||||
retrieves the corresponding user from the database.
|
||||
Returns:
|
||||
New access and refresh tokens.
|
||||
|
||||
Args:
|
||||
access_token: JWT access token.
|
||||
|
||||
Returns:
|
||||
The user associated with the token.
|
||||
|
||||
Raises:
|
||||
InvalidTokenError: If access token is invalid.
|
||||
ExpiredTokenError: If access token has expired.
|
||||
UserNotFoundError: If user no longer exists.
|
||||
AccountDisabledError: If user account is disabled.
|
||||
|
||||
Example:
|
||||
>>> user = auth_service.get_current_user(access_token)
|
||||
>>> print(user.email)
|
||||
"""
|
||||
# Decode and validate token
|
||||
token_payload = self.token_manager.decode_access_token(access_token)
|
||||
Raises:
|
||||
InvalidTokenError: If refresh token is invalid.
|
||||
ExpiredTokenError: If refresh token has expired.
|
||||
RevokedTokenError: If refresh token has been revoked.
|
||||
UserNotFoundError: If user no longer exists.
|
||||
AccountDisabledError: If user account is disabled.
|
||||
|
||||
# Get user
|
||||
user = self.user_repository.get_user_by_id(token_payload.sub)
|
||||
if not user:
|
||||
raise UserNotFoundError()
|
||||
Example:
|
||||
>>> tokens = auth_service.refresh_access_token(old_refresh_token)
|
||||
>>> print(tokens.access_token)
|
||||
"""
|
||||
# Validate refresh token
|
||||
token_data = self.token_repository.get_token(refresh_token, "refresh")
|
||||
if not token_data:
|
||||
raise InvalidTokenError("Invalid refresh token")
|
||||
|
||||
if token_data.is_revoked:
|
||||
raise RevokedTokenError()
|
||||
|
||||
if token_data.expires_at < datetime.utcnow():
|
||||
raise ExpiredTokenError()
|
||||
|
||||
# Get user
|
||||
user = self.user_repository.get_user_by_id(token_data.user_id)
|
||||
if not user:
|
||||
raise UserNotFoundError()
|
||||
|
||||
if not user.is_active:
|
||||
raise AccountDisabledError()
|
||||
|
||||
# Create new tokens
|
||||
access_token = self.token_manager.create_access_token(user)
|
||||
new_refresh_token = self.token_manager.create_refresh_token()
|
||||
|
||||
# Revoke old refresh token
|
||||
self.token_repository.revoke_token(refresh_token)
|
||||
|
||||
# Store new refresh token
|
||||
new_token_data = TokenData(
|
||||
token=new_refresh_token,
|
||||
token_type="refresh",
|
||||
user_id=user.id,
|
||||
expires_at=self.token_manager.get_refresh_token_expiration(),
|
||||
created_at=datetime.utcnow(),
|
||||
is_revoked=False
|
||||
)
|
||||
self.token_repository.save_token(new_token_data)
|
||||
|
||||
return AccessTokenResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=new_refresh_token,
|
||||
token_type="bearer"
|
||||
)
|
||||
|
||||
def logout(self, refresh_token: str) -> bool:
|
||||
"""
|
||||
Logout a user by revoking their refresh token.
|
||||
|
||||
This prevents the refresh token from being used to obtain new
|
||||
access tokens. The current access token will remain valid until
|
||||
it expires naturally.
|
||||
|
||||
Args:
|
||||
refresh_token: The refresh token to revoke.
|
||||
|
||||
if not user.is_active:
|
||||
raise AccountDisabledError()
|
||||
Returns:
|
||||
True if logout was successful, False if token not found.
|
||||
|
||||
return user
|
||||
Example:
|
||||
>>> auth_service.logout(refresh_token)
|
||||
True
|
||||
"""
|
||||
return self.token_repository.revoke_token(refresh_token)
|
||||
|
||||
def request_password_reset(self, email: str) -> str:
|
||||
"""
|
||||
Generate a password reset token for a user.
|
||||
|
||||
This method creates a secure token that can be sent to the user
|
||||
via email to reset their password. If an email service is configured,
|
||||
the email is sent automatically. Otherwise, the token is returned
|
||||
for manual handling.
|
||||
|
||||
Args:
|
||||
email: User's email address.
|
||||
|
||||
Returns:
|
||||
The password reset token to be sent via email.
|
||||
|
||||
Raises:
|
||||
UserNotFoundError: If email is not registered.
|
||||
|
||||
Example:
|
||||
>>> token = auth_service.request_password_reset("user@example.com")
|
||||
>>> # Token is automatically sent via email if service is configured
|
||||
"""
|
||||
user = self.user_repository.get_user_by_email(email)
|
||||
if not user:
|
||||
raise UserNotFoundError(f"No user found with email {email}")
|
||||
|
||||
# Create reset token
|
||||
reset_token = self.token_manager.create_password_reset_token()
|
||||
|
||||
# Store token in database
|
||||
token_data = TokenData(
|
||||
token=reset_token,
|
||||
token_type="password_reset",
|
||||
user_id=user.id,
|
||||
expires_at=self.token_manager.get_password_reset_token_expiration(),
|
||||
created_at=datetime.utcnow(),
|
||||
is_revoked=False
|
||||
)
|
||||
self.token_repository.save_token(token_data)
|
||||
|
||||
# Send email if service is configured
|
||||
if self.email_service:
|
||||
self.email_service.send_password_reset_email(email, reset_token)
|
||||
|
||||
return reset_token
|
||||
|
||||
def reset_password(self, token: str, new_password: str) -> bool:
|
||||
"""
|
||||
Reset a user's password using a reset token.
|
||||
|
||||
This method validates the reset token and updates the user's password.
|
||||
All existing refresh tokens for the user are revoked for security.
|
||||
|
||||
Args:
|
||||
token: Password reset token.
|
||||
new_password: New plain text password (will be hashed).
|
||||
|
||||
Returns:
|
||||
True if password was reset successfully.
|
||||
|
||||
Raises:
|
||||
InvalidTokenError: If reset token is invalid.
|
||||
ExpiredTokenError: If reset token has expired.
|
||||
RevokedTokenError: If reset token has been used.
|
||||
UserNotFoundError: If user no longer exists.
|
||||
|
||||
Example:
|
||||
>>> auth_service.reset_password(token, "NewSecurePass123!")
|
||||
True
|
||||
"""
|
||||
# Validate reset token
|
||||
token_data = self.token_repository.get_token(token, "password_reset")
|
||||
if not token_data:
|
||||
raise InvalidTokenError("Invalid password reset token")
|
||||
|
||||
if token_data.is_revoked:
|
||||
raise RevokedTokenError("Password reset token has already been used")
|
||||
|
||||
if token_data.expires_at < datetime.utcnow():
|
||||
raise ExpiredTokenError("Password reset token has expired")
|
||||
|
||||
# Get user
|
||||
user = self.user_repository.get_user_by_id(token_data.user_id)
|
||||
if not user:
|
||||
raise UserNotFoundError()
|
||||
|
||||
# Hash new password
|
||||
hashed_password = self.password_manager.hash_password(new_password)
|
||||
|
||||
# Update user password
|
||||
updates = UserUpdate(password=hashed_password)
|
||||
self.user_repository.update_user(user.id, updates)
|
||||
|
||||
# Revoke the reset token
|
||||
self.token_repository.revoke_token(token)
|
||||
|
||||
# Revoke all user's refresh tokens for security
|
||||
self.token_repository.revoke_all_user_tokens(user.id, "refresh")
|
||||
|
||||
return True
|
||||
|
||||
def request_email_verification(self, email: str) -> str:
|
||||
"""
|
||||
Generate an email verification token for a user.
|
||||
|
||||
This method creates a JWT token that can be sent to the user
|
||||
to verify their email address. If an email service is configured,
|
||||
the email is sent automatically. Otherwise, the token is returned
|
||||
for manual handling.
|
||||
|
||||
Args:
|
||||
email: User's email address.
|
||||
|
||||
Returns:
|
||||
The email verification token (JWT) to be sent via email.
|
||||
|
||||
Raises:
|
||||
UserNotFoundError: If email is not registered.
|
||||
|
||||
Example:
|
||||
>>> token = auth_service.request_email_verification("user@example.com")
|
||||
>>> # Token is automatically sent via email if service is configured
|
||||
"""
|
||||
user = self.user_repository.get_user_by_email(email)
|
||||
if not user:
|
||||
raise UserNotFoundError(f"No user found with email {email}")
|
||||
|
||||
verification_token = self.token_manager.create_email_verification_token(email)
|
||||
|
||||
# Send email if service is configured
|
||||
if self.email_service:
|
||||
self.email_service.send_verification_email(email, verification_token)
|
||||
|
||||
return verification_token
|
||||
|
||||
def verify_email(self, token: str) -> bool:
|
||||
"""
|
||||
Verify a user's email address using a verification token.
|
||||
|
||||
This method decodes the JWT token, extracts the email, and marks
|
||||
the user's email as verified.
|
||||
|
||||
Args:
|
||||
token: Email verification token (JWT).
|
||||
|
||||
Returns:
|
||||
True if email was verified successfully.
|
||||
|
||||
Raises:
|
||||
InvalidTokenError: If verification token is invalid.
|
||||
ExpiredTokenError: If verification token has expired.
|
||||
UserNotFoundError: If user no longer exists.
|
||||
|
||||
Example:
|
||||
>>> auth_service.verify_email(token)
|
||||
True
|
||||
"""
|
||||
# Decode and validate token
|
||||
email = self.token_manager.decode_email_verification_token(token)
|
||||
|
||||
# Get user
|
||||
user = self.user_repository.get_user_by_email(email)
|
||||
if not user:
|
||||
raise UserNotFoundError(f"No user found with email {email}")
|
||||
|
||||
# Update user's verified status
|
||||
updates = UserUpdate(is_verified=True)
|
||||
self.user_repository.update_user(user.id, updates)
|
||||
|
||||
return True
|
||||
|
||||
def get_current_user(self, access_token: str) -> UserInDB:
|
||||
"""
|
||||
Retrieve the current user from an access token.
|
||||
|
||||
This method decodes and validates the JWT access token and
|
||||
retrieves the corresponding user from the database.
|
||||
|
||||
Args:
|
||||
access_token: JWT access token.
|
||||
|
||||
Returns:
|
||||
The user associated with the token.
|
||||
|
||||
Raises:
|
||||
InvalidTokenError: If access token is invalid.
|
||||
ExpiredTokenError: If access token has expired.
|
||||
UserNotFoundError: If user no longer exists.
|
||||
AccountDisabledError: If user account is disabled.
|
||||
|
||||
Example:
|
||||
>>> user = auth_service.get_current_user(access_token)
|
||||
>>> print(user.email)
|
||||
"""
|
||||
# Decode and validate token
|
||||
token_payload = self.token_manager.decode_access_token(access_token)
|
||||
|
||||
# Get user
|
||||
user = self.user_repository.get_user_by_id(token_payload.sub)
|
||||
if not user:
|
||||
raise UserNotFoundError()
|
||||
|
||||
if not user.is_active:
|
||||
raise AccountDisabledError()
|
||||
|
||||
return user
|
||||
|
||||
def count_users(self) -> int:
|
||||
"""
|
||||
Counts the total number of users.
|
||||
|
||||
This method retrieves and returns the total count of users
|
||||
from the user repository. It provides functionality for
|
||||
fetching user count stored in the underlying repository of
|
||||
users.
|
||||
|
||||
:return: The total count of users.
|
||||
:rtype: int
|
||||
"""
|
||||
return self.user_repository.count_users()
|
||||
|
||||
def list_users(self, skip: int = 0, limit: int = 100):
|
||||
"""
|
||||
Lists users from the user repository with optional pagination.
|
||||
|
||||
This method retrieves a list of users, allowing optional pagination
|
||||
by specifying the number of records to skip and the maximum number
|
||||
of users to retrieve.
|
||||
|
||||
:param skip: The number of users to skip in the result set. Defaults to 0.
|
||||
:type skip: int
|
||||
:param limit: The maximum number of users to retrieve. Defaults to 100.
|
||||
:type limit: int
|
||||
:return: A list of users retrieved from the repository.
|
||||
:rtype: list
|
||||
"""
|
||||
return self.user_repository.list_users(skip, limit)
|
||||
|
||||
@@ -33,6 +33,16 @@ class UserBase(BaseModel):
|
||||
user_settings: dict = Field(default_factory=dict)
|
||||
|
||||
|
||||
class UserCreateNoValidation(UserBase):
|
||||
"""
|
||||
Model for user creation (registration).
|
||||
|
||||
This model extends UserBase with a password field
|
||||
"""
|
||||
|
||||
password: str
|
||||
|
||||
|
||||
class UserCreate(UserBase):
|
||||
"""
|
||||
Model for user creation (registration).
|
||||
|
||||
@@ -120,6 +120,31 @@ class UserRepository(ABC):
|
||||
True if email exists, False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_users(self, skip: int = 0, limit: int = 100) -> list[UserInDB]:
|
||||
"""
|
||||
List users with pagination.
|
||||
|
||||
Args:
|
||||
skip (int): Number of users to skip (default: 0)
|
||||
limit (int): Maximum number of users to return (default: 100)
|
||||
|
||||
Returns:
|
||||
List[UserInDB]: List of users
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def count_users(self) -> int:
|
||||
"""
|
||||
Count total number of users.
|
||||
|
||||
Returns:
|
||||
int: Total number of users in system
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
|
||||
class TokenRepository(ABC):
|
||||
|
||||
@@ -364,7 +364,51 @@ class SQLiteUserRepository(UserRepository):
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT 1 FROM users WHERE email = ? LIMIT 1", (email,))
|
||||
return cursor.fetchone() is not None
|
||||
|
||||
def list_users(self, skip: int = 0, limit: int = 100) -> list[UserInDB]:
|
||||
"""
|
||||
List users with pagination.
|
||||
|
||||
Args:
|
||||
skip: Number of users to skip (default: 0).
|
||||
limit: Maximum number of users to return (default: 100).
|
||||
|
||||
Returns:
|
||||
List of users.
|
||||
"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
SELECT id,
|
||||
email,
|
||||
username,
|
||||
hashed_password,
|
||||
roles,
|
||||
user_settings,
|
||||
is_verified,
|
||||
is_active,
|
||||
created_at,
|
||||
updated_at
|
||||
FROM users
|
||||
ORDER BY created_at DESC LIMIT ?
|
||||
OFFSET ?
|
||||
""", (limit, skip))
|
||||
|
||||
rows = cursor.fetchall()
|
||||
return [self._row_to_user(row) for row in rows]
|
||||
|
||||
def count_users(self) -> int:
|
||||
"""
|
||||
Count total number of users.
|
||||
|
||||
Returns:
|
||||
Total number of users in system.
|
||||
"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT COUNT(*) FROM users")
|
||||
result = cursor.fetchone()
|
||||
return result[0] if result else 0
|
||||
|
||||
class SQLiteTokenRepository(TokenRepository):
|
||||
"""
|
||||
|
||||
@@ -77,6 +77,178 @@ class TestAuthServiceRegisterLogin(object):
|
||||
"""Failure: Login fails if the user does not exist."""
|
||||
with pytest.raises(InvalidCredentialsError):
|
||||
auth_service.login("non.existent@example.com", "AnyPassword")
|
||||
|
||||
def test_create_admin_if_needed_success_with_custom_credentials(
|
||||
self,
|
||||
auth_service: AuthService
|
||||
):
|
||||
"""Success: Admin is created with custom credentials when no users exist."""
|
||||
|
||||
# Arrange
|
||||
custom_email = "custom.admin@example.com"
|
||||
custom_username = "custom_admin"
|
||||
custom_password = "CustomAdminPass123!"
|
||||
|
||||
# Act
|
||||
result = auth_service.create_admin_if_needed(
|
||||
admin_email=custom_email,
|
||||
admin_username=custom_username,
|
||||
admin_password=custom_password
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
# Verify admin user was created
|
||||
admin_user = auth_service.user_repository.get_user_by_email(custom_email)
|
||||
assert admin_user is not None
|
||||
assert admin_user.email == custom_email
|
||||
assert admin_user.username == custom_username
|
||||
assert "admin" in admin_user.roles
|
||||
|
||||
# Verify password was hashed
|
||||
auth_service.password_manager.hash_password.assert_called()
|
||||
|
||||
def test_create_admin_if_needed_success_with_default_credentials(
|
||||
self,
|
||||
auth_service: AuthService,
|
||||
monkeypatch
|
||||
):
|
||||
"""Success: Admin is created with default credentials from environment variables."""
|
||||
|
||||
# Arrange
|
||||
monkeypatch.setenv("AUTH_ADMIN_EMAIL", "env.admin@example.com")
|
||||
monkeypatch.setenv("AUTH_ADMIN_USERNAME", "env_admin")
|
||||
monkeypatch.setenv("AUTH_ADMIN_PASSWORD", "EnvAdminPass123!")
|
||||
|
||||
# Act
|
||||
result = auth_service.create_admin_if_needed()
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
# Verify admin user was created with env variables
|
||||
admin_user = auth_service.user_repository.get_user_by_email("env.admin@example.com")
|
||||
assert admin_user is not None
|
||||
assert admin_user.email == "env.admin@example.com"
|
||||
assert admin_user.username == "env_admin"
|
||||
assert "admin" in admin_user.roles
|
||||
|
||||
def test_create_admin_if_needed_success_with_hardcoded_defaults(
|
||||
self,
|
||||
auth_service: AuthService,
|
||||
monkeypatch
|
||||
):
|
||||
"""Success: Admin is created with hardcoded defaults when no env vars or params provided."""
|
||||
|
||||
# Arrange - Clear any existing env variables
|
||||
monkeypatch.delenv("AUTH_ADMIN_EMAIL", raising=False)
|
||||
monkeypatch.delenv("AUTH_ADMIN_USERNAME", raising=False)
|
||||
monkeypatch.delenv("AUTH_ADMIN_PASSWORD", raising=False)
|
||||
|
||||
# Act
|
||||
result = auth_service.create_admin_if_needed()
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
# Verify admin user was created with hardcoded defaults
|
||||
admin_user = auth_service.user_repository.get_user_by_email("admin@myauth.com")
|
||||
assert admin_user is not None
|
||||
assert admin_user.email == "admin@myauth.com"
|
||||
assert admin_user.username == "admin"
|
||||
assert "admin" in admin_user.roles
|
||||
|
||||
def test_create_admin_if_needed_no_creation_when_users_exist(
|
||||
self,
|
||||
auth_service: AuthService,
|
||||
test_user_data_create: UserCreate
|
||||
):
|
||||
"""Failure: Admin is not created when users already exist in the system."""
|
||||
|
||||
# Arrange - Create a regular user first
|
||||
auth_service.register(test_user_data_create)
|
||||
|
||||
# Act
|
||||
result = auth_service.create_admin_if_needed(
|
||||
admin_email="should.not.be.created@example.com",
|
||||
admin_username="should_not_exist",
|
||||
admin_password="ShouldNotExist123!"
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
# Verify admin user was NOT created
|
||||
admin_user = auth_service.user_repository.get_user_by_email(
|
||||
"should.not.be.created@example.com"
|
||||
)
|
||||
assert admin_user is None
|
||||
|
||||
# Verify only the original user exists
|
||||
assert auth_service.count_users() == 1
|
||||
|
||||
def test_create_admin_if_needed_parameters_override_env_variables(
|
||||
self,
|
||||
auth_service: AuthService,
|
||||
monkeypatch
|
||||
):
|
||||
"""Success: Parameters take precedence over environment variables."""
|
||||
|
||||
# Arrange
|
||||
monkeypatch.setenv("AUTH_ADMIN_EMAIL", "env.admin@example.com")
|
||||
monkeypatch.setenv("AUTH_ADMIN_USERNAME", "env_admin")
|
||||
monkeypatch.setenv("AUTH_ADMIN_PASSWORD", "EnvAdminPass123!")
|
||||
|
||||
param_email = "param.admin@example.com"
|
||||
param_username = "param_admin"
|
||||
param_password = "ParamAdminPass123!"
|
||||
|
||||
# Act
|
||||
result = auth_service.create_admin_if_needed(
|
||||
admin_email=param_email,
|
||||
admin_username=param_username,
|
||||
admin_password=param_password
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
# Verify parameters were used, not env variables
|
||||
admin_user = auth_service.user_repository.get_user_by_email(param_email)
|
||||
assert admin_user is not None
|
||||
assert admin_user.email == param_email
|
||||
assert admin_user.username == param_username
|
||||
|
||||
# Verify env admin was NOT created
|
||||
env_admin = auth_service.user_repository.get_user_by_email("env.admin@example.com")
|
||||
assert env_admin is None
|
||||
|
||||
def test_create_admin_if_needed_mixed_parameters_and_env(
|
||||
self,
|
||||
auth_service: AuthService,
|
||||
monkeypatch
|
||||
):
|
||||
"""Success: Partial parameters combine with environment variables."""
|
||||
|
||||
# Arrange
|
||||
monkeypatch.setenv("AUTH_ADMIN_EMAIL", "env.admin@example.com")
|
||||
monkeypatch.setenv("AUTH_ADMIN_USERNAME", "env_admin")
|
||||
monkeypatch.setenv("AUTH_ADMIN_PASSWORD", "EnvAdminPass123!")
|
||||
|
||||
# Act - Only provide email as parameter
|
||||
result = auth_service.create_admin_if_needed(
|
||||
admin_email="partial.admin@example.com"
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
# Verify email from parameter, username and password from env
|
||||
admin_user = auth_service.user_repository.get_user_by_email("partial.admin@example.com")
|
||||
assert admin_user is not None
|
||||
assert admin_user.email == "partial.admin@example.com"
|
||||
assert admin_user.username == "env_admin"
|
||||
|
||||
|
||||
class TestAuthServiceTokenManagement(object):
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
# tests/persistence/test_sqlite_user.py
|
||||
|
||||
import pytest
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from myauth.persistence.sqlite import SQLiteUserRepository
|
||||
from myauth.models.user import UserCreate, UserUpdate
|
||||
import pytest
|
||||
|
||||
from myauth.exceptions import UserAlreadyExistsError, UserNotFoundError
|
||||
from myauth.models.user import UserCreate, UserUpdate
|
||||
from myauth.persistence.sqlite import SQLiteUserRepository
|
||||
|
||||
|
||||
def test_i_can_create_and_retrieve_user_by_email(user_repository: SQLiteUserRepository,
|
||||
@@ -152,4 +152,139 @@ def test_i_cannot_retrieve_non_existent_user_by_email(user_repository: SQLiteUse
|
||||
"""Ensures retrieval by email returns None for non-existent email."""
|
||||
|
||||
retrieved_user = user_repository.get_user_by_email("ghost@example.com")
|
||||
assert retrieved_user is None
|
||||
assert retrieved_user is None
|
||||
|
||||
|
||||
def test_i_can_list_users_with_pagination(user_repository: SQLiteUserRepository,
|
||||
test_user_hashed_password: str):
|
||||
"""Verifies that list_users returns paginated results correctly."""
|
||||
|
||||
# Create multiple users
|
||||
for i in range(5):
|
||||
user_data = UserCreate(
|
||||
email=f"user{i}@example.com",
|
||||
username=f"User{i}",
|
||||
password="#Password123",
|
||||
roles=["user"],
|
||||
user_settings={}
|
||||
)
|
||||
user_repository.create_user(user_data, test_user_hashed_password)
|
||||
|
||||
# Test: Get first 3 users
|
||||
users_page1 = user_repository.list_users(skip=0, limit=3)
|
||||
assert len(users_page1) == 3
|
||||
|
||||
# Test: Get next 2 users
|
||||
users_page2 = user_repository.list_users(skip=3, limit=3)
|
||||
assert len(users_page2) == 2
|
||||
|
||||
# Test: Verify no duplicates between pages
|
||||
page1_ids = {user.id for user in users_page1}
|
||||
page2_ids = {user.id for user in users_page2}
|
||||
assert len(page1_ids.intersection(page2_ids)) == 0
|
||||
|
||||
|
||||
def test_i_can_list_users_with_default_pagination(user_repository: SQLiteUserRepository,
|
||||
test_user_data_create: UserCreate,
|
||||
test_user_hashed_password: str):
|
||||
"""Verifies that list_users works with default parameters."""
|
||||
|
||||
# Create 2 users
|
||||
user_repository.create_user(test_user_data_create, test_user_hashed_password)
|
||||
|
||||
user_data2 = UserCreate(
|
||||
email="user2@example.com",
|
||||
username="User2",
|
||||
password="#Password123",
|
||||
roles=["user"],
|
||||
user_settings={}
|
||||
)
|
||||
user_repository.create_user(user_data2, test_user_hashed_password)
|
||||
|
||||
# Test: Default parameters (skip=0, limit=100)
|
||||
users = user_repository.list_users()
|
||||
assert len(users) == 2
|
||||
assert all(isinstance(user.created_at, datetime) for user in users)
|
||||
|
||||
|
||||
def test_i_get_empty_list_when_no_users_exist(user_repository: SQLiteUserRepository):
|
||||
"""Verifies that list_users returns an empty list when no users exist."""
|
||||
|
||||
users = user_repository.list_users()
|
||||
assert users == []
|
||||
assert isinstance(users, list)
|
||||
|
||||
|
||||
def test_i_can_skip_beyond_available_users(user_repository: SQLiteUserRepository,
|
||||
test_user_data_create: UserCreate,
|
||||
test_user_hashed_password: str):
|
||||
"""Verifies that skipping beyond available users returns an empty list."""
|
||||
|
||||
user_repository.create_user(test_user_data_create, test_user_hashed_password)
|
||||
|
||||
# Skip beyond the only user
|
||||
users = user_repository.list_users(skip=10, limit=10)
|
||||
assert users == []
|
||||
|
||||
|
||||
def test_i_can_count_users(user_repository: SQLiteUserRepository,
|
||||
test_user_hashed_password: str):
|
||||
"""Verifies that count_users returns the correct number of users."""
|
||||
|
||||
# Initial count should be 0
|
||||
assert user_repository.count_users() == 0
|
||||
|
||||
# Create first user
|
||||
user_data1 = UserCreate(
|
||||
email="user1@example.com",
|
||||
username="User1",
|
||||
password="#Password123",
|
||||
roles=["user"],
|
||||
user_settings={}
|
||||
)
|
||||
user_repository.create_user(user_data1, test_user_hashed_password)
|
||||
assert user_repository.count_users() == 1
|
||||
|
||||
# Create second user
|
||||
user_data2 = UserCreate(
|
||||
email="user2@example.com",
|
||||
username="User2",
|
||||
password="#Password123",
|
||||
roles=["user"],
|
||||
user_settings={}
|
||||
)
|
||||
user_repository.create_user(user_data2, test_user_hashed_password)
|
||||
assert user_repository.count_users() == 2
|
||||
|
||||
|
||||
def test_i_get_zero_count_when_no_users_exist(user_repository: SQLiteUserRepository):
|
||||
"""Verifies that count_users returns 0 when the database is empty."""
|
||||
|
||||
count = user_repository.count_users()
|
||||
assert count == 0
|
||||
assert isinstance(count, int)
|
||||
|
||||
|
||||
def test_list_users_returns_correct_user_structure(user_repository: SQLiteUserRepository,
|
||||
test_user_data_create: UserCreate,
|
||||
test_user_hashed_password: str):
|
||||
"""Verifies that list_users returns UserInDB objects with all fields."""
|
||||
|
||||
user_repository.create_user(test_user_data_create, test_user_hashed_password)
|
||||
|
||||
users = user_repository.list_users()
|
||||
|
||||
assert len(users) == 1
|
||||
user = users[0]
|
||||
|
||||
# Verify all fields are present and correct type
|
||||
assert user.id is not None
|
||||
assert user.email == test_user_data_create.email
|
||||
assert user.username == test_user_data_create.username
|
||||
assert user.hashed_password == test_user_hashed_password
|
||||
assert isinstance(user.roles, list)
|
||||
assert isinstance(user.user_settings, dict)
|
||||
assert isinstance(user.is_verified, bool)
|
||||
assert isinstance(user.is_active, bool)
|
||||
assert isinstance(user.created_at, datetime)
|
||||
assert isinstance(user.updated_at, datetime)
|
||||
|
||||
Reference in New Issue
Block a user