from __future__ import annotations import json from datetime import UTC, datetime, timedelta from uuid import UUID, uuid4 import jwt import pytest from cryptography.hazmat.primitives.asymmetric import rsa from fastapi import HTTPException from fastapi.testclient import TestClient from jwt import algorithms from sqlmodel import Session, SQLModel, create_engine from sqlmodel.pool import StaticPool import db as db_module from db import get_session from innercontext.api.auth_deps import require_admin from innercontext.auth import ( CurrentHouseholdMembership, CurrentUser, IdentityData, TokenClaims, reset_auth_caches, validate_access_token, ) from innercontext.models import ( Household, HouseholdMembership, HouseholdRole, Role, User, ) from main import app class _MockResponse: def __init__(self, payload: dict[str, object], status_code: int = 200): self._payload = payload self.status_code = status_code def raise_for_status(self) -> None: if self.status_code >= 400: raise RuntimeError(f"unexpected status {self.status_code}") def json(self) -> dict[str, object]: return self._payload @pytest.fixture() def auth_env(monkeypatch): monkeypatch.setenv("OIDC_ISSUER", "https://auth.example.test") monkeypatch.setenv("OIDC_CLIENT_ID", "innercontext-web") monkeypatch.setenv( "OIDC_DISCOVERY_URL", "https://auth.example.test/.well-known/openid-configuration", ) monkeypatch.setenv("OIDC_ADMIN_GROUPS", "innercontext-admin") monkeypatch.setenv("OIDC_MEMBER_GROUPS", "innercontext-member") monkeypatch.setenv("OIDC_JWKS_CACHE_TTL_SECONDS", "3600") reset_auth_caches() yield reset_auth_caches() @pytest.fixture() def rsa_keypair(): private_key = rsa.generate_private_key( public_exponent=65537, key_size=2048, ) return private_key, private_key.public_key() @pytest.fixture() def auth_session(monkeypatch): engine = create_engine( "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool, ) monkeypatch.setattr(db_module, "engine", engine) import innercontext.models # noqa: F401 SQLModel.metadata.create_all(engine) with Session(engine) as session: yield session @pytest.fixture() def auth_client(auth_session): def _override(): yield auth_session app.dependency_overrides[get_session] = _override with TestClient(app) as client: yield client app.dependency_overrides.clear() def _public_jwk(public_key, kid: str) -> dict[str, object]: jwk = json.loads(algorithms.RSAAlgorithm.to_jwk(public_key)) jwk["kid"] = kid jwk["use"] = "sig" jwk["alg"] = "RS256" return jwk def _sign_token(private_key, kid: str, **claims_overrides: object) -> str: now = datetime.now(UTC) payload: dict[str, object] = { "iss": "https://auth.example.test", "sub": "user-123", "aud": "innercontext-web", "exp": int((now + timedelta(hours=1)).timestamp()), "iat": int(now.timestamp()), "groups": ["innercontext-admin"], "email": "user@example.test", "name": "Inner Context User", "preferred_username": "ictx-user", } payload.update(claims_overrides) return jwt.encode(payload, private_key, algorithm="RS256", headers={"kid": kid}) def _mock_oidc(monkeypatch, public_key, *, fetch_counts: dict[str, int] | None = None): def _fake_get(url: str, timeout: float): if fetch_counts is not None: fetch_counts[url] = fetch_counts.get(url, 0) + 1 if url.endswith("/.well-known/openid-configuration"): return _MockResponse({"jwks_uri": "https://auth.example.test/jwks.json"}) if url.endswith("/jwks.json"): return _MockResponse({"keys": [_public_jwk(public_key, "kid-1")]}) raise AssertionError(f"unexpected URL {url} with timeout {timeout}") monkeypatch.setattr("innercontext.auth.httpx.get", _fake_get) def test_validate_access_token_uses_cached_jwks(auth_env, rsa_keypair, monkeypatch): private_key, public_key = rsa_keypair fetch_counts: dict[str, int] = {} _mock_oidc(monkeypatch, public_key, fetch_counts=fetch_counts) validate_access_token(_sign_token(private_key, "kid-1", sub="user-a")) validate_access_token(_sign_token(private_key, "kid-1", sub="user-b")) assert ( fetch_counts["https://auth.example.test/.well-known/openid-configuration"] == 1 ) assert fetch_counts["https://auth.example.test/jwks.json"] == 1 @pytest.mark.parametrize( ("path", "payload"), [ ( "/auth/session/sync", { "email": "sync@example.test", "name": "Synced User", "preferred_username": "synced-user", "groups": ["innercontext-admin"], }, ), ("/auth/me", None), ], ids=["/auth/session/sync", "/auth/me"], ) def test_sync_protected_endpoints_create_or_resolve_current_user( auth_env, auth_client, auth_session, rsa_keypair, monkeypatch, path: str, payload: dict[str, object] | None, ): private_key, public_key = rsa_keypair _mock_oidc(monkeypatch, public_key) token = _sign_token(private_key, "kid-1") if path == "/auth/me": user = User( oidc_issuer="https://auth.example.test", oidc_subject="user-123", role=Role.ADMIN, ) auth_session.add(user) auth_session.commit() auth_session.refresh(user) household = Household() auth_session.add(household) auth_session.commit() auth_session.refresh(household) membership = HouseholdMembership( user_id=user.id, household_id=household.id, role=HouseholdRole.OWNER, ) auth_session.add(membership) auth_session.commit() response = auth_client.request( "POST" if path.endswith("sync") else "GET", path, headers={"Authorization": f"Bearer {token}"}, json=payload, ) assert response.status_code == 200 data = response.json() assert data["user"]["role"] == "admin" assert data["identity"]["issuer"] == "https://auth.example.test" assert data["identity"]["subject"] == "user-123" synced_user = auth_session.get(User, UUID(data["user"]["id"])) assert synced_user is not None assert synced_user.oidc_issuer == "https://auth.example.test" assert synced_user.oidc_subject == "user-123" if path == "/auth/session/sync": assert data["identity"]["email"] == "sync@example.test" assert data["identity"]["groups"] == ["innercontext-admin"] else: assert data["user"]["household_membership"]["role"] == "owner" @pytest.mark.parametrize( "path", ["/auth/me", "/profile"], ids=["/auth/me expects 401", "/profile expects 401"], ) def test_unauthorized_protected_endpoints_return_401(auth_env, auth_client, path: str): response = auth_client.get(path) assert response.status_code == 401 assert response.json()["detail"] == "Missing bearer token" def test_unauthorized_invalid_bearer_token_is_rejected( auth_env, auth_client, rsa_keypair, monkeypatch ): _, public_key = rsa_keypair _mock_oidc(monkeypatch, public_key) response = auth_client.get( "/auth/me", headers={"Authorization": "Bearer not-a-jwt"}, ) assert response.status_code == 401 def test_require_admin_raises_for_member(): claims = TokenClaims( issuer="https://auth.example.test", subject="member-1", audience=("innercontext-web",), expires_at=datetime.now(UTC) + timedelta(hours=1), raw_claims={"iss": "https://auth.example.test", "sub": "member-1"}, ) current_user = CurrentUser( user_id=uuid4(), role=Role.MEMBER, identity=IdentityData.from_claims(claims), claims=claims, household_membership=CurrentHouseholdMembership( household_id=uuid4(), role=HouseholdRole.MEMBER, ), ) with pytest.raises(HTTPException) as exc_info: require_admin(current_user) assert exc_info.value.status_code == 403