feat(auth): validate Authelia tokens in FastAPI

This commit is contained in:
Piotr Oleszczyk 2026-03-12 15:13:55 +01:00
parent 2704d58673
commit 4782fad5b9
7 changed files with 953 additions and 8 deletions

View file

@ -0,0 +1,166 @@
from __future__ import annotations
from datetime import date, datetime
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status
from sqlmodel import Field, Session, SQLModel, select
from db import get_session
from innercontext.api.auth_deps import get_current_user
from innercontext.auth import CurrentUser, IdentityData, sync_current_user
from innercontext.models import HouseholdRole, Role, UserProfile
router = APIRouter()
class SessionSyncRequest(SQLModel):
iss: str | None = None
sub: str | None = None
email: str | None = None
name: str | None = None
preferred_username: str | None = None
groups: list[str] | None = None
class AuthHouseholdMembershipPublic(SQLModel):
household_id: UUID
role: HouseholdRole
class AuthUserPublic(SQLModel):
id: UUID
role: Role
household_membership: AuthHouseholdMembershipPublic | None = None
class AuthIdentityPublic(SQLModel):
issuer: str
subject: str
email: str | None = None
name: str | None = None
preferred_username: str | None = None
groups: list[str] = Field(default_factory=list)
class AuthProfilePublic(SQLModel):
id: UUID
user_id: UUID | None
birth_date: date | None = None
sex_at_birth: str | None = None
created_at: datetime
updated_at: datetime
class AuthSessionResponse(SQLModel):
user: AuthUserPublic
identity: AuthIdentityPublic
profile: AuthProfilePublic | None = None
def _build_identity(
current_user: CurrentUser,
payload: SessionSyncRequest | None,
) -> IdentityData:
if payload is None:
return current_user.identity
if payload.iss is not None and payload.iss != current_user.identity.issuer:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Session sync issuer does not match bearer token",
)
if payload.sub is not None and payload.sub != current_user.identity.subject:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Session sync subject does not match bearer token",
)
return IdentityData(
issuer=current_user.identity.issuer,
subject=current_user.identity.subject,
email=(
payload.email if payload.email is not None else current_user.identity.email
),
name=payload.name if payload.name is not None else current_user.identity.name,
preferred_username=(
payload.preferred_username
if payload.preferred_username is not None
else current_user.identity.preferred_username
),
groups=(
tuple(payload.groups)
if payload.groups is not None
else current_user.identity.groups
),
)
def _get_profile(session: Session, user_id: UUID) -> UserProfile | None:
return session.exec(
select(UserProfile).where(UserProfile.user_id == user_id)
).first()
def _profile_public(profile: UserProfile | None) -> AuthProfilePublic | None:
if profile is None:
return None
return AuthProfilePublic(
id=profile.id,
user_id=profile.user_id,
birth_date=profile.birth_date,
sex_at_birth=(
profile.sex_at_birth.value if profile.sex_at_birth is not None else None
),
created_at=profile.created_at,
updated_at=profile.updated_at,
)
def _response(session: Session, current_user: CurrentUser) -> AuthSessionResponse:
household_membership = None
if current_user.household_membership is not None:
household_membership = AuthHouseholdMembershipPublic(
household_id=current_user.household_membership.household_id,
role=current_user.household_membership.role,
)
return AuthSessionResponse(
user=AuthUserPublic(
id=current_user.user_id,
role=current_user.role,
household_membership=household_membership,
),
identity=AuthIdentityPublic(
issuer=current_user.identity.issuer,
subject=current_user.identity.subject,
email=current_user.identity.email,
name=current_user.identity.name,
preferred_username=current_user.identity.preferred_username,
groups=list(current_user.identity.groups),
),
profile=_profile_public(_get_profile(session, current_user.user_id)),
)
@router.post("/session/sync", response_model=AuthSessionResponse)
def sync_session(
payload: SessionSyncRequest | None = None,
session: Session = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user),
):
synced_user = sync_current_user(
session,
current_user.claims,
identity=_build_identity(current_user, payload),
)
return _response(session, synced_user)
@router.get("/me", response_model=AuthSessionResponse)
def get_me(
session: Session = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user),
):
return _response(session, current_user)

