innercontext/backend/tests/test_auth.py

275 lines
8.2 KiB
Python

from __future__ import annotations
import json
from datetime import UTC, datetime, timedelta
from uuid import UUID, uuid4
import jwt
import pytest
from cryptography.hazmat.primitives.asymmetric import rsa
from fastapi import HTTPException
from fastapi.testclient import TestClient
from jwt import algorithms
from sqlmodel import Session, SQLModel, create_engine
from sqlmodel.pool import StaticPool
import db as db_module
from db import get_session
from innercontext.api.auth_deps import require_admin
from innercontext.auth import (
CurrentHouseholdMembership,
CurrentUser,
IdentityData,
TokenClaims,
reset_auth_caches,
validate_access_token,
)
from innercontext.models import (
Household,
HouseholdMembership,
HouseholdRole,
Role,
User,
)
from main import app
class _MockResponse:
def __init__(self, payload: dict[str, object], status_code: int = 200):
self._payload = payload
self.status_code = status_code
def raise_for_status(self) -> None:
if self.status_code >= 400:
raise RuntimeError(f"unexpected status {self.status_code}")
def json(self) -> dict[str, object]:
return self._payload
@pytest.fixture()
def auth_env(monkeypatch):
monkeypatch.setenv("OIDC_ISSUER", "https://auth.example.test")
monkeypatch.setenv("OIDC_CLIENT_ID", "innercontext-web")
monkeypatch.setenv(
"OIDC_DISCOVERY_URL",
"https://auth.example.test/.well-known/openid-configuration",
)
monkeypatch.setenv("OIDC_ADMIN_GROUPS", "innercontext-admin")
monkeypatch.setenv("OIDC_MEMBER_GROUPS", "innercontext-member")
monkeypatch.setenv("OIDC_JWKS_CACHE_TTL_SECONDS", "3600")
reset_auth_caches()
yield
reset_auth_caches()
@pytest.fixture()
def rsa_keypair():
private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048,
)
return private_key, private_key.public_key()
@pytest.fixture()
def auth_session(monkeypatch):
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
monkeypatch.setattr(db_module, "engine", engine)
import innercontext.models # noqa: F401
SQLModel.metadata.create_all(engine)
with Session(engine) as session:
yield session
@pytest.fixture()
def auth_client(auth_session):
def _override():
yield auth_session
app.dependency_overrides[get_session] = _override
with TestClient(app) as client:
yield client
app.dependency_overrides.clear()
def _public_jwk(public_key, kid: str) -> dict[str, object]:
jwk = json.loads(algorithms.RSAAlgorithm.to_jwk(public_key))
jwk["kid"] = kid
jwk["use"] = "sig"
jwk["alg"] = "RS256"
return jwk
def _sign_token(private_key, kid: str, **claims_overrides: object) -> str:
now = datetime.now(UTC)
payload: dict[str, object] = {
"iss": "https://auth.example.test",
"sub": "user-123",
"aud": "innercontext-web",
"exp": int((now + timedelta(hours=1)).timestamp()),
"iat": int(now.timestamp()),
"groups": ["innercontext-admin"],
"email": "user@example.test",
"name": "Inner Context User",
"preferred_username": "ictx-user",
}
payload.update(claims_overrides)
return jwt.encode(payload, private_key, algorithm="RS256", headers={"kid": kid})
def _mock_oidc(monkeypatch, public_key, *, fetch_counts: dict[str, int] | None = None):
def _fake_get(url: str, timeout: float):
if fetch_counts is not None:
fetch_counts[url] = fetch_counts.get(url, 0) + 1
if url.endswith("/.well-known/openid-configuration"):
return _MockResponse({"jwks_uri": "https://auth.example.test/jwks.json"})
if url.endswith("/jwks.json"):
return _MockResponse({"keys": [_public_jwk(public_key, "kid-1")]})
raise AssertionError(f"unexpected URL {url} with timeout {timeout}")
monkeypatch.setattr("innercontext.auth.httpx.get", _fake_get)
def test_validate_access_token_uses_cached_jwks(auth_env, rsa_keypair, monkeypatch):
private_key, public_key = rsa_keypair
fetch_counts: dict[str, int] = {}
_mock_oidc(monkeypatch, public_key, fetch_counts=fetch_counts)
validate_access_token(_sign_token(private_key, "kid-1", sub="user-a"))
validate_access_token(_sign_token(private_key, "kid-1", sub="user-b"))
assert (
fetch_counts["https://auth.example.test/.well-known/openid-configuration"] == 1
)
assert fetch_counts["https://auth.example.test/jwks.json"] == 1
@pytest.mark.parametrize(
("path", "payload"),
[
(
"/auth/session/sync",
{
"email": "sync@example.test",
"name": "Synced User",
"preferred_username": "synced-user",
"groups": ["innercontext-admin"],
},
),
("/auth/me", None),
],
ids=["/auth/session/sync", "/auth/me"],
)
def test_sync_protected_endpoints_create_or_resolve_current_user(
auth_env,
auth_client,
auth_session,
rsa_keypair,
monkeypatch,
path: str,
payload: dict[str, object] | None,
):
private_key, public_key = rsa_keypair
_mock_oidc(monkeypatch, public_key)
token = _sign_token(private_key, "kid-1")
if path == "/auth/me":
user = User(
oidc_issuer="https://auth.example.test",
oidc_subject="user-123",
role=Role.ADMIN,
)
auth_session.add(user)
auth_session.commit()
auth_session.refresh(user)
household = Household()
auth_session.add(household)
auth_session.commit()
auth_session.refresh(household)
membership = HouseholdMembership(
user_id=user.id,
household_id=household.id,
role=HouseholdRole.OWNER,
)
auth_session.add(membership)
auth_session.commit()
response = auth_client.request(
"POST" if path.endswith("sync") else "GET",
path,
headers={"Authorization": f"Bearer {token}"},
json=payload,
)
assert response.status_code == 200
data = response.json()
assert data["user"]["role"] == "admin"
assert data["identity"]["issuer"] == "https://auth.example.test"
assert data["identity"]["subject"] == "user-123"
synced_user = auth_session.get(User, UUID(data["user"]["id"]))
assert synced_user is not None
assert synced_user.oidc_issuer == "https://auth.example.test"
assert synced_user.oidc_subject == "user-123"
if path == "/auth/session/sync":
assert data["identity"]["email"] == "sync@example.test"
assert data["identity"]["groups"] == ["innercontext-admin"]
else:
assert data["user"]["household_membership"]["role"] == "owner"
@pytest.mark.parametrize(
"path",
["/auth/me", "/profile"],
ids=["/auth/me expects 401", "/profile expects 401"],
)
def test_unauthorized_protected_endpoints_return_401(auth_env, auth_client, path: str):
response = auth_client.get(path)
assert response.status_code == 401
assert response.json()["detail"] == "Missing bearer token"
def test_unauthorized_invalid_bearer_token_is_rejected(
auth_env, auth_client, rsa_keypair, monkeypatch
):
_, public_key = rsa_keypair
_mock_oidc(monkeypatch, public_key)
response = auth_client.get(
"/auth/me",
headers={"Authorization": "Bearer not-a-jwt"},
)
assert response.status_code == 401
def test_require_admin_raises_for_member():
claims = TokenClaims(
issuer="https://auth.example.test",
subject="member-1",
audience=("innercontext-web",),
expires_at=datetime.now(UTC) + timedelta(hours=1),
raw_claims={"iss": "https://auth.example.test", "sub": "member-1"},
)
current_user = CurrentUser(
user_id=uuid4(),
role=Role.MEMBER,
identity=IdentityData.from_claims(claims),
claims=claims,
household_membership=CurrentHouseholdMembership(
household_id=uuid4(),
role=HouseholdRole.MEMBER,
),
)
with pytest.raises(HTTPException) as exc_info:
require_admin(current_user)
assert exc_info.value.status_code == 403