innercontext/backend/innercontext/auth.py

384 lines
12 KiB
Python

from __future__ import annotations
import os
import time
from dataclasses import dataclass, field
from datetime import UTC, datetime
from functools import lru_cache
from threading import Lock
from typing import Any, Mapping
from uuid import UUID
import httpx
import jwt
from jwt import InvalidTokenError, PyJWKSet
from sqlmodel import Session, select
from innercontext.models import HouseholdMembership, HouseholdRole, Role, User
_DISCOVERY_PATH = "/.well-known/openid-configuration"
_SUPPORTED_ALGORITHMS = frozenset(
{"RS256", "RS384", "RS512", "ES256", "ES384", "ES512"}
)
class AuthConfigurationError(RuntimeError):
pass
class TokenValidationError(ValueError):
pass
@dataclass(frozen=True, slots=True)
class AuthSettings:
issuer: str
client_id: str
audiences: tuple[str, ...]
discovery_url: str
jwks_url: str | None
groups_claim: str
admin_groups: tuple[str, ...]
member_groups: tuple[str, ...]
jwks_cache_ttl_seconds: int
http_timeout_seconds: float
clock_skew_seconds: int
@dataclass(frozen=True, slots=True)
class TokenClaims:
issuer: str
subject: str
audience: tuple[str, ...]
expires_at: datetime
groups: tuple[str, ...] = ()
email: str | None = None
name: str | None = None
preferred_username: str | None = None
raw_claims: Mapping[str, Any] = field(default_factory=dict, repr=False)
@classmethod
def from_payload(
cls, payload: Mapping[str, Any], settings: AuthSettings
) -> "TokenClaims":
audience = payload.get("aud")
if isinstance(audience, str):
audiences = (audience,)
elif isinstance(audience, list):
audiences = tuple(str(item) for item in audience)
else:
audiences = ()
groups = _normalize_groups(payload.get(settings.groups_claim))
exp = payload.get("exp")
if not isinstance(exp, (int, float)):
raise TokenValidationError("Access token missing exp claim")
return cls(
issuer=str(payload["iss"]),
subject=str(payload["sub"]),
audience=audiences,
expires_at=datetime.fromtimestamp(exp, tz=UTC),
groups=groups,
email=_optional_str(payload.get("email")),
name=_optional_str(payload.get("name")),
preferred_username=_optional_str(payload.get("preferred_username")),
raw_claims=dict(payload),
)
@dataclass(frozen=True, slots=True)
class IdentityData:
issuer: str
subject: str
email: str | None = None
name: str | None = None
preferred_username: str | None = None
groups: tuple[str, ...] = ()
@classmethod
def from_claims(cls, claims: TokenClaims) -> "IdentityData":
return cls(
issuer=claims.issuer,
subject=claims.subject,
email=claims.email,
name=claims.name,
preferred_username=claims.preferred_username,
groups=claims.groups,
)
@dataclass(frozen=True, slots=True)
class CurrentHouseholdMembership:
household_id: UUID
role: HouseholdRole
@dataclass(frozen=True, slots=True)
class CurrentUser:
user_id: UUID
role: Role
identity: IdentityData
claims: TokenClaims = field(repr=False)
household_membership: CurrentHouseholdMembership | None = None
def _split_csv(value: str | None) -> tuple[str, ...]:
if value is None:
return ()
return tuple(item.strip() for item in value.split(",") if item.strip())
def _optional_str(value: Any) -> str | None:
if value is None:
return None
if isinstance(value, str):
return value
return str(value)
def _normalize_groups(value: Any) -> tuple[str, ...]:
if value is None:
return ()
if isinstance(value, str):
return (value,)
if isinstance(value, list):
return tuple(str(item) for item in value)
if isinstance(value, tuple):
return tuple(str(item) for item in value)
return (str(value),)
def _required_env(name: str) -> str:
value = os.environ.get(name)
if value:
return value
raise AuthConfigurationError(f"Missing required auth environment variable: {name}")
@lru_cache
def get_auth_settings() -> AuthSettings:
issuer = _required_env("OIDC_ISSUER")
client_id = _required_env("OIDC_CLIENT_ID")
audiences = _split_csv(os.environ.get("OIDC_AUDIENCE")) or (client_id,)
discovery_url = os.environ.get("OIDC_DISCOVERY_URL") or (
issuer.rstrip("/") + _DISCOVERY_PATH
)
return AuthSettings(
issuer=issuer,
client_id=client_id,
audiences=audiences,
discovery_url=discovery_url,
jwks_url=os.environ.get("OIDC_JWKS_URL"),
groups_claim=os.environ.get("OIDC_GROUPS_CLAIM", "groups"),
admin_groups=_split_csv(os.environ.get("OIDC_ADMIN_GROUPS")),
member_groups=_split_csv(os.environ.get("OIDC_MEMBER_GROUPS")),
jwks_cache_ttl_seconds=int(
os.environ.get("OIDC_JWKS_CACHE_TTL_SECONDS", "300")
),
http_timeout_seconds=float(os.environ.get("OIDC_HTTP_TIMEOUT_SECONDS", "5")),
clock_skew_seconds=int(os.environ.get("OIDC_CLOCK_SKEW_SECONDS", "30")),
)
class CachedJwksClient:
def __init__(self, settings: AuthSettings):
self._settings = settings
self._lock = Lock()
self._jwks: PyJWKSet | None = None
self._jwks_fetched_at = 0.0
self._discovery_jwks_url: str | None = None
self._discovery_fetched_at = 0.0
def get_signing_key(self, kid: str) -> Any:
with self._lock:
jwks = self._get_jwks_locked()
key = self._find_key(jwks, kid)
if key is not None:
return key
self._refresh_jwks_locked(
force_discovery_refresh=self._settings.jwks_url is None
)
if self._jwks is None:
raise TokenValidationError("JWKS cache is empty")
key = self._find_key(self._jwks, kid)
if key is None:
raise TokenValidationError(f"No signing key found for kid '{kid}'")
return key
def _get_jwks_locked(self) -> PyJWKSet:
if self._jwks is None or self._is_stale(self._jwks_fetched_at):
self._refresh_jwks_locked(force_discovery_refresh=False)
if self._jwks is None:
raise TokenValidationError("Unable to load JWKS")
return self._jwks
def _refresh_jwks_locked(self, force_discovery_refresh: bool) -> None:
jwks_url = self._resolve_jwks_url_locked(force_refresh=force_discovery_refresh)
data = self._fetch_json(jwks_url)
try:
self._jwks = PyJWKSet.from_dict(data)
except Exception as exc:
raise TokenValidationError(
"OIDC provider returned an invalid JWKS payload"
) from exc
self._jwks_fetched_at = time.monotonic()
def _resolve_jwks_url_locked(self, force_refresh: bool) -> str:
if self._settings.jwks_url:
return self._settings.jwks_url
if (
force_refresh
or self._discovery_jwks_url is None
or self._is_stale(self._discovery_fetched_at)
):
discovery = self._fetch_json(self._settings.discovery_url)
jwks_uri = discovery.get("jwks_uri")
if not isinstance(jwks_uri, str) or not jwks_uri:
raise TokenValidationError("OIDC discovery document missing jwks_uri")
self._discovery_jwks_url = jwks_uri
self._discovery_fetched_at = time.monotonic()
if self._discovery_jwks_url is None:
raise TokenValidationError("Unable to resolve JWKS URL")
return self._discovery_jwks_url
def _fetch_json(self, url: str) -> dict[str, Any]:
try:
response = httpx.get(url, timeout=self._settings.http_timeout_seconds)
response.raise_for_status()
except httpx.HTTPError as exc:
raise TokenValidationError(
f"Failed to fetch OIDC metadata from {url}"
) from exc
data = response.json()
if not isinstance(data, dict):
raise TokenValidationError(
f"OIDC metadata from {url} must be a JSON object"
)
return data
def _is_stale(self, fetched_at: float) -> bool:
return (time.monotonic() - fetched_at) >= self._settings.jwks_cache_ttl_seconds
@staticmethod
def _find_key(jwks: PyJWKSet, kid: str) -> Any | None:
for jwk in jwks.keys:
if jwk.key_id == kid:
return jwk.key
return None
@lru_cache
def get_jwks_client() -> CachedJwksClient:
return CachedJwksClient(get_auth_settings())
def reset_auth_caches() -> None:
get_auth_settings.cache_clear()
get_jwks_client.cache_clear()
def validate_access_token(token: str) -> TokenClaims:
settings = get_auth_settings()
try:
unverified_header = jwt.get_unverified_header(token)
except InvalidTokenError as exc:
raise TokenValidationError("Malformed access token header") from exc
kid = unverified_header.get("kid")
algorithm = unverified_header.get("alg")
if not isinstance(kid, str) or not kid:
raise TokenValidationError("Access token missing kid header")
if not isinstance(algorithm, str) or algorithm not in _SUPPORTED_ALGORITHMS:
raise TokenValidationError("Access token uses an unsupported signing algorithm")
signing_key = get_jwks_client().get_signing_key(kid)
try:
payload = jwt.decode(
token,
key=signing_key,
algorithms=[algorithm],
audience=settings.audiences,
issuer=settings.issuer,
options={"require": ["exp", "iss", "sub"]},
leeway=settings.clock_skew_seconds,
)
except InvalidTokenError as exc:
raise TokenValidationError("Invalid access token") from exc
return TokenClaims.from_payload(payload, settings)
def sync_current_user(
session: Session,
claims: TokenClaims,
identity: IdentityData | None = None,
) -> CurrentUser:
effective_identity = identity or IdentityData.from_claims(claims)
statement = select(User).where(
User.oidc_issuer == effective_identity.issuer,
User.oidc_subject == effective_identity.subject,
)
user = session.exec(statement).first()
existing_role = user.role if user is not None else None
resolved_role = resolve_role(effective_identity.groups, existing_role=existing_role)
needs_commit = False
if user is None:
user = User(
oidc_issuer=effective_identity.issuer,
oidc_subject=effective_identity.subject,
role=resolved_role,
)
session.add(user)
needs_commit = True
elif user.role != resolved_role:
user.role = resolved_role
session.add(user)
needs_commit = True
if needs_commit:
session.commit()
session.refresh(user)
membership = session.exec(
select(HouseholdMembership).where(HouseholdMembership.user_id == user.id)
).first()
household_membership = None
if membership is not None:
household_membership = CurrentHouseholdMembership(
household_id=membership.household_id,
role=membership.role,
)
return CurrentUser(
user_id=user.id,
role=user.role,
identity=effective_identity,
claims=claims,
household_membership=household_membership,
)
def resolve_role(groups: tuple[str, ...], existing_role: Role | None = None) -> Role:
settings = get_auth_settings()
if groups:
group_set = set(groups)
if settings.admin_groups and group_set.intersection(settings.admin_groups):
return Role.ADMIN
if settings.member_groups:
if group_set.intersection(settings.member_groups):
return Role.MEMBER
return Role.MEMBER
return Role.MEMBER
return existing_role or Role.MEMBER