View file

@ -0,0 +1,57 @@
from __future__ import annotations
from typing import Annotated
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from sqlmodel import Session
from db import get_session
from innercontext.auth import (
AuthConfigurationError,
CurrentUser,
TokenValidationError,
sync_current_user,
validate_access_token,
)
from innercontext.models import Role
_bearer_scheme = HTTPBearer(auto_error=False)
def _unauthorized(detail: str) -> HTTPException:
return HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=detail,
headers={"WWW-Authenticate": "Bearer"},
)
def get_current_user(
credentials: Annotated[
HTTPAuthorizationCredentials | None, Depends(_bearer_scheme)
],
session: Session = Depends(get_session),
) -> CurrentUser:
if credentials is None or credentials.scheme.lower() != "bearer":
raise _unauthorized("Missing bearer token")
try:
claims = validate_access_token(credentials.credentials)
return sync_current_user(session, claims)
except AuthConfigurationError as exc:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=str(exc),
) from exc
except TokenValidationError as exc:
raise _unauthorized(str(exc)) from exc
def require_admin(current_user: CurrentUser = Depends(get_current_user)) -> CurrentUser:
if current_user.role is not Role.ADMIN:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Admin role required",
)
return current_user

View file

@ -0,0 +1,384 @@
from __future__ import annotations
import os
import time
from dataclasses import dataclass, field
from datetime import UTC, datetime
from functools import lru_cache
from threading import Lock
from typing import Any, Mapping
from uuid import UUID
import httpx
import jwt
from jwt import InvalidTokenError, PyJWKSet
from sqlmodel import Session, select
from innercontext.models import HouseholdMembership, HouseholdRole, Role, User
_DISCOVERY_PATH = "/.well-known/openid-configuration"
_SUPPORTED_ALGORITHMS = frozenset(
{"RS256", "RS384", "RS512", "ES256", "ES384", "ES512"}
)
class AuthConfigurationError(RuntimeError):
pass
class TokenValidationError(ValueError):
pass
@dataclass(frozen=True, slots=True)
class AuthSettings:
issuer: str
client_id: str
audiences: tuple[str, ...]
discovery_url: str
jwks_url: str | None
groups_claim: str
admin_groups: tuple[str, ...]
member_groups: tuple[str, ...]
jwks_cache_ttl_seconds: int
http_timeout_seconds: float
clock_skew_seconds: int
@dataclass(frozen=True, slots=True)
class TokenClaims:
issuer: str
subject: str
audience: tuple[str, ...]
expires_at: datetime
groups: tuple[str, ...] = ()
email: str | None = None
name: str | None = None
preferred_username: str | None = None
raw_claims: Mapping[str, Any] = field(default_factory=dict, repr=False)
@classmethod
def from_payload(
cls, payload: Mapping[str, Any], settings: AuthSettings
) -> "TokenClaims":
audience = payload.get("aud")
if isinstance(audience, str):
audiences = (audience,)
elif isinstance(audience, list):
audiences = tuple(str(item) for item in audience)
else:
audiences = ()
groups = _normalize_groups(payload.get(settings.groups_claim))
exp = payload.get("exp")
if not isinstance(exp, (int, float)):
raise TokenValidationError("Access token missing exp claim")
return cls(
issuer=str(payload["iss"]),
subject=str(payload["sub"]),
audience=audiences,
expires_at=datetime.fromtimestamp(exp, tz=UTC),
groups=groups,
email=_optional_str(payload.get("email")),
name=_optional_str(payload.get("name")),
preferred_username=_optional_str(payload.get("preferred_username")),
raw_claims=dict(payload),
)
@dataclass(frozen=True, slots=True)
class IdentityData:
issuer: str
subject: str
email: str | None = None
name: str | None = None
preferred_username: str | None = None
groups: tuple[str, ...] = ()
@classmethod
def from_claims(cls, claims: TokenClaims) -> "IdentityData":
return cls(
issuer=claims.issuer,
subject=claims.subject,
email=claims.email,
name=claims.name,
preferred_username=claims.preferred_username,
groups=claims.groups,
)
@dataclass(frozen=True, slots=True)
class CurrentHouseholdMembership:
household_id: UUID
role: HouseholdRole
@dataclass(frozen=True, slots=True)
class CurrentUser:
user_id: UUID
role: Role
identity: IdentityData
claims: TokenClaims = field(repr=False)
household_membership: CurrentHouseholdMembership | None = None
def _split_csv(value: str | None) -> tuple[str, ...]:
if value is None:
return ()
return tuple(item.strip() for item in value.split(",") if item.strip())
def _optional_str(value: Any) -> str | None:
if value is None:
return None
if isinstance(value, str):
return value
return str(value)
def _normalize_groups(value: Any) -> tuple[str, ...]:
if value is None:
return ()
if isinstance(value, str):
return (value,)
if isinstance(value, list):
return tuple(str(item) for item in value)
if isinstance(value, tuple):
return tuple(str(item) for item in value)
return (str(value),)
def _required_env(name: str) -> str:
value = os.environ.get(name)
if value:
return value
raise AuthConfigurationError(f"Missing required auth environment variable: {name}")
@lru_cache
def get_auth_settings() -> AuthSettings:
issuer = _required_env("OIDC_ISSUER")
client_id = _required_env("OIDC_CLIENT_ID")
audiences = _split_csv(os.environ.get("OIDC_AUDIENCE")) or (client_id,)
discovery_url = os.environ.get("OIDC_DISCOVERY_URL") or (
issuer.rstrip("/") + _DISCOVERY_PATH
)
return AuthSettings(
issuer=issuer,
client_id=client_id,
audiences=audiences,
discovery_url=discovery_url,
jwks_url=os.environ.get("OIDC_JWKS_URL"),
groups_claim=os.environ.get("OIDC_GROUPS_CLAIM", "groups"),
admin_groups=_split_csv(os.environ.get("OIDC_ADMIN_GROUPS")),
member_groups=_split_csv(os.environ.get("OIDC_MEMBER_GROUPS")),
jwks_cache_ttl_seconds=int(
os.environ.get("OIDC_JWKS_CACHE_TTL_SECONDS", "300")
),
http_timeout_seconds=float(os.environ.get("OIDC_HTTP_TIMEOUT_SECONDS", "5")),
clock_skew_seconds=int(os.environ.get("OIDC_CLOCK_SKEW_SECONDS", "30")),
)
class CachedJwksClient:
def __init__(self, settings: AuthSettings):
self._settings = settings
self._lock = Lock()
self._jwks: PyJWKSet | None = None
self._jwks_fetched_at = 0.0
self._discovery_jwks_url: str | None = None
self._discovery_fetched_at = 0.0
def get_signing_key(self, kid: str) -> Any:
with self._lock:
jwks = self._get_jwks_locked()
key = self._find_key(jwks, kid)
if key is not None:
return key
self._refresh_jwks_locked(
force_discovery_refresh=self._settings.jwks_url is None
)
if self._jwks is None:
raise TokenValidationError("JWKS cache is empty")
key = self._find_key(self._jwks, kid)
if key is None:
raise TokenValidationError(f"No signing key found for kid '{kid}'")
return key
def _get_jwks_locked(self) -> PyJWKSet:
if self._jwks is None or self._is_stale(self._jwks_fetched_at):
self._refresh_jwks_locked(force_discovery_refresh=False)
if self._jwks is None:
raise TokenValidationError("Unable to load JWKS")
return self._jwks
def _refresh_jwks_locked(self, force_discovery_refresh: bool) -> None:
jwks_url = self._resolve_jwks_url_locked(force_refresh=force_discovery_refresh)
data = self._fetch_json(jwks_url)
try:
self._jwks = PyJWKSet.from_dict(data)
except Exception as exc:
raise TokenValidationError(
"OIDC provider returned an invalid JWKS payload"
) from exc
self._jwks_fetched_at = time.monotonic()
def _resolve_jwks_url_locked(self, force_refresh: bool) -> str:
if self._settings.jwks_url:
return self._settings.jwks_url
if (
force_refresh
or self._discovery_jwks_url is None
or self._is_stale(self._discovery_fetched_at)
):
discovery = self._fetch_json(self._settings.discovery_url)
jwks_uri = discovery.get("jwks_uri")
if not isinstance(jwks_uri, str) or not jwks_uri:
raise TokenValidationError("OIDC discovery document missing jwks_uri")
self._discovery_jwks_url = jwks_uri
self._discovery_fetched_at = time.monotonic()
if self._discovery_jwks_url is None:
raise TokenValidationError("Unable to resolve JWKS URL")
return self._discovery_jwks_url
def _fetch_json(self, url: str) -> dict[str, Any]:
try:
response = httpx.get(url, timeout=self._settings.http_timeout_seconds)
response.raise_for_status()
except httpx.HTTPError as exc:
raise TokenValidationError(
f"Failed to fetch OIDC metadata from {url}"
) from exc
data = response.json()
if not isinstance(data, dict):
raise TokenValidationError(
f"OIDC metadata from {url} must be a JSON object"
)
return data
def _is_stale(self, fetched_at: float) -> bool:
return (time.monotonic() - fetched_at) >= self._settings.jwks_cache_ttl_seconds
@staticmethod
def _find_key(jwks: PyJWKSet, kid: str) -> Any | None:
for jwk in jwks.keys:
if jwk.key_id == kid:
return jwk.key
return None
@lru_cache
def get_jwks_client() -> CachedJwksClient:
return CachedJwksClient(get_auth_settings())
def reset_auth_caches() -> None:
get_auth_settings.cache_clear()
get_jwks_client.cache_clear()
def validate_access_token(token: str) -> TokenClaims:
settings = get_auth_settings()
try:
unverified_header = jwt.get_unverified_header(token)
except InvalidTokenError as exc:
raise TokenValidationError("Malformed access token header") from exc
kid = unverified_header.get("kid")
algorithm = unverified_header.get("alg")
if not isinstance(kid, str) or not kid:
raise TokenValidationError("Access token missing kid header")
if not isinstance(algorithm, str) or algorithm not in _SUPPORTED_ALGORITHMS:
raise TokenValidationError("Access token uses an unsupported signing algorithm")
signing_key = get_jwks_client().get_signing_key(kid)
try:
payload = jwt.decode(
token,
key=signing_key,
algorithms=[algorithm],
audience=settings.audiences,
issuer=settings.issuer,
options={"require": ["exp", "iss", "sub"]},
leeway=settings.clock_skew_seconds,
)
except InvalidTokenError as exc:
raise TokenValidationError("Invalid access token") from exc
return TokenClaims.from_payload(payload, settings)
def sync_current_user(
session: Session,
claims: TokenClaims,
identity: IdentityData | None = None,
) -> CurrentUser:
effective_identity = identity or IdentityData.from_claims(claims)
statement = select(User).where(
User.oidc_issuer == effective_identity.issuer,
User.oidc_subject == effective_identity.subject,
)
user = session.exec(statement).first()
existing_role = user.role if user is not None else None
resolved_role = resolve_role(effective_identity.groups, existing_role=existing_role)
needs_commit = False
if user is None:
user = User(
oidc_issuer=effective_identity.issuer,
oidc_subject=effective_identity.subject,
role=resolved_role,
)
session.add(user)
needs_commit = True
elif user.role != resolved_role:
user.role = resolved_role
session.add(user)
needs_commit = True
if needs_commit:
session.commit()
session.refresh(user)
membership = session.exec(
select(HouseholdMembership).where(HouseholdMembership.user_id == user.id)
).first()
household_membership = None
if membership is not None:
household_membership = CurrentHouseholdMembership(
household_id=membership.household_id,
role=membership.role,
)
return CurrentUser(
user_id=user.id,
role=user.role,
identity=effective_identity,
claims=claims,
household_membership=household_membership,
)
def resolve_role(groups: tuple[str, ...], existing_role: Role | None = None) -> Role:
settings = get_auth_settings()
if groups:
group_set = set(groups)
if settings.admin_groups and group_set.intersection(settings.admin_groups):
return Role.ADMIN
if settings.member_groups:
if group_set.intersection(settings.member_groups):
return Role.MEMBER
return Role.MEMBER
return Role.MEMBER
return existing_role or Role.MEMBER

