225 lines
9.0 KiB
Python
225 lines
9.0 KiB
Python
# tests/core/test_token_manager.py
|
|
|
|
from datetime import datetime, timedelta
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
from jose import jwt
|
|
|
|
from myauth.core.token import TokenManager
|
|
from myauth.exceptions import InvalidTokenError, ExpiredTokenError
|
|
from myauth.models.user import UserInDB # Assuming you have a fixture for this
|
|
|
|
|
|
@pytest.fixture
|
|
def token_manager():
|
|
"""Provides a TokenManager instance with known, short expiration times for testing."""
|
|
return TokenManager(
|
|
jwt_secret="TEST_SECRET_KEY",
|
|
jwt_algorithm="HS256",
|
|
access_token_expire_minutes=1,
|
|
refresh_token_expire_days=7,
|
|
password_reset_token_expire_minutes=15
|
|
)
|
|
|
|
|
|
class TestTokenManagerInitialization:
|
|
"""Tests for TokenManager setup and configuration."""
|
|
|
|
def test_init_success(self):
|
|
"""Should initialize successfully with required parameters."""
|
|
tm = TokenManager(jwt_secret="MySecret")
|
|
assert tm.jwt_secret == "MySecret"
|
|
assert tm.jwt_algorithm == "HS256"
|
|
assert tm.access_token_expire_minutes == 30
|
|
assert tm.refresh_token_expire_days == 7
|
|
assert tm.password_reset_token_expire_minutes == 15
|
|
|
|
def test_init_failure_empty_secret(self):
|
|
"""Should raise ValueError if JWT secret is empty."""
|
|
with pytest.raises(ValueError, match="JWT secret cannot be empty"):
|
|
TokenManager(jwt_secret="")
|
|
|
|
|
|
class TestTokenCreation:
|
|
"""Tests creation methods for different token types."""
|
|
|
|
def test_create_access_token_format_and_expiration(self, token_manager: TokenManager, test_user_in_db: UserInDB):
|
|
"""Should create a valid JWT with correct payload and expiration."""
|
|
|
|
token = token_manager.create_access_token(test_user_in_db)
|
|
|
|
# 1. Assert token is a string (encoded)
|
|
assert isinstance(token, str)
|
|
|
|
# 2. Decode and check payload content
|
|
payload = jwt.decode(token, token_manager.jwt_secret, algorithms=[token_manager.jwt_algorithm])
|
|
|
|
assert payload["sub"] == test_user_in_db.id
|
|
assert payload["email"] == test_user_in_db.email
|
|
assert payload["type"] == "access"
|
|
|
|
# 3. Check expiration (should be within a small window of the expected time)
|
|
now = datetime.now()
|
|
expected_exp_dt = now + timedelta(minutes=token_manager.access_token_expire_minutes)
|
|
# Check if expiration is within +/- 1 second of the expected value
|
|
assert abs(payload["exp"] - int(expected_exp_dt.timestamp())) <= 1
|
|
|
|
def test_create_refresh_token_format(self, token_manager: TokenManager):
|
|
"""Should create a random hex string of length 64."""
|
|
token = token_manager.create_refresh_token()
|
|
assert isinstance(token, str)
|
|
assert len(token) == 64
|
|
assert all(c in '0123456789abcdef' for c in token)
|
|
|
|
def test_create_password_reset_token_format(self, token_manager: TokenManager):
|
|
"""Should create a random hex string of length 64."""
|
|
token = token_manager.create_password_reset_token()
|
|
assert isinstance(token, str)
|
|
assert len(token) == 64
|
|
|
|
def test_create_email_verification_token_format(self, token_manager: TokenManager):
|
|
"""Should create a JWT with email and 'email_verification' type."""
|
|
email = "verify@example.com"
|
|
token = token_manager.create_email_verification_token(email)
|
|
|
|
# Decode and check payload content
|
|
payload = jwt.decode(token, token_manager.jwt_secret, algorithms=[token_manager.jwt_algorithm])
|
|
|
|
assert payload["email"] == email
|
|
assert payload["type"] == "email_verification"
|
|
|
|
# Expiration check (set to 7 days in the implementation)
|
|
now = datetime.now()
|
|
expected_exp_dt = now + timedelta(days=7)
|
|
assert abs(payload["exp"] - int(expected_exp_dt.timestamp())) <= 1
|
|
|
|
|
|
class TestTokenExpirationCalculations:
|
|
"""Tests for token expiration date methods."""
|
|
|
|
# We patch datetime.now() to ensure stable calculations
|
|
@patch('myauth.core.token.datetime')
|
|
def test_get_refresh_token_expiration(self, mock_datetime, token_manager: TokenManager):
|
|
"""Should calculate refresh token expiration correctly."""
|
|
|
|
# Set a fixed starting time
|
|
start_time = datetime(2025, 1, 1, 10, 0, 0)
|
|
mock_datetime.now = MagicMock(return_value=start_time)
|
|
|
|
expected_exp = start_time + timedelta(days=token_manager.refresh_token_expire_days)
|
|
actual_exp = token_manager.get_refresh_token_expiration()
|
|
|
|
assert actual_exp == expected_exp
|
|
|
|
@patch('myauth.core.token.datetime')
|
|
def test_get_password_reset_token_expiration(self, mock_datetime, token_manager: TokenManager):
|
|
"""Should calculate password reset token expiration correctly."""
|
|
|
|
start_time = datetime(2025, 1, 1, 10, 0, 0)
|
|
mock_datetime.now = MagicMock(return_value=start_time)
|
|
|
|
expected_exp = start_time + timedelta(minutes=token_manager.password_reset_token_expire_minutes)
|
|
actual_exp = token_manager.get_password_reset_token_expiration()
|
|
|
|
assert actual_exp == expected_exp
|
|
|
|
|
|
class TestTokenDecodingAndValidation:
|
|
"""Tests decoding and validation logic for JWT tokens."""
|
|
|
|
# --- Access Token Tests ---
|
|
|
|
def test_decode_access_token_success(self, token_manager: TokenManager, test_user_in_db: UserInDB):
|
|
"""Should successfully decode a valid access token."""
|
|
token = token_manager.create_access_token(test_user_in_db)
|
|
|
|
payload = token_manager.decode_access_token(token)
|
|
|
|
assert payload.sub == test_user_in_db.id
|
|
assert payload.email == test_user_in_db.email
|
|
assert payload.type == "access"
|
|
|
|
def test_i_cannot_decode_expired_access_token(self, token_manager: TokenManager, test_user_in_db: UserInDB):
|
|
"""
|
|
Should raise ExpiredTokenError when decoding an expired token.
|
|
"""
|
|
from jose import jwt
|
|
from datetime import datetime, timedelta
|
|
|
|
# Create an already expired token (1 hour ago)
|
|
expired_time = datetime.now() - timedelta(hours=1)
|
|
|
|
payload = {
|
|
"sub": test_user_in_db.id,
|
|
"email": test_user_in_db.email,
|
|
"exp": int(expired_time.timestamp()),
|
|
"type": "access"
|
|
}
|
|
|
|
expired_token = jwt.encode(
|
|
payload,
|
|
token_manager.jwt_secret,
|
|
algorithm=token_manager.jwt_algorithm
|
|
)
|
|
|
|
# Should raise ExpiredTokenError
|
|
with pytest.raises(ExpiredTokenError, match="Access token has expired"):
|
|
token_manager.decode_access_token(expired_token)
|
|
|
|
def test_decode_access_token_invalid_signature(self, token_manager: TokenManager, test_user_in_db: UserInDB):
|
|
"""Should raise InvalidTokenError if the signature is bad."""
|
|
|
|
token = token_manager.create_access_token(test_user_in_db)
|
|
# Flip the last character to invalidate the signature
|
|
invalid_token = token[:-1] + ('A' if token[-1] != 'A' else 'B')
|
|
|
|
with pytest.raises(InvalidTokenError, match="Invalid access token"):
|
|
token_manager.decode_access_token(invalid_token)
|
|
|
|
def test_decode_access_token_wrong_type(self, token_manager: TokenManager):
|
|
"""Should raise InvalidTokenError if token is not 'access' type."""
|
|
|
|
# Create an email verification token, but try to decode it as an access token
|
|
wrong_token = token_manager.create_email_verification_token("wrong@type.com")
|
|
|
|
with pytest.raises(InvalidTokenError, match="Invalid token type"):
|
|
token_manager.decode_access_token(wrong_token)
|
|
|
|
# --- Email Verification Token Tests ---
|
|
|
|
def test_decode_email_verification_token_success(self, token_manager: TokenManager):
|
|
"""Should successfully decode a valid email verification token."""
|
|
email = "valid_email@test.com"
|
|
token = token_manager.create_email_verification_token(email)
|
|
|
|
decoded_email = token_manager.decode_email_verification_token(token)
|
|
|
|
assert decoded_email == email
|
|
|
|
def test_decode_email_verification_token_expired(self, token_manager: TokenManager):
|
|
"""Should raise ExpiredTokenError if the token is old (7 days set in creation)."""
|
|
|
|
# This test requires mocking time, but given the 7-day expiration,
|
|
# we can simulate an expired token by manually encoding one.
|
|
|
|
# Manually encode an expired token
|
|
expired_payload = {
|
|
"email": "old@example.com",
|
|
"exp": int((datetime.now() - timedelta(days=1)).timestamp()), # Expired yesterday
|
|
"type": "email_verification"
|
|
}
|
|
expired_token = jwt.encode(expired_payload, token_manager.jwt_secret, algorithm=token_manager.jwt_algorithm)
|
|
|
|
with pytest.raises(ExpiredTokenError, match="Email verification token has expired"):
|
|
token_manager.decode_email_verification_token(expired_token)
|
|
|
|
def test_decode_email_verification_token_wrong_type(self, token_manager: TokenManager, test_user_in_db: UserInDB):
|
|
"""Should raise InvalidTokenError if token is not 'email_verification' type."""
|
|
|
|
# Create an access token, but try to decode it as an email token
|
|
wrong_token = token_manager.create_access_token(test_user_in_db)
|
|
|
|
with pytest.raises(InvalidTokenError, match="Invalid token type"):
|
|
token_manager.decode_email_verification_token(wrong_token)
|