384 lines
12 KiB
Python
384 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from datetime import UTC, datetime
|
|
from functools import lru_cache
|
|
from threading import Lock
|
|
from typing import Any, Mapping
|
|
from uuid import UUID
|
|
|
|
import httpx
|
|
import jwt
|
|
from jwt import InvalidTokenError, PyJWKSet
|
|
from sqlmodel import Session, select
|
|
|
|
from innercontext.models import HouseholdMembership, HouseholdRole, Role, User
|
|
|
|
_DISCOVERY_PATH = "/.well-known/openid-configuration"
|
|
_SUPPORTED_ALGORITHMS = frozenset(
|
|
{"RS256", "RS384", "RS512", "ES256", "ES384", "ES512"}
|
|
)
|
|
|
|
|
|
class AuthConfigurationError(RuntimeError):
|
|
pass
|
|
|
|
|
|
class TokenValidationError(ValueError):
|
|
pass
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class AuthSettings:
|
|
issuer: str
|
|
client_id: str
|
|
audiences: tuple[str, ...]
|
|
discovery_url: str
|
|
jwks_url: str | None
|
|
groups_claim: str
|
|
admin_groups: tuple[str, ...]
|
|
member_groups: tuple[str, ...]
|
|
jwks_cache_ttl_seconds: int
|
|
http_timeout_seconds: float
|
|
clock_skew_seconds: int
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class TokenClaims:
|
|
issuer: str
|
|
subject: str
|
|
audience: tuple[str, ...]
|
|
expires_at: datetime
|
|
groups: tuple[str, ...] = ()
|
|
email: str | None = None
|
|
name: str | None = None
|
|
preferred_username: str | None = None
|
|
raw_claims: Mapping[str, Any] = field(default_factory=dict, repr=False)
|
|
|
|
@classmethod
|
|
def from_payload(
|
|
cls, payload: Mapping[str, Any], settings: AuthSettings
|
|
) -> "TokenClaims":
|
|
audience = payload.get("aud")
|
|
if isinstance(audience, str):
|
|
audiences = (audience,)
|
|
elif isinstance(audience, list):
|
|
audiences = tuple(str(item) for item in audience)
|
|
else:
|
|
audiences = ()
|
|
|
|
groups = _normalize_groups(payload.get(settings.groups_claim))
|
|
exp = payload.get("exp")
|
|
if not isinstance(exp, (int, float)):
|
|
raise TokenValidationError("Access token missing exp claim")
|
|
|
|
return cls(
|
|
issuer=str(payload["iss"]),
|
|
subject=str(payload["sub"]),
|
|
audience=audiences,
|
|
expires_at=datetime.fromtimestamp(exp, tz=UTC),
|
|
groups=groups,
|
|
email=_optional_str(payload.get("email")),
|
|
name=_optional_str(payload.get("name")),
|
|
preferred_username=_optional_str(payload.get("preferred_username")),
|
|
raw_claims=dict(payload),
|
|
)
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class IdentityData:
|
|
issuer: str
|
|
subject: str
|
|
email: str | None = None
|
|
name: str | None = None
|
|
preferred_username: str | None = None
|
|
groups: tuple[str, ...] = ()
|
|
|
|
@classmethod
|
|
def from_claims(cls, claims: TokenClaims) -> "IdentityData":
|
|
return cls(
|
|
issuer=claims.issuer,
|
|
subject=claims.subject,
|
|
email=claims.email,
|
|
name=claims.name,
|
|
preferred_username=claims.preferred_username,
|
|
groups=claims.groups,
|
|
)
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class CurrentHouseholdMembership:
|
|
household_id: UUID
|
|
role: HouseholdRole
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class CurrentUser:
|
|
user_id: UUID
|
|
role: Role
|
|
identity: IdentityData
|
|
claims: TokenClaims = field(repr=False)
|
|
household_membership: CurrentHouseholdMembership | None = None
|
|
|
|
|
|
def _split_csv(value: str | None) -> tuple[str, ...]:
|
|
if value is None:
|
|
return ()
|
|
return tuple(item.strip() for item in value.split(",") if item.strip())
|
|
|
|
|
|
def _optional_str(value: Any) -> str | None:
|
|
if value is None:
|
|
return None
|
|
if isinstance(value, str):
|
|
return value
|
|
return str(value)
|
|
|
|
|
|
def _normalize_groups(value: Any) -> tuple[str, ...]:
|
|
if value is None:
|
|
return ()
|
|
if isinstance(value, str):
|
|
return (value,)
|
|
if isinstance(value, list):
|
|
return tuple(str(item) for item in value)
|
|
if isinstance(value, tuple):
|
|
return tuple(str(item) for item in value)
|
|
return (str(value),)
|
|
|
|
|
|
def _required_env(name: str) -> str:
|
|
value = os.environ.get(name)
|
|
if value:
|
|
return value
|
|
raise AuthConfigurationError(f"Missing required auth environment variable: {name}")
|
|
|
|
|
|
@lru_cache
|
|
def get_auth_settings() -> AuthSettings:
|
|
issuer = _required_env("OIDC_ISSUER")
|
|
client_id = _required_env("OIDC_CLIENT_ID")
|
|
audiences = _split_csv(os.environ.get("OIDC_AUDIENCE")) or (client_id,)
|
|
discovery_url = os.environ.get("OIDC_DISCOVERY_URL") or (
|
|
issuer.rstrip("/") + _DISCOVERY_PATH
|
|
)
|
|
|
|
return AuthSettings(
|
|
issuer=issuer,
|
|
client_id=client_id,
|
|
audiences=audiences,
|
|
discovery_url=discovery_url,
|
|
jwks_url=os.environ.get("OIDC_JWKS_URL"),
|
|
groups_claim=os.environ.get("OIDC_GROUPS_CLAIM", "groups"),
|
|
admin_groups=_split_csv(os.environ.get("OIDC_ADMIN_GROUPS")),
|
|
member_groups=_split_csv(os.environ.get("OIDC_MEMBER_GROUPS")),
|
|
jwks_cache_ttl_seconds=int(
|
|
os.environ.get("OIDC_JWKS_CACHE_TTL_SECONDS", "300")
|
|
),
|
|
http_timeout_seconds=float(os.environ.get("OIDC_HTTP_TIMEOUT_SECONDS", "5")),
|
|
clock_skew_seconds=int(os.environ.get("OIDC_CLOCK_SKEW_SECONDS", "30")),
|
|
)
|
|
|
|
|
|
class CachedJwksClient:
|
|
def __init__(self, settings: AuthSettings):
|
|
self._settings = settings
|
|
self._lock = Lock()
|
|
self._jwks: PyJWKSet | None = None
|
|
self._jwks_fetched_at = 0.0
|
|
self._discovery_jwks_url: str | None = None
|
|
self._discovery_fetched_at = 0.0
|
|
|
|
def get_signing_key(self, kid: str) -> Any:
|
|
with self._lock:
|
|
jwks = self._get_jwks_locked()
|
|
key = self._find_key(jwks, kid)
|
|
if key is not None:
|
|
return key
|
|
|
|
self._refresh_jwks_locked(
|
|
force_discovery_refresh=self._settings.jwks_url is None
|
|
)
|
|
if self._jwks is None:
|
|
raise TokenValidationError("JWKS cache is empty")
|
|
|
|
key = self._find_key(self._jwks, kid)
|
|
if key is None:
|
|
raise TokenValidationError(f"No signing key found for kid '{kid}'")
|
|
return key
|
|
|
|
def _get_jwks_locked(self) -> PyJWKSet:
|
|
if self._jwks is None or self._is_stale(self._jwks_fetched_at):
|
|
self._refresh_jwks_locked(force_discovery_refresh=False)
|
|
if self._jwks is None:
|
|
raise TokenValidationError("Unable to load JWKS")
|
|
return self._jwks
|
|
|
|
def _refresh_jwks_locked(self, force_discovery_refresh: bool) -> None:
|
|
jwks_url = self._resolve_jwks_url_locked(force_refresh=force_discovery_refresh)
|
|
data = self._fetch_json(jwks_url)
|
|
try:
|
|
self._jwks = PyJWKSet.from_dict(data)
|
|
except Exception as exc:
|
|
raise TokenValidationError(
|
|
"OIDC provider returned an invalid JWKS payload"
|
|
) from exc
|
|
self._jwks_fetched_at = time.monotonic()
|
|
|
|
def _resolve_jwks_url_locked(self, force_refresh: bool) -> str:
|
|
if self._settings.jwks_url:
|
|
return self._settings.jwks_url
|
|
|
|
if (
|
|
force_refresh
|
|
or self._discovery_jwks_url is None
|
|
or self._is_stale(self._discovery_fetched_at)
|
|
):
|
|
discovery = self._fetch_json(self._settings.discovery_url)
|
|
jwks_uri = discovery.get("jwks_uri")
|
|
if not isinstance(jwks_uri, str) or not jwks_uri:
|
|
raise TokenValidationError("OIDC discovery document missing jwks_uri")
|
|
self._discovery_jwks_url = jwks_uri
|
|
self._discovery_fetched_at = time.monotonic()
|
|
|
|
if self._discovery_jwks_url is None:
|
|
raise TokenValidationError("Unable to resolve JWKS URL")
|
|
return self._discovery_jwks_url
|
|
|
|
def _fetch_json(self, url: str) -> dict[str, Any]:
|
|
try:
|
|
response = httpx.get(url, timeout=self._settings.http_timeout_seconds)
|
|
response.raise_for_status()
|
|
except httpx.HTTPError as exc:
|
|
raise TokenValidationError(
|
|
f"Failed to fetch OIDC metadata from {url}"
|
|
) from exc
|
|
|
|
data = response.json()
|
|
if not isinstance(data, dict):
|
|
raise TokenValidationError(
|
|
f"OIDC metadata from {url} must be a JSON object"
|
|
)
|
|
return data
|
|
|
|
def _is_stale(self, fetched_at: float) -> bool:
|
|
return (time.monotonic() - fetched_at) >= self._settings.jwks_cache_ttl_seconds
|
|
|
|
@staticmethod
|
|
def _find_key(jwks: PyJWKSet, kid: str) -> Any | None:
|
|
for jwk in jwks.keys:
|
|
if jwk.key_id == kid:
|
|
return jwk.key
|
|
return None
|
|
|
|
|
|
@lru_cache
|
|
def get_jwks_client() -> CachedJwksClient:
|
|
return CachedJwksClient(get_auth_settings())
|
|
|
|
|
|
def reset_auth_caches() -> None:
|
|
get_auth_settings.cache_clear()
|
|
get_jwks_client.cache_clear()
|
|
|
|
|
|
def validate_access_token(token: str) -> TokenClaims:
|
|
settings = get_auth_settings()
|
|
|
|
try:
|
|
unverified_header = jwt.get_unverified_header(token)
|
|
except InvalidTokenError as exc:
|
|
raise TokenValidationError("Malformed access token header") from exc
|
|
|
|
kid = unverified_header.get("kid")
|
|
algorithm = unverified_header.get("alg")
|
|
if not isinstance(kid, str) or not kid:
|
|
raise TokenValidationError("Access token missing kid header")
|
|
if not isinstance(algorithm, str) or algorithm not in _SUPPORTED_ALGORITHMS:
|
|
raise TokenValidationError("Access token uses an unsupported signing algorithm")
|
|
|
|
signing_key = get_jwks_client().get_signing_key(kid)
|
|
|
|
try:
|
|
payload = jwt.decode(
|
|
token,
|
|
key=signing_key,
|
|
algorithms=[algorithm],
|
|
audience=settings.audiences,
|
|
issuer=settings.issuer,
|
|
options={"require": ["exp", "iss", "sub"]},
|
|
leeway=settings.clock_skew_seconds,
|
|
)
|
|
except InvalidTokenError as exc:
|
|
raise TokenValidationError("Invalid access token") from exc
|
|
|
|
return TokenClaims.from_payload(payload, settings)
|
|
|
|
|
|
def sync_current_user(
|
|
session: Session,
|
|
claims: TokenClaims,
|
|
identity: IdentityData | None = None,
|
|
) -> CurrentUser:
|
|
effective_identity = identity or IdentityData.from_claims(claims)
|
|
statement = select(User).where(
|
|
User.oidc_issuer == effective_identity.issuer,
|
|
User.oidc_subject == effective_identity.subject,
|
|
)
|
|
user = session.exec(statement).first()
|
|
existing_role = user.role if user is not None else None
|
|
resolved_role = resolve_role(effective_identity.groups, existing_role=existing_role)
|
|
needs_commit = False
|
|
|
|
if user is None:
|
|
user = User(
|
|
oidc_issuer=effective_identity.issuer,
|
|
oidc_subject=effective_identity.subject,
|
|
role=resolved_role,
|
|
)
|
|
session.add(user)
|
|
needs_commit = True
|
|
elif user.role != resolved_role:
|
|
user.role = resolved_role
|
|
session.add(user)
|
|
needs_commit = True
|
|
|
|
if needs_commit:
|
|
session.commit()
|
|
session.refresh(user)
|
|
|
|
membership = session.exec(
|
|
select(HouseholdMembership).where(HouseholdMembership.user_id == user.id)
|
|
).first()
|
|
|
|
household_membership = None
|
|
if membership is not None:
|
|
household_membership = CurrentHouseholdMembership(
|
|
household_id=membership.household_id,
|
|
role=membership.role,
|
|
)
|
|
|
|
return CurrentUser(
|
|
user_id=user.id,
|
|
role=user.role,
|
|
identity=effective_identity,
|
|
claims=claims,
|
|
household_membership=household_membership,
|
|
)
|
|
|
|
|
|
def resolve_role(groups: tuple[str, ...], existing_role: Role | None = None) -> Role:
|
|
settings = get_auth_settings()
|
|
if groups:
|
|
group_set = set(groups)
|
|
if settings.admin_groups and group_set.intersection(settings.admin_groups):
|
|
return Role.ADMIN
|
|
if settings.member_groups:
|
|
if group_set.intersection(settings.member_groups):
|
|
return Role.MEMBER
|
|
return Role.MEMBER
|
|
return Role.MEMBER
|
|
|
|
return existing_role or Role.MEMBER
|