View file

@ -5,13 +5,14 @@ from dotenv import load_dotenv
load_dotenv() # load .env before db.py reads DATABASE_URL
from fastapi import FastAPI # noqa: E402
from fastapi import Depends, FastAPI # noqa: E402
from fastapi.middleware.cors import CORSMiddleware # noqa: E402
from sqlmodel import Session # noqa: E402
from db import create_db_and_tables, engine # noqa: E402
from innercontext.api import ( # noqa: E402
ai_logs,
auth,
health,
inventory,
products,
@ -19,6 +20,7 @@ from innercontext.api import ( # noqa: E402
routines,
skincare,
)
from innercontext.api.auth_deps import get_current_user # noqa: E402
from innercontext.services.pricing_jobs import enqueue_pricing_recalc # noqa: E402
@ -47,13 +49,51 @@ app.add_middleware(
allow_headers=["*"],
)
app.include_router(products.router, prefix="/products", tags=["products"])
app.include_router(inventory.router, prefix="/inventory", tags=["inventory"])
app.include_router(profile.router, prefix="/profile", tags=["profile"])
app.include_router(health.router, prefix="/health", tags=["health"])
app.include_router(routines.router, prefix="/routines", tags=["routines"])
app.include_router(skincare.router, prefix="/skincare", tags=["skincare"])
app.include_router(ai_logs.router, prefix="/ai-logs", tags=["ai-logs"])
protected = [Depends(get_current_user)]
app.include_router(auth.router, prefix="/auth", tags=["auth"])
app.include_router(
products.router,
prefix="/products",
tags=["products"],
dependencies=protected,
)
app.include_router(
inventory.router,
prefix="/inventory",
tags=["inventory"],
dependencies=protected,
)
app.include_router(
profile.router,
prefix="/profile",
tags=["profile"],
dependencies=protected,
)
app.include_router(
health.router,
prefix="/health",
tags=["health"],
dependencies=protected,
)
app.include_router(
routines.router,
prefix="/routines",
tags=["routines"],
dependencies=protected,
)
app.include_router(
skincare.router,
prefix="/skincare",
tags=["skincare"],
dependencies=protected,
)
app.include_router(
ai_logs.router,
prefix="/ai-logs",
tags=["ai-logs"],
dependencies=protected,
)
@app.get("/health-check")

View file

@ -8,6 +8,7 @@ dependencies = [
"alembic>=1.14",
"fastapi>=0.132.0",
"google-genai>=1.65.0",
"pyjwt[crypto]>=2.10.1",
"psycopg[binary]>=3.3.3",
"python-dotenv>=1.2.1",
"python-multipart>=0.0.22",

View file

@ -1,4 +1,6 @@
import os
from datetime import UTC, datetime, timedelta
from uuid import uuid4
# Must be set before importing db (which calls create_engine at module level)
os.environ.setdefault("DATABASE_URL", "sqlite://")
@ -10,6 +12,9 @@ from sqlmodel.pool import StaticPool
import db as db_module
from db import get_session
from innercontext.api.auth_deps import get_current_user
from innercontext.auth import CurrentUser, IdentityData, TokenClaims
from innercontext.models import Role
from main import app
@ -38,7 +43,24 @@ def client(session, monkeypatch):
def _override():
yield session
def _current_user_override():
claims = TokenClaims(
issuer="https://auth.test",
subject="test-user",
audience=("innercontext-web",),
expires_at=datetime.now(UTC) + timedelta(hours=1),
groups=("innercontext-admin",),
raw_claims={"iss": "https://auth.test", "sub": "test-user"},
)
return CurrentUser(
user_id=uuid4(),
role=Role.ADMIN,
identity=IdentityData.from_claims(claims),
claims=claims,
)
app.dependency_overrides[get_session] = _override
app.dependency_overrides[get_current_user] = _current_user_override
with TestClient(app) as c:
yield c
app.dependency_overrides.clear()

275
backend/tests/test_auth.py Normal file
View file

@ -0,0 +1,275 @@
from __future__ import annotations
import json
from datetime import UTC, datetime, timedelta
from uuid import UUID, uuid4
import jwt
import pytest
from cryptography.hazmat.primitives.asymmetric import rsa
from fastapi import HTTPException
from fastapi.testclient import TestClient
from jwt import algorithms
from sqlmodel import Session, SQLModel, create_engine
from sqlmodel.pool import StaticPool
import db as db_module
from db import get_session
from innercontext.api.auth_deps import require_admin
from innercontext.auth import (
CurrentHouseholdMembership,
CurrentUser,
IdentityData,
TokenClaims,
reset_auth_caches,
validate_access_token,
)
from innercontext.models import (
Household,
HouseholdMembership,
HouseholdRole,
Role,
User,
)
from main import app
class _MockResponse:
def __init__(self, payload: dict[str, object], status_code: int = 200):
self._payload = payload
self.status_code = status_code
def raise_for_status(self) -> None:
if self.status_code >= 400:
raise RuntimeError(f"unexpected status {self.status_code}")
def json(self) -> dict[str, object]:
return self._payload
@pytest.fixture()
def auth_env(monkeypatch):
monkeypatch.setenv("OIDC_ISSUER", "https://auth.example.test")
monkeypatch.setenv("OIDC_CLIENT_ID", "innercontext-web")
monkeypatch.setenv(
"OIDC_DISCOVERY_URL",
"https://auth.example.test/.well-known/openid-configuration",
)
monkeypatch.setenv("OIDC_ADMIN_GROUPS", "innercontext-admin")
monkeypatch.setenv("OIDC_MEMBER_GROUPS", "innercontext-member")
monkeypatch.setenv("OIDC_JWKS_CACHE_TTL_SECONDS", "3600")
reset_auth_caches()
yield
reset_auth_caches()
@pytest.fixture()
def rsa_keypair():
private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048,
)
return private_key, private_key.public_key()
@pytest.fixture()
def auth_session(monkeypatch):
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
monkeypatch.setattr(db_module, "engine", engine)
import innercontext.models # noqa: F401
SQLModel.metadata.create_all(engine)
with Session(engine) as session:
yield session
@pytest.fixture()
def auth_client(auth_session):
def _override():
yield auth_session
app.dependency_overrides[get_session] = _override
with TestClient(app) as client:
yield client
app.dependency_overrides.clear()
def _public_jwk(public_key, kid: str) -> dict[str, object]:
jwk = json.loads(algorithms.RSAAlgorithm.to_jwk(public_key))
jwk["kid"] = kid
jwk["use"] = "sig"
jwk["alg"] = "RS256"
return jwk
def _sign_token(private_key, kid: str, **claims_overrides: object) -> str:
now = datetime.now(UTC)
payload: dict[str, object] = {
"iss": "https://auth.example.test",
"sub": "user-123",
"aud": "innercontext-web",
"exp": int((now + timedelta(hours=1)).timestamp()),
"iat": int(now.timestamp()),
"groups": ["innercontext-admin"],
"email": "user@example.test",
"name": "Inner Context User",
"preferred_username": "ictx-user",
}
payload.update(claims_overrides)
return jwt.encode(payload, private_key, algorithm="RS256", headers={"kid": kid})
def _mock_oidc(monkeypatch, public_key, *, fetch_counts: dict[str, int] | None = None):
def _fake_get(url: str, timeout: float):
if fetch_counts is not None:
fetch_counts[url] = fetch_counts.get(url, 0) + 1
if url.endswith("/.well-known/openid-configuration"):
return _MockResponse({"jwks_uri": "https://auth.example.test/jwks.json"})
if url.endswith("/jwks.json"):
return _MockResponse({"keys": [_public_jwk(public_key, "kid-1")]})
raise AssertionError(f"unexpected URL {url} with timeout {timeout}")
monkeypatch.setattr("innercontext.auth.httpx.get", _fake_get)
def test_validate_access_token_uses_cached_jwks(auth_env, rsa_keypair, monkeypatch):
private_key, public_key = rsa_keypair
fetch_counts: dict[str, int] = {}
_mock_oidc(monkeypatch, public_key, fetch_counts=fetch_counts)
validate_access_token(_sign_token(private_key, "kid-1", sub="user-a"))
validate_access_token(_sign_token(private_key, "kid-1", sub="user-b"))
assert (
fetch_counts["https://auth.example.test/.well-known/openid-configuration"] == 1
)
assert fetch_counts["https://auth.example.test/jwks.json"] == 1
@pytest.mark.parametrize(
("path", "payload"),
[
(
"/auth/session/sync",
{
"email": "sync@example.test",
"name": "Synced User",
"preferred_username": "synced-user",
"groups": ["innercontext-admin"],
},
),
("/auth/me", None),
],
ids=["/auth/session/sync", "/auth/me"],
)
def test_sync_protected_endpoints_create_or_resolve_current_user(
auth_env,
auth_client,
auth_session,
rsa_keypair,
monkeypatch,
path: str,
payload: dict[str, object] | None,
):
private_key, public_key = rsa_keypair
_mock_oidc(monkeypatch, public_key)
token = _sign_token(private_key, "kid-1")
if path == "/auth/me":
user = User(
oidc_issuer="https://auth.example.test",
oidc_subject="user-123",
role=Role.ADMIN,
)
auth_session.add(user)
auth_session.commit()
auth_session.refresh(user)
household = Household()
auth_session.add(household)
auth_session.commit()
auth_session.refresh(household)
membership = HouseholdMembership(
user_id=user.id,
household_id=household.id,
role=HouseholdRole.OWNER,
)
auth_session.add(membership)
auth_session.commit()
response = auth_client.request(
"POST" if path.endswith("sync") else "GET",
path,
headers={"Authorization": f"Bearer {token}"},
json=payload,
)
assert response.status_code == 200
data = response.json()
assert data["user"]["role"] == "admin"
assert data["identity"]["issuer"] == "https://auth.example.test"
assert data["identity"]["subject"] == "user-123"
synced_user = auth_session.get(User, UUID(data["user"]["id"]))
assert synced_user is not None
assert synced_user.oidc_issuer == "https://auth.example.test"
assert synced_user.oidc_subject == "user-123"
if path == "/auth/session/sync":
assert data["identity"]["email"] == "sync@example.test"
assert data["identity"]["groups"] == ["innercontext-admin"]
else:
assert data["user"]["household_membership"]["role"] == "owner"
@pytest.mark.parametrize(
"path",
["/auth/me", "/profile"],
ids=["/auth/me expects 401", "/profile expects 401"],
)
def test_unauthorized_protected_endpoints_return_401(auth_env, auth_client, path: str):
response = auth_client.get(path)
assert response.status_code == 401
assert response.json()["detail"] == "Missing bearer token"
def test_unauthorized_invalid_bearer_token_is_rejected(
auth_env, auth_client, rsa_keypair, monkeypatch
):
_, public_key = rsa_keypair
_mock_oidc(monkeypatch, public_key)
response = auth_client.get(
"/auth/me",
headers={"Authorization": "Bearer not-a-jwt"},
)
assert response.status_code == 401
def test_require_admin_raises_for_member():
claims = TokenClaims(
issuer="https://auth.example.test",
subject="member-1",
audience=("innercontext-web",),
expires_at=datetime.now(UTC) + timedelta(hours=1),
raw_claims={"iss": "https://auth.example.test", "sub": "member-1"},
)
current_user = CurrentUser(
user_id=uuid4(),
role=Role.MEMBER,
identity=IdentityData.from_claims(claims),
claims=claims,
household_membership=CurrentHouseholdMembership(
household_id=uuid4(),
role=HouseholdRole.MEMBER,
),
)
with pytest.raises(HTTPException) as exc_info:
require_admin(current_user)
assert exc_info.value.status_code == 403