feat(auth): validate Authelia tokens in FastAPI
This commit is contained in:
parent
2704d58673
commit
4782fad5b9
7 changed files with 953 additions and 8 deletions
166
backend/innercontext/api/auth.py
Normal file
166
backend/innercontext/api/auth.py
Normal file
|
|
@ -0,0 +1,166 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import date, datetime
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from sqlmodel import Field, Session, SQLModel, select
|
||||||
|
|
||||||
|
from db import get_session
|
||||||
|
from innercontext.api.auth_deps import get_current_user
|
||||||
|
from innercontext.auth import CurrentUser, IdentityData, sync_current_user
|
||||||
|
from innercontext.models import HouseholdRole, Role, UserProfile
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
class SessionSyncRequest(SQLModel):
|
||||||
|
iss: str | None = None
|
||||||
|
sub: str | None = None
|
||||||
|
email: str | None = None
|
||||||
|
name: str | None = None
|
||||||
|
preferred_username: str | None = None
|
||||||
|
groups: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class AuthHouseholdMembershipPublic(SQLModel):
|
||||||
|
household_id: UUID
|
||||||
|
role: HouseholdRole
|
||||||
|
|
||||||
|
|
||||||
|
class AuthUserPublic(SQLModel):
|
||||||
|
id: UUID
|
||||||
|
role: Role
|
||||||
|
household_membership: AuthHouseholdMembershipPublic | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class AuthIdentityPublic(SQLModel):
|
||||||
|
issuer: str
|
||||||
|
subject: str
|
||||||
|
email: str | None = None
|
||||||
|
name: str | None = None
|
||||||
|
preferred_username: str | None = None
|
||||||
|
groups: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class AuthProfilePublic(SQLModel):
|
||||||
|
id: UUID
|
||||||
|
user_id: UUID | None
|
||||||
|
birth_date: date | None = None
|
||||||
|
sex_at_birth: str | None = None
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
class AuthSessionResponse(SQLModel):
|
||||||
|
user: AuthUserPublic
|
||||||
|
identity: AuthIdentityPublic
|
||||||
|
profile: AuthProfilePublic | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _build_identity(
|
||||||
|
current_user: CurrentUser,
|
||||||
|
payload: SessionSyncRequest | None,
|
||||||
|
) -> IdentityData:
|
||||||
|
if payload is None:
|
||||||
|
return current_user.identity
|
||||||
|
|
||||||
|
if payload.iss is not None and payload.iss != current_user.identity.issuer:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Session sync issuer does not match bearer token",
|
||||||
|
)
|
||||||
|
if payload.sub is not None and payload.sub != current_user.identity.subject:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Session sync subject does not match bearer token",
|
||||||
|
)
|
||||||
|
|
||||||
|
return IdentityData(
|
||||||
|
issuer=current_user.identity.issuer,
|
||||||
|
subject=current_user.identity.subject,
|
||||||
|
email=(
|
||||||
|
payload.email if payload.email is not None else current_user.identity.email
|
||||||
|
),
|
||||||
|
name=payload.name if payload.name is not None else current_user.identity.name,
|
||||||
|
preferred_username=(
|
||||||
|
payload.preferred_username
|
||||||
|
if payload.preferred_username is not None
|
||||||
|
else current_user.identity.preferred_username
|
||||||
|
),
|
||||||
|
groups=(
|
||||||
|
tuple(payload.groups)
|
||||||
|
if payload.groups is not None
|
||||||
|
else current_user.identity.groups
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_profile(session: Session, user_id: UUID) -> UserProfile | None:
|
||||||
|
return session.exec(
|
||||||
|
select(UserProfile).where(UserProfile.user_id == user_id)
|
||||||
|
).first()
|
||||||
|
|
||||||
|
|
||||||
|
def _profile_public(profile: UserProfile | None) -> AuthProfilePublic | None:
|
||||||
|
if profile is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return AuthProfilePublic(
|
||||||
|
id=profile.id,
|
||||||
|
user_id=profile.user_id,
|
||||||
|
birth_date=profile.birth_date,
|
||||||
|
sex_at_birth=(
|
||||||
|
profile.sex_at_birth.value if profile.sex_at_birth is not None else None
|
||||||
|
),
|
||||||
|
created_at=profile.created_at,
|
||||||
|
updated_at=profile.updated_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _response(session: Session, current_user: CurrentUser) -> AuthSessionResponse:
|
||||||
|
household_membership = None
|
||||||
|
if current_user.household_membership is not None:
|
||||||
|
household_membership = AuthHouseholdMembershipPublic(
|
||||||
|
household_id=current_user.household_membership.household_id,
|
||||||
|
role=current_user.household_membership.role,
|
||||||
|
)
|
||||||
|
|
||||||
|
return AuthSessionResponse(
|
||||||
|
user=AuthUserPublic(
|
||||||
|
id=current_user.user_id,
|
||||||
|
role=current_user.role,
|
||||||
|
household_membership=household_membership,
|
||||||
|
),
|
||||||
|
identity=AuthIdentityPublic(
|
||||||
|
issuer=current_user.identity.issuer,
|
||||||
|
subject=current_user.identity.subject,
|
||||||
|
email=current_user.identity.email,
|
||||||
|
name=current_user.identity.name,
|
||||||
|
preferred_username=current_user.identity.preferred_username,
|
||||||
|
groups=list(current_user.identity.groups),
|
||||||
|
),
|
||||||
|
profile=_profile_public(_get_profile(session, current_user.user_id)),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/session/sync", response_model=AuthSessionResponse)
|
||||||
|
def sync_session(
|
||||||
|
payload: SessionSyncRequest | None = None,
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
current_user: CurrentUser = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
synced_user = sync_current_user(
|
||||||
|
session,
|
||||||
|
current_user.claims,
|
||||||
|
identity=_build_identity(current_user, payload),
|
||||||
|
)
|
||||||
|
return _response(session, synced_user)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/me", response_model=AuthSessionResponse)
|
||||||
|
def get_me(
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
current_user: CurrentUser = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
return _response(session, current_user)
|
||||||
57
backend/innercontext/api/auth_deps.py
Normal file
57
backend/innercontext/api/auth_deps.py
Normal file
|
|
@ -0,0 +1,57 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import Depends, HTTPException, status
|
||||||
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||||
|
from sqlmodel import Session
|
||||||
|
|
||||||
|
from db import get_session
|
||||||
|
from innercontext.auth import (
|
||||||
|
AuthConfigurationError,
|
||||||
|
CurrentUser,
|
||||||
|
TokenValidationError,
|
||||||
|
sync_current_user,
|
||||||
|
validate_access_token,
|
||||||
|
)
|
||||||
|
from innercontext.models import Role
|
||||||
|
|
||||||
|
_bearer_scheme = HTTPBearer(auto_error=False)
|
||||||
|
|
||||||
|
|
||||||
|
def _unauthorized(detail: str) -> HTTPException:
|
||||||
|
return HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=detail,
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_user(
|
||||||
|
credentials: Annotated[
|
||||||
|
HTTPAuthorizationCredentials | None, Depends(_bearer_scheme)
|
||||||
|
],
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
) -> CurrentUser:
|
||||||
|
if credentials is None or credentials.scheme.lower() != "bearer":
|
||||||
|
raise _unauthorized("Missing bearer token")
|
||||||
|
|
||||||
|
try:
|
||||||
|
claims = validate_access_token(credentials.credentials)
|
||||||
|
return sync_current_user(session, claims)
|
||||||
|
except AuthConfigurationError as exc:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail=str(exc),
|
||||||
|
) from exc
|
||||||
|
except TokenValidationError as exc:
|
||||||
|
raise _unauthorized(str(exc)) from exc
|
||||||
|
|
||||||
|
|
||||||
|
def require_admin(current_user: CurrentUser = Depends(get_current_user)) -> CurrentUser:
|
||||||
|
if current_user.role is not Role.ADMIN:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Admin role required",
|
||||||
|
)
|
||||||
|
return current_user
|
||||||
384
backend/innercontext/auth.py
Normal file
384
backend/innercontext/auth.py
Normal file
|
|
@ -0,0 +1,384 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from functools import lru_cache
|
||||||
|
from threading import Lock
|
||||||
|
from typing import Any, Mapping
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import jwt
|
||||||
|
from jwt import InvalidTokenError, PyJWKSet
|
||||||
|
from sqlmodel import Session, select
|
||||||
|
|
||||||
|
from innercontext.models import HouseholdMembership, HouseholdRole, Role, User
|
||||||
|
|
||||||
|
_DISCOVERY_PATH = "/.well-known/openid-configuration"
|
||||||
|
_SUPPORTED_ALGORITHMS = frozenset(
|
||||||
|
{"RS256", "RS384", "RS512", "ES256", "ES384", "ES512"}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AuthConfigurationError(RuntimeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TokenValidationError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class AuthSettings:
|
||||||
|
issuer: str
|
||||||
|
client_id: str
|
||||||
|
audiences: tuple[str, ...]
|
||||||
|
discovery_url: str
|
||||||
|
jwks_url: str | None
|
||||||
|
groups_claim: str
|
||||||
|
admin_groups: tuple[str, ...]
|
||||||
|
member_groups: tuple[str, ...]
|
||||||
|
jwks_cache_ttl_seconds: int
|
||||||
|
http_timeout_seconds: float
|
||||||
|
clock_skew_seconds: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class TokenClaims:
|
||||||
|
issuer: str
|
||||||
|
subject: str
|
||||||
|
audience: tuple[str, ...]
|
||||||
|
expires_at: datetime
|
||||||
|
groups: tuple[str, ...] = ()
|
||||||
|
email: str | None = None
|
||||||
|
name: str | None = None
|
||||||
|
preferred_username: str | None = None
|
||||||
|
raw_claims: Mapping[str, Any] = field(default_factory=dict, repr=False)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_payload(
|
||||||
|
cls, payload: Mapping[str, Any], settings: AuthSettings
|
||||||
|
) -> "TokenClaims":
|
||||||
|
audience = payload.get("aud")
|
||||||
|
if isinstance(audience, str):
|
||||||
|
audiences = (audience,)
|
||||||
|
elif isinstance(audience, list):
|
||||||
|
audiences = tuple(str(item) for item in audience)
|
||||||
|
else:
|
||||||
|
audiences = ()
|
||||||
|
|
||||||
|
groups = _normalize_groups(payload.get(settings.groups_claim))
|
||||||
|
exp = payload.get("exp")
|
||||||
|
if not isinstance(exp, (int, float)):
|
||||||
|
raise TokenValidationError("Access token missing exp claim")
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
issuer=str(payload["iss"]),
|
||||||
|
subject=str(payload["sub"]),
|
||||||
|
audience=audiences,
|
||||||
|
expires_at=datetime.fromtimestamp(exp, tz=UTC),
|
||||||
|
groups=groups,
|
||||||
|
email=_optional_str(payload.get("email")),
|
||||||
|
name=_optional_str(payload.get("name")),
|
||||||
|
preferred_username=_optional_str(payload.get("preferred_username")),
|
||||||
|
raw_claims=dict(payload),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class IdentityData:
|
||||||
|
issuer: str
|
||||||
|
subject: str
|
||||||
|
email: str | None = None
|
||||||
|
name: str | None = None
|
||||||
|
preferred_username: str | None = None
|
||||||
|
groups: tuple[str, ...] = ()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_claims(cls, claims: TokenClaims) -> "IdentityData":
|
||||||
|
return cls(
|
||||||
|
issuer=claims.issuer,
|
||||||
|
subject=claims.subject,
|
||||||
|
email=claims.email,
|
||||||
|
name=claims.name,
|
||||||
|
preferred_username=claims.preferred_username,
|
||||||
|
groups=claims.groups,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class CurrentHouseholdMembership:
|
||||||
|
household_id: UUID
|
||||||
|
role: HouseholdRole
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class CurrentUser:
|
||||||
|
user_id: UUID
|
||||||
|
role: Role
|
||||||
|
identity: IdentityData
|
||||||
|
claims: TokenClaims = field(repr=False)
|
||||||
|
household_membership: CurrentHouseholdMembership | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _split_csv(value: str | None) -> tuple[str, ...]:
|
||||||
|
if value is None:
|
||||||
|
return ()
|
||||||
|
return tuple(item.strip() for item in value.split(",") if item.strip())
|
||||||
|
|
||||||
|
|
||||||
|
def _optional_str(value: Any) -> str | None:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
if isinstance(value, str):
|
||||||
|
return value
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_groups(value: Any) -> tuple[str, ...]:
|
||||||
|
if value is None:
|
||||||
|
return ()
|
||||||
|
if isinstance(value, str):
|
||||||
|
return (value,)
|
||||||
|
if isinstance(value, list):
|
||||||
|
return tuple(str(item) for item in value)
|
||||||
|
if isinstance(value, tuple):
|
||||||
|
return tuple(str(item) for item in value)
|
||||||
|
return (str(value),)
|
||||||
|
|
||||||
|
|
||||||
|
def _required_env(name: str) -> str:
|
||||||
|
value = os.environ.get(name)
|
||||||
|
if value:
|
||||||
|
return value
|
||||||
|
raise AuthConfigurationError(f"Missing required auth environment variable: {name}")
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def get_auth_settings() -> AuthSettings:
|
||||||
|
issuer = _required_env("OIDC_ISSUER")
|
||||||
|
client_id = _required_env("OIDC_CLIENT_ID")
|
||||||
|
audiences = _split_csv(os.environ.get("OIDC_AUDIENCE")) or (client_id,)
|
||||||
|
discovery_url = os.environ.get("OIDC_DISCOVERY_URL") or (
|
||||||
|
issuer.rstrip("/") + _DISCOVERY_PATH
|
||||||
|
)
|
||||||
|
|
||||||
|
return AuthSettings(
|
||||||
|
issuer=issuer,
|
||||||
|
client_id=client_id,
|
||||||
|
audiences=audiences,
|
||||||
|
discovery_url=discovery_url,
|
||||||
|
jwks_url=os.environ.get("OIDC_JWKS_URL"),
|
||||||
|
groups_claim=os.environ.get("OIDC_GROUPS_CLAIM", "groups"),
|
||||||
|
admin_groups=_split_csv(os.environ.get("OIDC_ADMIN_GROUPS")),
|
||||||
|
member_groups=_split_csv(os.environ.get("OIDC_MEMBER_GROUPS")),
|
||||||
|
jwks_cache_ttl_seconds=int(
|
||||||
|
os.environ.get("OIDC_JWKS_CACHE_TTL_SECONDS", "300")
|
||||||
|
),
|
||||||
|
http_timeout_seconds=float(os.environ.get("OIDC_HTTP_TIMEOUT_SECONDS", "5")),
|
||||||
|
clock_skew_seconds=int(os.environ.get("OIDC_CLOCK_SKEW_SECONDS", "30")),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CachedJwksClient:
|
||||||
|
def __init__(self, settings: AuthSettings):
|
||||||
|
self._settings = settings
|
||||||
|
self._lock = Lock()
|
||||||
|
self._jwks: PyJWKSet | None = None
|
||||||
|
self._jwks_fetched_at = 0.0
|
||||||
|
self._discovery_jwks_url: str | None = None
|
||||||
|
self._discovery_fetched_at = 0.0
|
||||||
|
|
||||||
|
def get_signing_key(self, kid: str) -> Any:
|
||||||
|
with self._lock:
|
||||||
|
jwks = self._get_jwks_locked()
|
||||||
|
key = self._find_key(jwks, kid)
|
||||||
|
if key is not None:
|
||||||
|
return key
|
||||||
|
|
||||||
|
self._refresh_jwks_locked(
|
||||||
|
force_discovery_refresh=self._settings.jwks_url is None
|
||||||
|
)
|
||||||
|
if self._jwks is None:
|
||||||
|
raise TokenValidationError("JWKS cache is empty")
|
||||||
|
|
||||||
|
key = self._find_key(self._jwks, kid)
|
||||||
|
if key is None:
|
||||||
|
raise TokenValidationError(f"No signing key found for kid '{kid}'")
|
||||||
|
return key
|
||||||
|
|
||||||
|
def _get_jwks_locked(self) -> PyJWKSet:
|
||||||
|
if self._jwks is None or self._is_stale(self._jwks_fetched_at):
|
||||||
|
self._refresh_jwks_locked(force_discovery_refresh=False)
|
||||||
|
if self._jwks is None:
|
||||||
|
raise TokenValidationError("Unable to load JWKS")
|
||||||
|
return self._jwks
|
||||||
|
|
||||||
|
def _refresh_jwks_locked(self, force_discovery_refresh: bool) -> None:
|
||||||
|
jwks_url = self._resolve_jwks_url_locked(force_refresh=force_discovery_refresh)
|
||||||
|
data = self._fetch_json(jwks_url)
|
||||||
|
try:
|
||||||
|
self._jwks = PyJWKSet.from_dict(data)
|
||||||
|
except Exception as exc:
|
||||||
|
raise TokenValidationError(
|
||||||
|
"OIDC provider returned an invalid JWKS payload"
|
||||||
|
) from exc
|
||||||
|
self._jwks_fetched_at = time.monotonic()
|
||||||
|
|
||||||
|
def _resolve_jwks_url_locked(self, force_refresh: bool) -> str:
|
||||||
|
if self._settings.jwks_url:
|
||||||
|
return self._settings.jwks_url
|
||||||
|
|
||||||
|
if (
|
||||||
|
force_refresh
|
||||||
|
or self._discovery_jwks_url is None
|
||||||
|
or self._is_stale(self._discovery_fetched_at)
|
||||||
|
):
|
||||||
|
discovery = self._fetch_json(self._settings.discovery_url)
|
||||||
|
jwks_uri = discovery.get("jwks_uri")
|
||||||
|
if not isinstance(jwks_uri, str) or not jwks_uri:
|
||||||
|
raise TokenValidationError("OIDC discovery document missing jwks_uri")
|
||||||
|
self._discovery_jwks_url = jwks_uri
|
||||||
|
self._discovery_fetched_at = time.monotonic()
|
||||||
|
|
||||||
|
if self._discovery_jwks_url is None:
|
||||||
|
raise TokenValidationError("Unable to resolve JWKS URL")
|
||||||
|
return self._discovery_jwks_url
|
||||||
|
|
||||||
|
def _fetch_json(self, url: str) -> dict[str, Any]:
|
||||||
|
try:
|
||||||
|
response = httpx.get(url, timeout=self._settings.http_timeout_seconds)
|
||||||
|
response.raise_for_status()
|
||||||
|
except httpx.HTTPError as exc:
|
||||||
|
raise TokenValidationError(
|
||||||
|
f"Failed to fetch OIDC metadata from {url}"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
raise TokenValidationError(
|
||||||
|
f"OIDC metadata from {url} must be a JSON object"
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _is_stale(self, fetched_at: float) -> bool:
|
||||||
|
return (time.monotonic() - fetched_at) >= self._settings.jwks_cache_ttl_seconds
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _find_key(jwks: PyJWKSet, kid: str) -> Any | None:
|
||||||
|
for jwk in jwks.keys:
|
||||||
|
if jwk.key_id == kid:
|
||||||
|
return jwk.key
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def get_jwks_client() -> CachedJwksClient:
|
||||||
|
return CachedJwksClient(get_auth_settings())
|
||||||
|
|
||||||
|
|
||||||
|
def reset_auth_caches() -> None:
|
||||||
|
get_auth_settings.cache_clear()
|
||||||
|
get_jwks_client.cache_clear()
|
||||||
|
|
||||||
|
|
||||||
|
def validate_access_token(token: str) -> TokenClaims:
|
||||||
|
settings = get_auth_settings()
|
||||||
|
|
||||||
|
try:
|
||||||
|
unverified_header = jwt.get_unverified_header(token)
|
||||||
|
except InvalidTokenError as exc:
|
||||||
|
raise TokenValidationError("Malformed access token header") from exc
|
||||||
|
|
||||||
|
kid = unverified_header.get("kid")
|
||||||
|
algorithm = unverified_header.get("alg")
|
||||||
|
if not isinstance(kid, str) or not kid:
|
||||||
|
raise TokenValidationError("Access token missing kid header")
|
||||||
|
if not isinstance(algorithm, str) or algorithm not in _SUPPORTED_ALGORITHMS:
|
||||||
|
raise TokenValidationError("Access token uses an unsupported signing algorithm")
|
||||||
|
|
||||||
|
signing_key = get_jwks_client().get_signing_key(kid)
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(
|
||||||
|
token,
|
||||||
|
key=signing_key,
|
||||||
|
algorithms=[algorithm],
|
||||||
|
audience=settings.audiences,
|
||||||
|
issuer=settings.issuer,
|
||||||
|
options={"require": ["exp", "iss", "sub"]},
|
||||||
|
leeway=settings.clock_skew_seconds,
|
||||||
|
)
|
||||||
|
except InvalidTokenError as exc:
|
||||||
|
raise TokenValidationError("Invalid access token") from exc
|
||||||
|
|
||||||
|
return TokenClaims.from_payload(payload, settings)
|
||||||
|
|
||||||
|
|
||||||
|
def sync_current_user(
|
||||||
|
session: Session,
|
||||||
|
claims: TokenClaims,
|
||||||
|
identity: IdentityData | None = None,
|
||||||
|
) -> CurrentUser:
|
||||||
|
effective_identity = identity or IdentityData.from_claims(claims)
|
||||||
|
statement = select(User).where(
|
||||||
|
User.oidc_issuer == effective_identity.issuer,
|
||||||
|
User.oidc_subject == effective_identity.subject,
|
||||||
|
)
|
||||||
|
user = session.exec(statement).first()
|
||||||
|
existing_role = user.role if user is not None else None
|
||||||
|
resolved_role = resolve_role(effective_identity.groups, existing_role=existing_role)
|
||||||
|
needs_commit = False
|
||||||
|
|
||||||
|
if user is None:
|
||||||
|
user = User(
|
||||||
|
oidc_issuer=effective_identity.issuer,
|
||||||
|
oidc_subject=effective_identity.subject,
|
||||||
|
role=resolved_role,
|
||||||
|
)
|
||||||
|
session.add(user)
|
||||||
|
needs_commit = True
|
||||||
|
elif user.role != resolved_role:
|
||||||
|
user.role = resolved_role
|
||||||
|
session.add(user)
|
||||||
|
needs_commit = True
|
||||||
|
|
||||||
|
if needs_commit:
|
||||||
|
session.commit()
|
||||||
|
session.refresh(user)
|
||||||
|
|
||||||
|
membership = session.exec(
|
||||||
|
select(HouseholdMembership).where(HouseholdMembership.user_id == user.id)
|
||||||
|
).first()
|
||||||
|
|
||||||
|
household_membership = None
|
||||||
|
if membership is not None:
|
||||||
|
household_membership = CurrentHouseholdMembership(
|
||||||
|
household_id=membership.household_id,
|
||||||
|
role=membership.role,
|
||||||
|
)
|
||||||
|
|
||||||
|
return CurrentUser(
|
||||||
|
user_id=user.id,
|
||||||
|
role=user.role,
|
||||||
|
identity=effective_identity,
|
||||||
|
claims=claims,
|
||||||
|
household_membership=household_membership,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_role(groups: tuple[str, ...], existing_role: Role | None = None) -> Role:
|
||||||
|
settings = get_auth_settings()
|
||||||
|
if groups:
|
||||||
|
group_set = set(groups)
|
||||||
|
if settings.admin_groups and group_set.intersection(settings.admin_groups):
|
||||||
|
return Role.ADMIN
|
||||||
|
if settings.member_groups:
|
||||||
|
if group_set.intersection(settings.member_groups):
|
||||||
|
return Role.MEMBER
|
||||||
|
return Role.MEMBER
|
||||||
|
return Role.MEMBER
|
||||||
|
|
||||||
|
return existing_role or Role.MEMBER
|
||||||
|
|
@ -5,13 +5,14 @@ from dotenv import load_dotenv
|
||||||
|
|
||||||
load_dotenv() # load .env before db.py reads DATABASE_URL
|
load_dotenv() # load .env before db.py reads DATABASE_URL
|
||||||
|
|
||||||
from fastapi import FastAPI # noqa: E402
|
from fastapi import Depends, FastAPI # noqa: E402
|
||||||
from fastapi.middleware.cors import CORSMiddleware # noqa: E402
|
from fastapi.middleware.cors import CORSMiddleware # noqa: E402
|
||||||
from sqlmodel import Session # noqa: E402
|
from sqlmodel import Session # noqa: E402
|
||||||
|
|
||||||
from db import create_db_and_tables, engine # noqa: E402
|
from db import create_db_and_tables, engine # noqa: E402
|
||||||
from innercontext.api import ( # noqa: E402
|
from innercontext.api import ( # noqa: E402
|
||||||
ai_logs,
|
ai_logs,
|
||||||
|
auth,
|
||||||
health,
|
health,
|
||||||
inventory,
|
inventory,
|
||||||
products,
|
products,
|
||||||
|
|
@ -19,6 +20,7 @@ from innercontext.api import ( # noqa: E402
|
||||||
routines,
|
routines,
|
||||||
skincare,
|
skincare,
|
||||||
)
|
)
|
||||||
|
from innercontext.api.auth_deps import get_current_user # noqa: E402
|
||||||
from innercontext.services.pricing_jobs import enqueue_pricing_recalc # noqa: E402
|
from innercontext.services.pricing_jobs import enqueue_pricing_recalc # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -47,13 +49,51 @@ app.add_middleware(
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
app.include_router(products.router, prefix="/products", tags=["products"])
|
protected = [Depends(get_current_user)]
|
||||||
app.include_router(inventory.router, prefix="/inventory", tags=["inventory"])
|
|
||||||
app.include_router(profile.router, prefix="/profile", tags=["profile"])
|
app.include_router(auth.router, prefix="/auth", tags=["auth"])
|
||||||
app.include_router(health.router, prefix="/health", tags=["health"])
|
app.include_router(
|
||||||
app.include_router(routines.router, prefix="/routines", tags=["routines"])
|
products.router,
|
||||||
app.include_router(skincare.router, prefix="/skincare", tags=["skincare"])
|
prefix="/products",
|
||||||
app.include_router(ai_logs.router, prefix="/ai-logs", tags=["ai-logs"])
|
tags=["products"],
|
||||||
|
dependencies=protected,
|
||||||
|
)
|
||||||
|
app.include_router(
|
||||||
|
inventory.router,
|
||||||
|
prefix="/inventory",
|
||||||
|
tags=["inventory"],
|
||||||
|
dependencies=protected,
|
||||||
|
)
|
||||||
|
app.include_router(
|
||||||
|
profile.router,
|
||||||
|
prefix="/profile",
|
||||||
|
tags=["profile"],
|
||||||
|
dependencies=protected,
|
||||||
|
)
|
||||||
|
app.include_router(
|
||||||
|
health.router,
|
||||||
|
prefix="/health",
|
||||||
|
tags=["health"],
|
||||||
|
dependencies=protected,
|
||||||
|
)
|
||||||
|
app.include_router(
|
||||||
|
routines.router,
|
||||||
|
prefix="/routines",
|
||||||
|
tags=["routines"],
|
||||||
|
dependencies=protected,
|
||||||
|
)
|
||||||
|
app.include_router(
|
||||||
|
skincare.router,
|
||||||
|
prefix="/skincare",
|
||||||
|
tags=["skincare"],
|
||||||
|
dependencies=protected,
|
||||||
|
)
|
||||||
|
app.include_router(
|
||||||
|
ai_logs.router,
|
||||||
|
prefix="/ai-logs",
|
||||||
|
tags=["ai-logs"],
|
||||||
|
dependencies=protected,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health-check")
|
@app.get("/health-check")
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ dependencies = [
|
||||||
"alembic>=1.14",
|
"alembic>=1.14",
|
||||||
"fastapi>=0.132.0",
|
"fastapi>=0.132.0",
|
||||||
"google-genai>=1.65.0",
|
"google-genai>=1.65.0",
|
||||||
|
"pyjwt[crypto]>=2.10.1",
|
||||||
"psycopg[binary]>=3.3.3",
|
"psycopg[binary]>=3.3.3",
|
||||||
"python-dotenv>=1.2.1",
|
"python-dotenv>=1.2.1",
|
||||||
"python-multipart>=0.0.22",
|
"python-multipart>=0.0.22",
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,6 @@
|
||||||
import os
|
import os
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
# Must be set before importing db (which calls create_engine at module level)
|
# Must be set before importing db (which calls create_engine at module level)
|
||||||
os.environ.setdefault("DATABASE_URL", "sqlite://")
|
os.environ.setdefault("DATABASE_URL", "sqlite://")
|
||||||
|
|
@ -10,6 +12,9 @@ from sqlmodel.pool import StaticPool
|
||||||
|
|
||||||
import db as db_module
|
import db as db_module
|
||||||
from db import get_session
|
from db import get_session
|
||||||
|
from innercontext.api.auth_deps import get_current_user
|
||||||
|
from innercontext.auth import CurrentUser, IdentityData, TokenClaims
|
||||||
|
from innercontext.models import Role
|
||||||
from main import app
|
from main import app
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -38,7 +43,24 @@ def client(session, monkeypatch):
|
||||||
def _override():
|
def _override():
|
||||||
yield session
|
yield session
|
||||||
|
|
||||||
|
def _current_user_override():
|
||||||
|
claims = TokenClaims(
|
||||||
|
issuer="https://auth.test",
|
||||||
|
subject="test-user",
|
||||||
|
audience=("innercontext-web",),
|
||||||
|
expires_at=datetime.now(UTC) + timedelta(hours=1),
|
||||||
|
groups=("innercontext-admin",),
|
||||||
|
raw_claims={"iss": "https://auth.test", "sub": "test-user"},
|
||||||
|
)
|
||||||
|
return CurrentUser(
|
||||||
|
user_id=uuid4(),
|
||||||
|
role=Role.ADMIN,
|
||||||
|
identity=IdentityData.from_claims(claims),
|
||||||
|
claims=claims,
|
||||||
|
)
|
||||||
|
|
||||||
app.dependency_overrides[get_session] = _override
|
app.dependency_overrides[get_session] = _override
|
||||||
|
app.dependency_overrides[get_current_user] = _current_user_override
|
||||||
with TestClient(app) as c:
|
with TestClient(app) as c:
|
||||||
yield c
|
yield c
|
||||||
app.dependency_overrides.clear()
|
app.dependency_overrides.clear()
|
||||||
|
|
|
||||||
275
backend/tests/test_auth.py
Normal file
275
backend/tests/test_auth.py
Normal file
|
|
@ -0,0 +1,275 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
|
import jwt
|
||||||
|
import pytest
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from jwt import algorithms
|
||||||
|
from sqlmodel import Session, SQLModel, create_engine
|
||||||
|
from sqlmodel.pool import StaticPool
|
||||||
|
|
||||||
|
import db as db_module
|
||||||
|
from db import get_session
|
||||||
|
from innercontext.api.auth_deps import require_admin
|
||||||
|
from innercontext.auth import (
|
||||||
|
CurrentHouseholdMembership,
|
||||||
|
CurrentUser,
|
||||||
|
IdentityData,
|
||||||
|
TokenClaims,
|
||||||
|
reset_auth_caches,
|
||||||
|
validate_access_token,
|
||||||
|
)
|
||||||
|
from innercontext.models import (
|
||||||
|
Household,
|
||||||
|
HouseholdMembership,
|
||||||
|
HouseholdRole,
|
||||||
|
Role,
|
||||||
|
User,
|
||||||
|
)
|
||||||
|
from main import app
|
||||||
|
|
||||||
|
|
||||||
|
class _MockResponse:
|
||||||
|
def __init__(self, payload: dict[str, object], status_code: int = 200):
|
||||||
|
self._payload = payload
|
||||||
|
self.status_code = status_code
|
||||||
|
|
||||||
|
def raise_for_status(self) -> None:
|
||||||
|
if self.status_code >= 400:
|
||||||
|
raise RuntimeError(f"unexpected status {self.status_code}")
|
||||||
|
|
||||||
|
def json(self) -> dict[str, object]:
|
||||||
|
return self._payload
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def auth_env(monkeypatch):
|
||||||
|
monkeypatch.setenv("OIDC_ISSUER", "https://auth.example.test")
|
||||||
|
monkeypatch.setenv("OIDC_CLIENT_ID", "innercontext-web")
|
||||||
|
monkeypatch.setenv(
|
||||||
|
"OIDC_DISCOVERY_URL",
|
||||||
|
"https://auth.example.test/.well-known/openid-configuration",
|
||||||
|
)
|
||||||
|
monkeypatch.setenv("OIDC_ADMIN_GROUPS", "innercontext-admin")
|
||||||
|
monkeypatch.setenv("OIDC_MEMBER_GROUPS", "innercontext-member")
|
||||||
|
monkeypatch.setenv("OIDC_JWKS_CACHE_TTL_SECONDS", "3600")
|
||||||
|
reset_auth_caches()
|
||||||
|
yield
|
||||||
|
reset_auth_caches()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def rsa_keypair():
|
||||||
|
private_key = rsa.generate_private_key(
|
||||||
|
public_exponent=65537,
|
||||||
|
key_size=2048,
|
||||||
|
)
|
||||||
|
return private_key, private_key.public_key()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def auth_session(monkeypatch):
|
||||||
|
engine = create_engine(
|
||||||
|
"sqlite://",
|
||||||
|
connect_args={"check_same_thread": False},
|
||||||
|
poolclass=StaticPool,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(db_module, "engine", engine)
|
||||||
|
import innercontext.models # noqa: F401
|
||||||
|
|
||||||
|
SQLModel.metadata.create_all(engine)
|
||||||
|
with Session(engine) as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def auth_client(auth_session):
|
||||||
|
def _override():
|
||||||
|
yield auth_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_session] = _override
|
||||||
|
with TestClient(app) as client:
|
||||||
|
yield client
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def _public_jwk(public_key, kid: str) -> dict[str, object]:
|
||||||
|
jwk = json.loads(algorithms.RSAAlgorithm.to_jwk(public_key))
|
||||||
|
jwk["kid"] = kid
|
||||||
|
jwk["use"] = "sig"
|
||||||
|
jwk["alg"] = "RS256"
|
||||||
|
return jwk
|
||||||
|
|
||||||
|
|
||||||
|
def _sign_token(private_key, kid: str, **claims_overrides: object) -> str:
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
payload: dict[str, object] = {
|
||||||
|
"iss": "https://auth.example.test",
|
||||||
|
"sub": "user-123",
|
||||||
|
"aud": "innercontext-web",
|
||||||
|
"exp": int((now + timedelta(hours=1)).timestamp()),
|
||||||
|
"iat": int(now.timestamp()),
|
||||||
|
"groups": ["innercontext-admin"],
|
||||||
|
"email": "user@example.test",
|
||||||
|
"name": "Inner Context User",
|
||||||
|
"preferred_username": "ictx-user",
|
||||||
|
}
|
||||||
|
payload.update(claims_overrides)
|
||||||
|
return jwt.encode(payload, private_key, algorithm="RS256", headers={"kid": kid})
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_oidc(monkeypatch, public_key, *, fetch_counts: dict[str, int] | None = None):
|
||||||
|
def _fake_get(url: str, timeout: float):
|
||||||
|
if fetch_counts is not None:
|
||||||
|
fetch_counts[url] = fetch_counts.get(url, 0) + 1
|
||||||
|
if url.endswith("/.well-known/openid-configuration"):
|
||||||
|
return _MockResponse({"jwks_uri": "https://auth.example.test/jwks.json"})
|
||||||
|
if url.endswith("/jwks.json"):
|
||||||
|
return _MockResponse({"keys": [_public_jwk(public_key, "kid-1")]})
|
||||||
|
raise AssertionError(f"unexpected URL {url} with timeout {timeout}")
|
||||||
|
|
||||||
|
monkeypatch.setattr("innercontext.auth.httpx.get", _fake_get)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_access_token_uses_cached_jwks(auth_env, rsa_keypair, monkeypatch):
|
||||||
|
private_key, public_key = rsa_keypair
|
||||||
|
fetch_counts: dict[str, int] = {}
|
||||||
|
_mock_oidc(monkeypatch, public_key, fetch_counts=fetch_counts)
|
||||||
|
|
||||||
|
validate_access_token(_sign_token(private_key, "kid-1", sub="user-a"))
|
||||||
|
validate_access_token(_sign_token(private_key, "kid-1", sub="user-b"))
|
||||||
|
|
||||||
|
assert (
|
||||||
|
fetch_counts["https://auth.example.test/.well-known/openid-configuration"] == 1
|
||||||
|
)
|
||||||
|
assert fetch_counts["https://auth.example.test/jwks.json"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("path", "payload"),
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"/auth/session/sync",
|
||||||
|
{
|
||||||
|
"email": "sync@example.test",
|
||||||
|
"name": "Synced User",
|
||||||
|
"preferred_username": "synced-user",
|
||||||
|
"groups": ["innercontext-admin"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
("/auth/me", None),
|
||||||
|
],
|
||||||
|
ids=["/auth/session/sync", "/auth/me"],
|
||||||
|
)
|
||||||
|
def test_sync_protected_endpoints_create_or_resolve_current_user(
|
||||||
|
auth_env,
|
||||||
|
auth_client,
|
||||||
|
auth_session,
|
||||||
|
rsa_keypair,
|
||||||
|
monkeypatch,
|
||||||
|
path: str,
|
||||||
|
payload: dict[str, object] | None,
|
||||||
|
):
|
||||||
|
private_key, public_key = rsa_keypair
|
||||||
|
_mock_oidc(monkeypatch, public_key)
|
||||||
|
token = _sign_token(private_key, "kid-1")
|
||||||
|
|
||||||
|
if path == "/auth/me":
|
||||||
|
user = User(
|
||||||
|
oidc_issuer="https://auth.example.test",
|
||||||
|
oidc_subject="user-123",
|
||||||
|
role=Role.ADMIN,
|
||||||
|
)
|
||||||
|
auth_session.add(user)
|
||||||
|
auth_session.commit()
|
||||||
|
auth_session.refresh(user)
|
||||||
|
|
||||||
|
household = Household()
|
||||||
|
auth_session.add(household)
|
||||||
|
auth_session.commit()
|
||||||
|
auth_session.refresh(household)
|
||||||
|
|
||||||
|
membership = HouseholdMembership(
|
||||||
|
user_id=user.id,
|
||||||
|
household_id=household.id,
|
||||||
|
role=HouseholdRole.OWNER,
|
||||||
|
)
|
||||||
|
auth_session.add(membership)
|
||||||
|
auth_session.commit()
|
||||||
|
|
||||||
|
response = auth_client.request(
|
||||||
|
"POST" if path.endswith("sync") else "GET",
|
||||||
|
path,
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
json=payload,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["user"]["role"] == "admin"
|
||||||
|
assert data["identity"]["issuer"] == "https://auth.example.test"
|
||||||
|
assert data["identity"]["subject"] == "user-123"
|
||||||
|
|
||||||
|
synced_user = auth_session.get(User, UUID(data["user"]["id"]))
|
||||||
|
assert synced_user is not None
|
||||||
|
assert synced_user.oidc_issuer == "https://auth.example.test"
|
||||||
|
assert synced_user.oidc_subject == "user-123"
|
||||||
|
|
||||||
|
if path == "/auth/session/sync":
|
||||||
|
assert data["identity"]["email"] == "sync@example.test"
|
||||||
|
assert data["identity"]["groups"] == ["innercontext-admin"]
|
||||||
|
else:
|
||||||
|
assert data["user"]["household_membership"]["role"] == "owner"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"path",
|
||||||
|
["/auth/me", "/profile"],
|
||||||
|
ids=["/auth/me expects 401", "/profile expects 401"],
|
||||||
|
)
|
||||||
|
def test_unauthorized_protected_endpoints_return_401(auth_env, auth_client, path: str):
|
||||||
|
response = auth_client.get(path)
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert response.json()["detail"] == "Missing bearer token"
|
||||||
|
|
||||||
|
|
||||||
|
def test_unauthorized_invalid_bearer_token_is_rejected(
|
||||||
|
auth_env, auth_client, rsa_keypair, monkeypatch
|
||||||
|
):
|
||||||
|
_, public_key = rsa_keypair
|
||||||
|
_mock_oidc(monkeypatch, public_key)
|
||||||
|
response = auth_client.get(
|
||||||
|
"/auth/me",
|
||||||
|
headers={"Authorization": "Bearer not-a-jwt"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
def test_require_admin_raises_for_member():
|
||||||
|
claims = TokenClaims(
|
||||||
|
issuer="https://auth.example.test",
|
||||||
|
subject="member-1",
|
||||||
|
audience=("innercontext-web",),
|
||||||
|
expires_at=datetime.now(UTC) + timedelta(hours=1),
|
||||||
|
raw_claims={"iss": "https://auth.example.test", "sub": "member-1"},
|
||||||
|
)
|
||||||
|
current_user = CurrentUser(
|
||||||
|
user_id=uuid4(),
|
||||||
|
role=Role.MEMBER,
|
||||||
|
identity=IdentityData.from_claims(claims),
|
||||||
|
claims=claims,
|
||||||
|
household_membership=CurrentHouseholdMembership(
|
||||||
|
household_id=uuid4(),
|
||||||
|
role=HouseholdRole.MEMBER,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
require_admin(current_user)
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == 403
|
||||||
Loading…
Add table
Add a link
Reference in a new issue