# tests/core/test_token_manager.py from datetime import datetime, timedelta from unittest.mock import MagicMock, patch import pytest from jose import jwt from src.my_auth.core.token import TokenManager from src.my_auth.exceptions import InvalidTokenError, ExpiredTokenError from src.my_auth.models.user import UserInDB # Assuming you have a fixture for this @pytest.fixture def token_manager(): """Provides a TokenManager instance with known, short expiration times for testing.""" return TokenManager( jwt_secret="TEST_SECRET_KEY", jwt_algorithm="HS256", access_token_expire_minutes=1, refresh_token_expire_days=7, password_reset_token_expire_minutes=15 ) class TestTokenManagerInitialization: """Tests for TokenManager setup and configuration.""" def test_init_success(self): """Should initialize successfully with required parameters.""" tm = TokenManager(jwt_secret="MySecret") assert tm.jwt_secret == "MySecret" assert tm.jwt_algorithm == "HS256" assert tm.access_token_expire_minutes == 30 assert tm.refresh_token_expire_days == 7 assert tm.password_reset_token_expire_minutes == 15 def test_init_failure_empty_secret(self): """Should raise ValueError if JWT secret is empty.""" with pytest.raises(ValueError, match="JWT secret cannot be empty"): TokenManager(jwt_secret="") class TestTokenCreation: """Tests creation methods for different token types.""" def test_create_access_token_format_and_expiration(self, token_manager: TokenManager, test_user_in_db: UserInDB): """Should create a valid JWT with correct payload and expiration.""" token = token_manager.create_access_token(test_user_in_db) # 1. Assert token is a string (encoded) assert isinstance(token, str) # 2. Decode and check payload content payload = jwt.decode(token, token_manager.jwt_secret, algorithms=[token_manager.jwt_algorithm]) assert payload["sub"] == test_user_in_db.id assert payload["email"] == test_user_in_db.email assert payload["type"] == "access" # 3. Check expiration (should be within a small window of the expected time) now = datetime.now() expected_exp_dt = now + timedelta(minutes=token_manager.access_token_expire_minutes) # Check if expiration is within +/- 1 second of the expected value assert abs(payload["exp"] - int(expected_exp_dt.timestamp())) <= 1 def test_create_refresh_token_format(self, token_manager: TokenManager): """Should create a random hex string of length 64.""" token = token_manager.create_refresh_token() assert isinstance(token, str) assert len(token) == 64 assert all(c in '0123456789abcdef' for c in token) def test_create_password_reset_token_format(self, token_manager: TokenManager): """Should create a random hex string of length 64.""" token = token_manager.create_password_reset_token() assert isinstance(token, str) assert len(token) == 64 def test_create_email_verification_token_format(self, token_manager: TokenManager): """Should create a JWT with email and 'email_verification' type.""" email = "verify@example.com" token = token_manager.create_email_verification_token(email) # Decode and check payload content payload = jwt.decode(token, token_manager.jwt_secret, algorithms=[token_manager.jwt_algorithm]) assert payload["email"] == email assert payload["type"] == "email_verification" # Expiration check (set to 7 days in the implementation) now = datetime.now() expected_exp_dt = now + timedelta(days=7) assert abs(payload["exp"] - int(expected_exp_dt.timestamp())) <= 1 class TestTokenExpirationCalculations: """Tests for token expiration date methods.""" # We patch datetime.now() to ensure stable calculations @patch('src.my_auth.core.token.datetime') def test_get_refresh_token_expiration(self, mock_datetime, token_manager: TokenManager): """Should calculate refresh token expiration correctly.""" # Set a fixed starting time start_time = datetime(2025, 1, 1, 10, 0, 0) mock_datetime.now = MagicMock(return_value=start_time) expected_exp = start_time + timedelta(days=token_manager.refresh_token_expire_days) actual_exp = token_manager.get_refresh_token_expiration() assert actual_exp == expected_exp @patch('src.my_auth.core.token.datetime') def test_get_password_reset_token_expiration(self, mock_datetime, token_manager: TokenManager): """Should calculate password reset token expiration correctly.""" start_time = datetime(2025, 1, 1, 10, 0, 0) mock_datetime.now = MagicMock(return_value=start_time) expected_exp = start_time + timedelta(minutes=token_manager.password_reset_token_expire_minutes) actual_exp = token_manager.get_password_reset_token_expiration() assert actual_exp == expected_exp class TestTokenDecodingAndValidation: """Tests decoding and validation logic for JWT tokens.""" # --- Access Token Tests --- def test_decode_access_token_success(self, token_manager: TokenManager, test_user_in_db: UserInDB): """Should successfully decode a valid access token.""" token = token_manager.create_access_token(test_user_in_db) payload = token_manager.decode_access_token(token) assert payload.sub == test_user_in_db.id assert payload.email == test_user_in_db.email assert payload.type == "access" def test_i_cannot_decode_expired_access_token(self, token_manager: TokenManager, test_user_in_db: UserInDB): """ Should raise ExpiredTokenError when decoding an expired token. """ from jose import jwt from datetime import datetime, timedelta # Create an already expired token (1 hour ago) expired_time = datetime.now() - timedelta(hours=1) payload = { "sub": test_user_in_db.id, "email": test_user_in_db.email, "exp": int(expired_time.timestamp()), "type": "access" } expired_token = jwt.encode( payload, token_manager.jwt_secret, algorithm=token_manager.jwt_algorithm ) # Should raise ExpiredTokenError with pytest.raises(ExpiredTokenError, match="Access token has expired"): token_manager.decode_access_token(expired_token) def test_decode_access_token_invalid_signature(self, token_manager: TokenManager, test_user_in_db: UserInDB): """Should raise InvalidTokenError if the signature is bad.""" token = token_manager.create_access_token(test_user_in_db) # Flip the last character to invalidate the signature invalid_token = token[:-1] + ('A' if token[-1] != 'A' else 'B') with pytest.raises(InvalidTokenError, match="Invalid access token"): token_manager.decode_access_token(invalid_token) def test_decode_access_token_wrong_type(self, token_manager: TokenManager): """Should raise InvalidTokenError if token is not 'access' type.""" # Create an email verification token, but try to decode it as an access token wrong_token = token_manager.create_email_verification_token("wrong@type.com") with pytest.raises(InvalidTokenError, match="Invalid token type"): token_manager.decode_access_token(wrong_token) # --- Email Verification Token Tests --- def test_decode_email_verification_token_success(self, token_manager: TokenManager): """Should successfully decode a valid email verification token.""" email = "valid_email@test.com" token = token_manager.create_email_verification_token(email) decoded_email = token_manager.decode_email_verification_token(token) assert decoded_email == email def test_decode_email_verification_token_expired(self, token_manager: TokenManager): """Should raise ExpiredTokenError if the token is old (7 days set in creation).""" # This test requires mocking time, but given the 7-day expiration, # we can simulate an expired token by manually encoding one. # Manually encode an expired token expired_payload = { "email": "old@example.com", "exp": int((datetime.now() - timedelta(days=1)).timestamp()), # Expired yesterday "type": "email_verification" } expired_token = jwt.encode(expired_payload, token_manager.jwt_secret, algorithm=token_manager.jwt_algorithm) with pytest.raises(ExpiredTokenError, match="Email verification token has expired"): token_manager.decode_email_verification_token(expired_token) def test_decode_email_verification_token_wrong_type(self, token_manager: TokenManager, test_user_in_db: UserInDB): """Should raise InvalidTokenError if token is not 'email_verification' type.""" # Create an access token, but try to decode it as an email token wrong_token = token_manager.create_access_token(test_user_in_db) with pytest.raises(InvalidTokenError, match="Invalid token type"): token_manager.decode_email_verification_token(wrong_token)