From 4782fad5b9610b28664ce59561f91f452e78b48a Mon Sep 17 00:00:00 2001 From: Piotr Oleszczyk Date: Thu, 12 Mar 2026 15:13:55 +0100 Subject: [PATCH] feat(auth): validate Authelia tokens in FastAPI --- backend/innercontext/api/auth.py | 166 +++++++++++ backend/innercontext/api/auth_deps.py | 57 ++++ backend/innercontext/auth.py | 384 ++++++++++++++++++++++++++ backend/main.py | 56 +++- backend/pyproject.toml | 1 + backend/tests/conftest.py | 22 ++ backend/tests/test_auth.py | 275 ++++++++++++++++++ 7 files changed, 953 insertions(+), 8 deletions(-) create mode 100644 backend/innercontext/api/auth.py create mode 100644 backend/innercontext/api/auth_deps.py create mode 100644 backend/innercontext/auth.py create mode 100644 backend/tests/test_auth.py diff --git a/backend/innercontext/api/auth.py b/backend/innercontext/api/auth.py new file mode 100644 index 0000000..877289e --- /dev/null +++ b/backend/innercontext/api/auth.py @@ -0,0 +1,166 @@ +from __future__ import annotations + +from datetime import date, datetime +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, status +from sqlmodel import Field, Session, SQLModel, select + +from db import get_session +from innercontext.api.auth_deps import get_current_user +from innercontext.auth import CurrentUser, IdentityData, sync_current_user +from innercontext.models import HouseholdRole, Role, UserProfile + +router = APIRouter() + + +class SessionSyncRequest(SQLModel): + iss: str | None = None + sub: str | None = None + email: str | None = None + name: str | None = None + preferred_username: str | None = None + groups: list[str] | None = None + + +class AuthHouseholdMembershipPublic(SQLModel): + household_id: UUID + role: HouseholdRole + + +class AuthUserPublic(SQLModel): + id: UUID + role: Role + household_membership: AuthHouseholdMembershipPublic | None = None + + +class AuthIdentityPublic(SQLModel): + issuer: str + subject: str + email: str | None = None + name: str | None = None + preferred_username: str | None = None + groups: list[str] = Field(default_factory=list) + + +class AuthProfilePublic(SQLModel): + id: UUID + user_id: UUID | None + birth_date: date | None = None + sex_at_birth: str | None = None + created_at: datetime + updated_at: datetime + + +class AuthSessionResponse(SQLModel): + user: AuthUserPublic + identity: AuthIdentityPublic + profile: AuthProfilePublic | None = None + + +def _build_identity( + current_user: CurrentUser, + payload: SessionSyncRequest | None, +) -> IdentityData: + if payload is None: + return current_user.identity + + if payload.iss is not None and payload.iss != current_user.identity.issuer: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Session sync issuer does not match bearer token", + ) + if payload.sub is not None and payload.sub != current_user.identity.subject: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Session sync subject does not match bearer token", + ) + + return IdentityData( + issuer=current_user.identity.issuer, + subject=current_user.identity.subject, + email=( + payload.email if payload.email is not None else current_user.identity.email + ), + name=payload.name if payload.name is not None else current_user.identity.name, + preferred_username=( + payload.preferred_username + if payload.preferred_username is not None + else current_user.identity.preferred_username + ), + groups=( + tuple(payload.groups) + if payload.groups is not None + else current_user.identity.groups + ), + ) + + +def _get_profile(session: Session, user_id: UUID) -> UserProfile | None: + return session.exec( + select(UserProfile).where(UserProfile.user_id == user_id) + ).first() + + +def _profile_public(profile: UserProfile | None) -> AuthProfilePublic | None: + if profile is None: + return None + + return AuthProfilePublic( + id=profile.id, + user_id=profile.user_id, + birth_date=profile.birth_date, + sex_at_birth=( + profile.sex_at_birth.value if profile.sex_at_birth is not None else None + ), + created_at=profile.created_at, + updated_at=profile.updated_at, + ) + + +def _response(session: Session, current_user: CurrentUser) -> AuthSessionResponse: + household_membership = None + if current_user.household_membership is not None: + household_membership = AuthHouseholdMembershipPublic( + household_id=current_user.household_membership.household_id, + role=current_user.household_membership.role, + ) + + return AuthSessionResponse( + user=AuthUserPublic( + id=current_user.user_id, + role=current_user.role, + household_membership=household_membership, + ), + identity=AuthIdentityPublic( + issuer=current_user.identity.issuer, + subject=current_user.identity.subject, + email=current_user.identity.email, + name=current_user.identity.name, + preferred_username=current_user.identity.preferred_username, + groups=list(current_user.identity.groups), + ), + profile=_profile_public(_get_profile(session, current_user.user_id)), + ) + + +@router.post("/session/sync", response_model=AuthSessionResponse) +def sync_session( + payload: SessionSyncRequest | None = None, + session: Session = Depends(get_session), + current_user: CurrentUser = Depends(get_current_user), +): + synced_user = sync_current_user( + session, + current_user.claims, + identity=_build_identity(current_user, payload), + ) + return _response(session, synced_user) + + +@router.get("/me", response_model=AuthSessionResponse) +def get_me( + session: Session = Depends(get_session), + current_user: CurrentUser = Depends(get_current_user), +): + return _response(session, current_user) diff --git a/backend/innercontext/api/auth_deps.py b/backend/innercontext/api/auth_deps.py new file mode 100644 index 0000000..a71a57a --- /dev/null +++ b/backend/innercontext/api/auth_deps.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +from typing import Annotated + +from fastapi import Depends, HTTPException, status +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from sqlmodel import Session + +from db import get_session +from innercontext.auth import ( + AuthConfigurationError, + CurrentUser, + TokenValidationError, + sync_current_user, + validate_access_token, +) +from innercontext.models import Role + +_bearer_scheme = HTTPBearer(auto_error=False) + + +def _unauthorized(detail: str) -> HTTPException: + return HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=detail, + headers={"WWW-Authenticate": "Bearer"}, + ) + + +def get_current_user( + credentials: Annotated[ + HTTPAuthorizationCredentials | None, Depends(_bearer_scheme) + ], + session: Session = Depends(get_session), +) -> CurrentUser: + if credentials is None or credentials.scheme.lower() != "bearer": + raise _unauthorized("Missing bearer token") + + try: + claims = validate_access_token(credentials.credentials) + return sync_current_user(session, claims) + except AuthConfigurationError as exc: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail=str(exc), + ) from exc + except TokenValidationError as exc: + raise _unauthorized(str(exc)) from exc + + +def require_admin(current_user: CurrentUser = Depends(get_current_user)) -> CurrentUser: + if current_user.role is not Role.ADMIN: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Admin role required", + ) + return current_user diff --git a/backend/innercontext/auth.py b/backend/innercontext/auth.py new file mode 100644 index 0000000..b672d43 --- /dev/null +++ b/backend/innercontext/auth.py @@ -0,0 +1,384 @@ +from __future__ import annotations + +import os +import time +from dataclasses import dataclass, field +from datetime import UTC, datetime +from functools import lru_cache +from threading import Lock +from typing import Any, Mapping +from uuid import UUID + +import httpx +import jwt +from jwt import InvalidTokenError, PyJWKSet +from sqlmodel import Session, select + +from innercontext.models import HouseholdMembership, HouseholdRole, Role, User + +_DISCOVERY_PATH = "/.well-known/openid-configuration" +_SUPPORTED_ALGORITHMS = frozenset( + {"RS256", "RS384", "RS512", "ES256", "ES384", "ES512"} +) + + +class AuthConfigurationError(RuntimeError): + pass + + +class TokenValidationError(ValueError): + pass + + +@dataclass(frozen=True, slots=True) +class AuthSettings: + issuer: str + client_id: str + audiences: tuple[str, ...] + discovery_url: str + jwks_url: str | None + groups_claim: str + admin_groups: tuple[str, ...] + member_groups: tuple[str, ...] + jwks_cache_ttl_seconds: int + http_timeout_seconds: float + clock_skew_seconds: int + + +@dataclass(frozen=True, slots=True) +class TokenClaims: + issuer: str + subject: str + audience: tuple[str, ...] + expires_at: datetime + groups: tuple[str, ...] = () + email: str | None = None + name: str | None = None + preferred_username: str | None = None + raw_claims: Mapping[str, Any] = field(default_factory=dict, repr=False) + + @classmethod + def from_payload( + cls, payload: Mapping[str, Any], settings: AuthSettings + ) -> "TokenClaims": + audience = payload.get("aud") + if isinstance(audience, str): + audiences = (audience,) + elif isinstance(audience, list): + audiences = tuple(str(item) for item in audience) + else: + audiences = () + + groups = _normalize_groups(payload.get(settings.groups_claim)) + exp = payload.get("exp") + if not isinstance(exp, (int, float)): + raise TokenValidationError("Access token missing exp claim") + + return cls( + issuer=str(payload["iss"]), + subject=str(payload["sub"]), + audience=audiences, + expires_at=datetime.fromtimestamp(exp, tz=UTC), + groups=groups, + email=_optional_str(payload.get("email")), + name=_optional_str(payload.get("name")), + preferred_username=_optional_str(payload.get("preferred_username")), + raw_claims=dict(payload), + ) + + +@dataclass(frozen=True, slots=True) +class IdentityData: + issuer: str + subject: str + email: str | None = None + name: str | None = None + preferred_username: str | None = None + groups: tuple[str, ...] = () + + @classmethod + def from_claims(cls, claims: TokenClaims) -> "IdentityData": + return cls( + issuer=claims.issuer, + subject=claims.subject, + email=claims.email, + name=claims.name, + preferred_username=claims.preferred_username, + groups=claims.groups, + ) + + +@dataclass(frozen=True, slots=True) +class CurrentHouseholdMembership: + household_id: UUID + role: HouseholdRole + + +@dataclass(frozen=True, slots=True) +class CurrentUser: + user_id: UUID + role: Role + identity: IdentityData + claims: TokenClaims = field(repr=False) + household_membership: CurrentHouseholdMembership | None = None + + +def _split_csv(value: str | None) -> tuple[str, ...]: + if value is None: + return () + return tuple(item.strip() for item in value.split(",") if item.strip()) + + +def _optional_str(value: Any) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + return str(value) + + +def _normalize_groups(value: Any) -> tuple[str, ...]: + if value is None: + return () + if isinstance(value, str): + return (value,) + if isinstance(value, list): + return tuple(str(item) for item in value) + if isinstance(value, tuple): + return tuple(str(item) for item in value) + return (str(value),) + + +def _required_env(name: str) -> str: + value = os.environ.get(name) + if value: + return value + raise AuthConfigurationError(f"Missing required auth environment variable: {name}") + + +@lru_cache +def get_auth_settings() -> AuthSettings: + issuer = _required_env("OIDC_ISSUER") + client_id = _required_env("OIDC_CLIENT_ID") + audiences = _split_csv(os.environ.get("OIDC_AUDIENCE")) or (client_id,) + discovery_url = os.environ.get("OIDC_DISCOVERY_URL") or ( + issuer.rstrip("/") + _DISCOVERY_PATH + ) + + return AuthSettings( + issuer=issuer, + client_id=client_id, + audiences=audiences, + discovery_url=discovery_url, + jwks_url=os.environ.get("OIDC_JWKS_URL"), + groups_claim=os.environ.get("OIDC_GROUPS_CLAIM", "groups"), + admin_groups=_split_csv(os.environ.get("OIDC_ADMIN_GROUPS")), + member_groups=_split_csv(os.environ.get("OIDC_MEMBER_GROUPS")), + jwks_cache_ttl_seconds=int( + os.environ.get("OIDC_JWKS_CACHE_TTL_SECONDS", "300") + ), + http_timeout_seconds=float(os.environ.get("OIDC_HTTP_TIMEOUT_SECONDS", "5")), + clock_skew_seconds=int(os.environ.get("OIDC_CLOCK_SKEW_SECONDS", "30")), + ) + + +class CachedJwksClient: + def __init__(self, settings: AuthSettings): + self._settings = settings + self._lock = Lock() + self._jwks: PyJWKSet | None = None + self._jwks_fetched_at = 0.0 + self._discovery_jwks_url: str | None = None + self._discovery_fetched_at = 0.0 + + def get_signing_key(self, kid: str) -> Any: + with self._lock: + jwks = self._get_jwks_locked() + key = self._find_key(jwks, kid) + if key is not None: + return key + + self._refresh_jwks_locked( + force_discovery_refresh=self._settings.jwks_url is None + ) + if self._jwks is None: + raise TokenValidationError("JWKS cache is empty") + + key = self._find_key(self._jwks, kid) + if key is None: + raise TokenValidationError(f"No signing key found for kid '{kid}'") + return key + + def _get_jwks_locked(self) -> PyJWKSet: + if self._jwks is None or self._is_stale(self._jwks_fetched_at): + self._refresh_jwks_locked(force_discovery_refresh=False) + if self._jwks is None: + raise TokenValidationError("Unable to load JWKS") + return self._jwks + + def _refresh_jwks_locked(self, force_discovery_refresh: bool) -> None: + jwks_url = self._resolve_jwks_url_locked(force_refresh=force_discovery_refresh) + data = self._fetch_json(jwks_url) + try: + self._jwks = PyJWKSet.from_dict(data) + except Exception as exc: + raise TokenValidationError( + "OIDC provider returned an invalid JWKS payload" + ) from exc + self._jwks_fetched_at = time.monotonic() + + def _resolve_jwks_url_locked(self, force_refresh: bool) -> str: + if self._settings.jwks_url: + return self._settings.jwks_url + + if ( + force_refresh + or self._discovery_jwks_url is None + or self._is_stale(self._discovery_fetched_at) + ): + discovery = self._fetch_json(self._settings.discovery_url) + jwks_uri = discovery.get("jwks_uri") + if not isinstance(jwks_uri, str) or not jwks_uri: + raise TokenValidationError("OIDC discovery document missing jwks_uri") + self._discovery_jwks_url = jwks_uri + self._discovery_fetched_at = time.monotonic() + + if self._discovery_jwks_url is None: + raise TokenValidationError("Unable to resolve JWKS URL") + return self._discovery_jwks_url + + def _fetch_json(self, url: str) -> dict[str, Any]: + try: + response = httpx.get(url, timeout=self._settings.http_timeout_seconds) + response.raise_for_status() + except httpx.HTTPError as exc: + raise TokenValidationError( + f"Failed to fetch OIDC metadata from {url}" + ) from exc + + data = response.json() + if not isinstance(data, dict): + raise TokenValidationError( + f"OIDC metadata from {url} must be a JSON object" + ) + return data + + def _is_stale(self, fetched_at: float) -> bool: + return (time.monotonic() - fetched_at) >= self._settings.jwks_cache_ttl_seconds + + @staticmethod + def _find_key(jwks: PyJWKSet, kid: str) -> Any | None: + for jwk in jwks.keys: + if jwk.key_id == kid: + return jwk.key + return None + + +@lru_cache +def get_jwks_client() -> CachedJwksClient: + return CachedJwksClient(get_auth_settings()) + + +def reset_auth_caches() -> None: + get_auth_settings.cache_clear() + get_jwks_client.cache_clear() + + +def validate_access_token(token: str) -> TokenClaims: + settings = get_auth_settings() + + try: + unverified_header = jwt.get_unverified_header(token) + except InvalidTokenError as exc: + raise TokenValidationError("Malformed access token header") from exc + + kid = unverified_header.get("kid") + algorithm = unverified_header.get("alg") + if not isinstance(kid, str) or not kid: + raise TokenValidationError("Access token missing kid header") + if not isinstance(algorithm, str) or algorithm not in _SUPPORTED_ALGORITHMS: + raise TokenValidationError("Access token uses an unsupported signing algorithm") + + signing_key = get_jwks_client().get_signing_key(kid) + + try: + payload = jwt.decode( + token, + key=signing_key, + algorithms=[algorithm], + audience=settings.audiences, + issuer=settings.issuer, + options={"require": ["exp", "iss", "sub"]}, + leeway=settings.clock_skew_seconds, + ) + except InvalidTokenError as exc: + raise TokenValidationError("Invalid access token") from exc + + return TokenClaims.from_payload(payload, settings) + + +def sync_current_user( + session: Session, + claims: TokenClaims, + identity: IdentityData | None = None, +) -> CurrentUser: + effective_identity = identity or IdentityData.from_claims(claims) + statement = select(User).where( + User.oidc_issuer == effective_identity.issuer, + User.oidc_subject == effective_identity.subject, + ) + user = session.exec(statement).first() + existing_role = user.role if user is not None else None + resolved_role = resolve_role(effective_identity.groups, existing_role=existing_role) + needs_commit = False + + if user is None: + user = User( + oidc_issuer=effective_identity.issuer, + oidc_subject=effective_identity.subject, + role=resolved_role, + ) + session.add(user) + needs_commit = True + elif user.role != resolved_role: + user.role = resolved_role + session.add(user) + needs_commit = True + + if needs_commit: + session.commit() + session.refresh(user) + + membership = session.exec( + select(HouseholdMembership).where(HouseholdMembership.user_id == user.id) + ).first() + + household_membership = None + if membership is not None: + household_membership = CurrentHouseholdMembership( + household_id=membership.household_id, + role=membership.role, + ) + + return CurrentUser( + user_id=user.id, + role=user.role, + identity=effective_identity, + claims=claims, + household_membership=household_membership, + ) + + +def resolve_role(groups: tuple[str, ...], existing_role: Role | None = None) -> Role: + settings = get_auth_settings() + if groups: + group_set = set(groups) + if settings.admin_groups and group_set.intersection(settings.admin_groups): + return Role.ADMIN + if settings.member_groups: + if group_set.intersection(settings.member_groups): + return Role.MEMBER + return Role.MEMBER + return Role.MEMBER + + return existing_role or Role.MEMBER diff --git a/backend/main.py b/backend/main.py index 10fb73b..55280aa 100644 --- a/backend/main.py +++ b/backend/main.py @@ -5,13 +5,14 @@ from dotenv import load_dotenv load_dotenv() # load .env before db.py reads DATABASE_URL -from fastapi import FastAPI # noqa: E402 +from fastapi import Depends, FastAPI # noqa: E402 from fastapi.middleware.cors import CORSMiddleware # noqa: E402 from sqlmodel import Session # noqa: E402 from db import create_db_and_tables, engine # noqa: E402 from innercontext.api import ( # noqa: E402 ai_logs, + auth, health, inventory, products, @@ -19,6 +20,7 @@ from innercontext.api import ( # noqa: E402 routines, skincare, ) +from innercontext.api.auth_deps import get_current_user # noqa: E402 from innercontext.services.pricing_jobs import enqueue_pricing_recalc # noqa: E402 @@ -47,13 +49,51 @@ app.add_middleware( allow_headers=["*"], ) -app.include_router(products.router, prefix="/products", tags=["products"]) -app.include_router(inventory.router, prefix="/inventory", tags=["inventory"]) -app.include_router(profile.router, prefix="/profile", tags=["profile"]) -app.include_router(health.router, prefix="/health", tags=["health"]) -app.include_router(routines.router, prefix="/routines", tags=["routines"]) -app.include_router(skincare.router, prefix="/skincare", tags=["skincare"]) -app.include_router(ai_logs.router, prefix="/ai-logs", tags=["ai-logs"]) +protected = [Depends(get_current_user)] + +app.include_router(auth.router, prefix="/auth", tags=["auth"]) +app.include_router( + products.router, + prefix="/products", + tags=["products"], + dependencies=protected, +) +app.include_router( + inventory.router, + prefix="/inventory", + tags=["inventory"], + dependencies=protected, +) +app.include_router( + profile.router, + prefix="/profile", + tags=["profile"], + dependencies=protected, +) +app.include_router( + health.router, + prefix="/health", + tags=["health"], + dependencies=protected, +) +app.include_router( + routines.router, + prefix="/routines", + tags=["routines"], + dependencies=protected, +) +app.include_router( + skincare.router, + prefix="/skincare", + tags=["skincare"], + dependencies=protected, +) +app.include_router( + ai_logs.router, + prefix="/ai-logs", + tags=["ai-logs"], + dependencies=protected, +) @app.get("/health-check") diff --git a/backend/pyproject.toml b/backend/pyproject.toml index eeddb55..6b9a55c 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -8,6 +8,7 @@ dependencies = [ "alembic>=1.14", "fastapi>=0.132.0", "google-genai>=1.65.0", + "pyjwt[crypto]>=2.10.1", "psycopg[binary]>=3.3.3", "python-dotenv>=1.2.1", "python-multipart>=0.0.22", diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 3c4f465..e35dfba 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -1,4 +1,6 @@ import os +from datetime import UTC, datetime, timedelta +from uuid import uuid4 # Must be set before importing db (which calls create_engine at module level) os.environ.setdefault("DATABASE_URL", "sqlite://") @@ -10,6 +12,9 @@ from sqlmodel.pool import StaticPool import db as db_module from db import get_session +from innercontext.api.auth_deps import get_current_user +from innercontext.auth import CurrentUser, IdentityData, TokenClaims +from innercontext.models import Role from main import app @@ -38,7 +43,24 @@ def client(session, monkeypatch): def _override(): yield session + def _current_user_override(): + claims = TokenClaims( + issuer="https://auth.test", + subject="test-user", + audience=("innercontext-web",), + expires_at=datetime.now(UTC) + timedelta(hours=1), + groups=("innercontext-admin",), + raw_claims={"iss": "https://auth.test", "sub": "test-user"}, + ) + return CurrentUser( + user_id=uuid4(), + role=Role.ADMIN, + identity=IdentityData.from_claims(claims), + claims=claims, + ) + app.dependency_overrides[get_session] = _override + app.dependency_overrides[get_current_user] = _current_user_override with TestClient(app) as c: yield c app.dependency_overrides.clear() diff --git a/backend/tests/test_auth.py b/backend/tests/test_auth.py new file mode 100644 index 0000000..0ed16a5 --- /dev/null +++ b/backend/tests/test_auth.py @@ -0,0 +1,275 @@ +from __future__ import annotations + +import json +from datetime import UTC, datetime, timedelta +from uuid import UUID, uuid4 + +import jwt +import pytest +from cryptography.hazmat.primitives.asymmetric import rsa +from fastapi import HTTPException +from fastapi.testclient import TestClient +from jwt import algorithms +from sqlmodel import Session, SQLModel, create_engine +from sqlmodel.pool import StaticPool + +import db as db_module +from db import get_session +from innercontext.api.auth_deps import require_admin +from innercontext.auth import ( + CurrentHouseholdMembership, + CurrentUser, + IdentityData, + TokenClaims, + reset_auth_caches, + validate_access_token, +) +from innercontext.models import ( + Household, + HouseholdMembership, + HouseholdRole, + Role, + User, +) +from main import app + + +class _MockResponse: + def __init__(self, payload: dict[str, object], status_code: int = 200): + self._payload = payload + self.status_code = status_code + + def raise_for_status(self) -> None: + if self.status_code >= 400: + raise RuntimeError(f"unexpected status {self.status_code}") + + def json(self) -> dict[str, object]: + return self._payload + + +@pytest.fixture() +def auth_env(monkeypatch): + monkeypatch.setenv("OIDC_ISSUER", "https://auth.example.test") + monkeypatch.setenv("OIDC_CLIENT_ID", "innercontext-web") + monkeypatch.setenv( + "OIDC_DISCOVERY_URL", + "https://auth.example.test/.well-known/openid-configuration", + ) + monkeypatch.setenv("OIDC_ADMIN_GROUPS", "innercontext-admin") + monkeypatch.setenv("OIDC_MEMBER_GROUPS", "innercontext-member") + monkeypatch.setenv("OIDC_JWKS_CACHE_TTL_SECONDS", "3600") + reset_auth_caches() + yield + reset_auth_caches() + + +@pytest.fixture() +def rsa_keypair(): + private_key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + ) + return private_key, private_key.public_key() + + +@pytest.fixture() +def auth_session(monkeypatch): + engine = create_engine( + "sqlite://", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + monkeypatch.setattr(db_module, "engine", engine) + import innercontext.models # noqa: F401 + + SQLModel.metadata.create_all(engine) + with Session(engine) as session: + yield session + + +@pytest.fixture() +def auth_client(auth_session): + def _override(): + yield auth_session + + app.dependency_overrides[get_session] = _override + with TestClient(app) as client: + yield client + app.dependency_overrides.clear() + + +def _public_jwk(public_key, kid: str) -> dict[str, object]: + jwk = json.loads(algorithms.RSAAlgorithm.to_jwk(public_key)) + jwk["kid"] = kid + jwk["use"] = "sig" + jwk["alg"] = "RS256" + return jwk + + +def _sign_token(private_key, kid: str, **claims_overrides: object) -> str: + now = datetime.now(UTC) + payload: dict[str, object] = { + "iss": "https://auth.example.test", + "sub": "user-123", + "aud": "innercontext-web", + "exp": int((now + timedelta(hours=1)).timestamp()), + "iat": int(now.timestamp()), + "groups": ["innercontext-admin"], + "email": "user@example.test", + "name": "Inner Context User", + "preferred_username": "ictx-user", + } + payload.update(claims_overrides) + return jwt.encode(payload, private_key, algorithm="RS256", headers={"kid": kid}) + + +def _mock_oidc(monkeypatch, public_key, *, fetch_counts: dict[str, int] | None = None): + def _fake_get(url: str, timeout: float): + if fetch_counts is not None: + fetch_counts[url] = fetch_counts.get(url, 0) + 1 + if url.endswith("/.well-known/openid-configuration"): + return _MockResponse({"jwks_uri": "https://auth.example.test/jwks.json"}) + if url.endswith("/jwks.json"): + return _MockResponse({"keys": [_public_jwk(public_key, "kid-1")]}) + raise AssertionError(f"unexpected URL {url} with timeout {timeout}") + + monkeypatch.setattr("innercontext.auth.httpx.get", _fake_get) + + +def test_validate_access_token_uses_cached_jwks(auth_env, rsa_keypair, monkeypatch): + private_key, public_key = rsa_keypair + fetch_counts: dict[str, int] = {} + _mock_oidc(monkeypatch, public_key, fetch_counts=fetch_counts) + + validate_access_token(_sign_token(private_key, "kid-1", sub="user-a")) + validate_access_token(_sign_token(private_key, "kid-1", sub="user-b")) + + assert ( + fetch_counts["https://auth.example.test/.well-known/openid-configuration"] == 1 + ) + assert fetch_counts["https://auth.example.test/jwks.json"] == 1 + + +@pytest.mark.parametrize( + ("path", "payload"), + [ + ( + "/auth/session/sync", + { + "email": "sync@example.test", + "name": "Synced User", + "preferred_username": "synced-user", + "groups": ["innercontext-admin"], + }, + ), + ("/auth/me", None), + ], + ids=["/auth/session/sync", "/auth/me"], +) +def test_sync_protected_endpoints_create_or_resolve_current_user( + auth_env, + auth_client, + auth_session, + rsa_keypair, + monkeypatch, + path: str, + payload: dict[str, object] | None, +): + private_key, public_key = rsa_keypair + _mock_oidc(monkeypatch, public_key) + token = _sign_token(private_key, "kid-1") + + if path == "/auth/me": + user = User( + oidc_issuer="https://auth.example.test", + oidc_subject="user-123", + role=Role.ADMIN, + ) + auth_session.add(user) + auth_session.commit() + auth_session.refresh(user) + + household = Household() + auth_session.add(household) + auth_session.commit() + auth_session.refresh(household) + + membership = HouseholdMembership( + user_id=user.id, + household_id=household.id, + role=HouseholdRole.OWNER, + ) + auth_session.add(membership) + auth_session.commit() + + response = auth_client.request( + "POST" if path.endswith("sync") else "GET", + path, + headers={"Authorization": f"Bearer {token}"}, + json=payload, + ) + + assert response.status_code == 200 + data = response.json() + assert data["user"]["role"] == "admin" + assert data["identity"]["issuer"] == "https://auth.example.test" + assert data["identity"]["subject"] == "user-123" + + synced_user = auth_session.get(User, UUID(data["user"]["id"])) + assert synced_user is not None + assert synced_user.oidc_issuer == "https://auth.example.test" + assert synced_user.oidc_subject == "user-123" + + if path == "/auth/session/sync": + assert data["identity"]["email"] == "sync@example.test" + assert data["identity"]["groups"] == ["innercontext-admin"] + else: + assert data["user"]["household_membership"]["role"] == "owner" + + +@pytest.mark.parametrize( + "path", + ["/auth/me", "/profile"], + ids=["/auth/me expects 401", "/profile expects 401"], +) +def test_unauthorized_protected_endpoints_return_401(auth_env, auth_client, path: str): + response = auth_client.get(path) + assert response.status_code == 401 + assert response.json()["detail"] == "Missing bearer token" + + +def test_unauthorized_invalid_bearer_token_is_rejected( + auth_env, auth_client, rsa_keypair, monkeypatch +): + _, public_key = rsa_keypair + _mock_oidc(monkeypatch, public_key) + response = auth_client.get( + "/auth/me", + headers={"Authorization": "Bearer not-a-jwt"}, + ) + assert response.status_code == 401 + + +def test_require_admin_raises_for_member(): + claims = TokenClaims( + issuer="https://auth.example.test", + subject="member-1", + audience=("innercontext-web",), + expires_at=datetime.now(UTC) + timedelta(hours=1), + raw_claims={"iss": "https://auth.example.test", "sub": "member-1"}, + ) + current_user = CurrentUser( + user_id=uuid4(), + role=Role.MEMBER, + identity=IdentityData.from_claims(claims), + claims=claims, + household_membership=CurrentHouseholdMembership( + household_id=uuid4(), + role=HouseholdRole.MEMBER, + ), + ) + + with pytest.raises(HTTPException) as exc_info: + require_admin(current_user) + + assert exc_info.value.status_code == 403