feat(auth): validate Authelia tokens in FastAPI
This commit is contained in:
parent
2704d58673
commit
4782fad5b9
7 changed files with 953 additions and 8 deletions
275
backend/tests/test_auth.py
Normal file
275
backend/tests/test_auth.py
Normal file
|
|
@ -0,0 +1,275 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import jwt
|
||||
import pytest
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from fastapi import HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
from jwt import algorithms
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
from sqlmodel.pool import StaticPool
|
||||
|
||||
import db as db_module
|
||||
from db import get_session
|
||||
from innercontext.api.auth_deps import require_admin
|
||||
from innercontext.auth import (
|
||||
CurrentHouseholdMembership,
|
||||
CurrentUser,
|
||||
IdentityData,
|
||||
TokenClaims,
|
||||
reset_auth_caches,
|
||||
validate_access_token,
|
||||
)
|
||||
from innercontext.models import (
|
||||
Household,
|
||||
HouseholdMembership,
|
||||
HouseholdRole,
|
||||
Role,
|
||||
User,
|
||||
)
|
||||
from main import app
|
||||
|
||||
|
||||
class _MockResponse:
|
||||
def __init__(self, payload: dict[str, object], status_code: int = 200):
|
||||
self._payload = payload
|
||||
self.status_code = status_code
|
||||
|
||||
def raise_for_status(self) -> None:
|
||||
if self.status_code >= 400:
|
||||
raise RuntimeError(f"unexpected status {self.status_code}")
|
||||
|
||||
def json(self) -> dict[str, object]:
|
||||
return self._payload
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def auth_env(monkeypatch):
|
||||
monkeypatch.setenv("OIDC_ISSUER", "https://auth.example.test")
|
||||
monkeypatch.setenv("OIDC_CLIENT_ID", "innercontext-web")
|
||||
monkeypatch.setenv(
|
||||
"OIDC_DISCOVERY_URL",
|
||||
"https://auth.example.test/.well-known/openid-configuration",
|
||||
)
|
||||
monkeypatch.setenv("OIDC_ADMIN_GROUPS", "innercontext-admin")
|
||||
monkeypatch.setenv("OIDC_MEMBER_GROUPS", "innercontext-member")
|
||||
monkeypatch.setenv("OIDC_JWKS_CACHE_TTL_SECONDS", "3600")
|
||||
reset_auth_caches()
|
||||
yield
|
||||
reset_auth_caches()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def rsa_keypair():
|
||||
private_key = rsa.generate_private_key(
|
||||
public_exponent=65537,
|
||||
key_size=2048,
|
||||
)
|
||||
return private_key, private_key.public_key()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def auth_session(monkeypatch):
|
||||
engine = create_engine(
|
||||
"sqlite://",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
monkeypatch.setattr(db_module, "engine", engine)
|
||||
import innercontext.models # noqa: F401
|
||||
|
||||
SQLModel.metadata.create_all(engine)
|
||||
with Session(engine) as session:
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def auth_client(auth_session):
|
||||
def _override():
|
||||
yield auth_session
|
||||
|
||||
app.dependency_overrides[get_session] = _override
|
||||
with TestClient(app) as client:
|
||||
yield client
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def _public_jwk(public_key, kid: str) -> dict[str, object]:
|
||||
jwk = json.loads(algorithms.RSAAlgorithm.to_jwk(public_key))
|
||||
jwk["kid"] = kid
|
||||
jwk["use"] = "sig"
|
||||
jwk["alg"] = "RS256"
|
||||
return jwk
|
||||
|
||||
|
||||
def _sign_token(private_key, kid: str, **claims_overrides: object) -> str:
|
||||
now = datetime.now(UTC)
|
||||
payload: dict[str, object] = {
|
||||
"iss": "https://auth.example.test",
|
||||
"sub": "user-123",
|
||||
"aud": "innercontext-web",
|
||||
"exp": int((now + timedelta(hours=1)).timestamp()),
|
||||
"iat": int(now.timestamp()),
|
||||
"groups": ["innercontext-admin"],
|
||||
"email": "user@example.test",
|
||||
"name": "Inner Context User",
|
||||
"preferred_username": "ictx-user",
|
||||
}
|
||||
payload.update(claims_overrides)
|
||||
return jwt.encode(payload, private_key, algorithm="RS256", headers={"kid": kid})
|
||||
|
||||
|
||||
def _mock_oidc(monkeypatch, public_key, *, fetch_counts: dict[str, int] | None = None):
|
||||
def _fake_get(url: str, timeout: float):
|
||||
if fetch_counts is not None:
|
||||
fetch_counts[url] = fetch_counts.get(url, 0) + 1
|
||||
if url.endswith("/.well-known/openid-configuration"):
|
||||
return _MockResponse({"jwks_uri": "https://auth.example.test/jwks.json"})
|
||||
if url.endswith("/jwks.json"):
|
||||
return _MockResponse({"keys": [_public_jwk(public_key, "kid-1")]})
|
||||
raise AssertionError(f"unexpected URL {url} with timeout {timeout}")
|
||||
|
||||
monkeypatch.setattr("innercontext.auth.httpx.get", _fake_get)
|
||||
|
||||
|
||||
def test_validate_access_token_uses_cached_jwks(auth_env, rsa_keypair, monkeypatch):
|
||||
private_key, public_key = rsa_keypair
|
||||
fetch_counts: dict[str, int] = {}
|
||||
_mock_oidc(monkeypatch, public_key, fetch_counts=fetch_counts)
|
||||
|
||||
validate_access_token(_sign_token(private_key, "kid-1", sub="user-a"))
|
||||
validate_access_token(_sign_token(private_key, "kid-1", sub="user-b"))
|
||||
|
||||
assert (
|
||||
fetch_counts["https://auth.example.test/.well-known/openid-configuration"] == 1
|
||||
)
|
||||
assert fetch_counts["https://auth.example.test/jwks.json"] == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("path", "payload"),
|
||||
[
|
||||
(
|
||||
"/auth/session/sync",
|
||||
{
|
||||
"email": "sync@example.test",
|
||||
"name": "Synced User",
|
||||
"preferred_username": "synced-user",
|
||||
"groups": ["innercontext-admin"],
|
||||
},
|
||||
),
|
||||
("/auth/me", None),
|
||||
],
|
||||
ids=["/auth/session/sync", "/auth/me"],
|
||||
)
|
||||
def test_sync_protected_endpoints_create_or_resolve_current_user(
|
||||
auth_env,
|
||||
auth_client,
|
||||
auth_session,
|
||||
rsa_keypair,
|
||||
monkeypatch,
|
||||
path: str,
|
||||
payload: dict[str, object] | None,
|
||||
):
|
||||
private_key, public_key = rsa_keypair
|
||||
_mock_oidc(monkeypatch, public_key)
|
||||
token = _sign_token(private_key, "kid-1")
|
||||
|
||||
if path == "/auth/me":
|
||||
user = User(
|
||||
oidc_issuer="https://auth.example.test",
|
||||
oidc_subject="user-123",
|
||||
role=Role.ADMIN,
|
||||
)
|
||||
auth_session.add(user)
|
||||
auth_session.commit()
|
||||
auth_session.refresh(user)
|
||||
|
||||
household = Household()
|
||||
auth_session.add(household)
|
||||
auth_session.commit()
|
||||
auth_session.refresh(household)
|
||||
|
||||
membership = HouseholdMembership(
|
||||
user_id=user.id,
|
||||
household_id=household.id,
|
||||
role=HouseholdRole.OWNER,
|
||||
)
|
||||
auth_session.add(membership)
|
||||
auth_session.commit()
|
||||
|
||||
response = auth_client.request(
|
||||
"POST" if path.endswith("sync") else "GET",
|
||||
path,
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
json=payload,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user"]["role"] == "admin"
|
||||
assert data["identity"]["issuer"] == "https://auth.example.test"
|
||||
assert data["identity"]["subject"] == "user-123"
|
||||
|
||||
synced_user = auth_session.get(User, UUID(data["user"]["id"]))
|
||||
assert synced_user is not None
|
||||
assert synced_user.oidc_issuer == "https://auth.example.test"
|
||||
assert synced_user.oidc_subject == "user-123"
|
||||
|
||||
if path == "/auth/session/sync":
|
||||
assert data["identity"]["email"] == "sync@example.test"
|
||||
assert data["identity"]["groups"] == ["innercontext-admin"]
|
||||
else:
|
||||
assert data["user"]["household_membership"]["role"] == "owner"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
["/auth/me", "/profile"],
|
||||
ids=["/auth/me expects 401", "/profile expects 401"],
|
||||
)
|
||||
def test_unauthorized_protected_endpoints_return_401(auth_env, auth_client, path: str):
|
||||
response = auth_client.get(path)
|
||||
assert response.status_code == 401
|
||||
assert response.json()["detail"] == "Missing bearer token"
|
||||
|
||||
|
||||
def test_unauthorized_invalid_bearer_token_is_rejected(
|
||||
auth_env, auth_client, rsa_keypair, monkeypatch
|
||||
):
|
||||
_, public_key = rsa_keypair
|
||||
_mock_oidc(monkeypatch, public_key)
|
||||
response = auth_client.get(
|
||||
"/auth/me",
|
||||
headers={"Authorization": "Bearer not-a-jwt"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
def test_require_admin_raises_for_member():
|
||||
claims = TokenClaims(
|
||||
issuer="https://auth.example.test",
|
||||
subject="member-1",
|
||||
audience=("innercontext-web",),
|
||||
expires_at=datetime.now(UTC) + timedelta(hours=1),
|
||||
raw_claims={"iss": "https://auth.example.test", "sub": "member-1"},
|
||||
)
|
||||
current_user = CurrentUser(
|
||||
user_id=uuid4(),
|
||||
role=Role.MEMBER,
|
||||
identity=IdentityData.from_claims(claims),
|
||||
claims=claims,
|
||||
household_membership=CurrentHouseholdMembership(
|
||||
household_id=uuid4(),
|
||||
role=HouseholdRole.MEMBER,
|
||||
),
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
require_admin(current_user)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
Loading…
Add table
Add a link
Reference in a new issue