refactor(api): centralize tenant authorization helpers

This commit is contained in:
Piotr Oleszczyk 2026-03-12 15:26:06 +01:00
parent 4782fad5b9
commit 1f47974f48
3 changed files with 512 additions and 0 deletions

View file

@ -0,0 +1,173 @@
from __future__ import annotations
from typing import TypeVar, cast
from uuid import UUID
from fastapi import HTTPException
from sqlmodel import Session, select
from innercontext.auth import CurrentUser
from innercontext.models import HouseholdMembership, Product, ProductInventory, Role
_T = TypeVar("_T")
def _not_found(model_name: str) -> HTTPException:
return HTTPException(status_code=404, detail=f"{model_name} not found")
def _user_scoped_model_name(model: type[object]) -> str:
return getattr(model, "__name__", str(model))
def _record_user_id(model: type[object], record: object) -> object:
if not hasattr(record, "user_id"):
model_name = _user_scoped_model_name(model)
raise TypeError(f"{model_name} does not expose user_id")
return cast(object, getattr(record, "user_id"))
def _is_admin(current_user: CurrentUser) -> bool:
return current_user.role is Role.ADMIN
def _owner_household_id(session: Session, owner_user_id: UUID) -> UUID | None:
membership = session.exec(
select(HouseholdMembership).where(HouseholdMembership.user_id == owner_user_id)
).first()
if membership is None:
return None
return membership.household_id
def _is_same_household(
session: Session,
owner_user_id: UUID,
current_user: CurrentUser,
) -> bool:
if current_user.household_membership is None:
return False
owner_household_id = _owner_household_id(session, owner_user_id)
return owner_household_id == current_user.household_membership.household_id
def get_owned_or_404(
session: Session,
model: type[_T],
record_id: object,
current_user: CurrentUser,
) -> _T:
obj = session.get(model, record_id)
model_name = _user_scoped_model_name(model)
if obj is None:
raise _not_found(model_name)
if _record_user_id(model, obj) != current_user.user_id:
raise _not_found(model_name)
return obj
def get_owned_or_404_admin_override(
session: Session,
model: type[_T],
record_id: object,
current_user: CurrentUser,
) -> _T:
obj = session.get(model, record_id)
model_name = _user_scoped_model_name(model)
if obj is None:
raise _not_found(model_name)
if _is_admin(current_user):
return obj
if _record_user_id(model, obj) != current_user.user_id:
raise _not_found(model_name)
return obj
def list_owned(
session: Session, model: type[_T], current_user: CurrentUser
) -> list[_T]:
model_name = _user_scoped_model_name(model)
if not hasattr(model, "user_id"):
raise TypeError(f"{model_name} does not expose user_id")
records = cast(list[_T], session.exec(select(model)).all())
return [
record
for record in records
if _record_user_id(model, record) == current_user.user_id
]
def list_owned_admin_override(
session: Session,
model: type[_T],
current_user: CurrentUser,
) -> list[_T]:
if _is_admin(current_user):
statement = select(model)
return cast(list[_T], session.exec(statement).all())
return list_owned(session, model, current_user)
def check_household_inventory_access(
session: Session,
inventory_id: UUID,
current_user: CurrentUser,
) -> ProductInventory:
inventory = session.get(ProductInventory, inventory_id)
if inventory is None:
raise _not_found(ProductInventory.__name__)
if _is_admin(current_user):
return inventory
owner_user_id = inventory.user_id
if owner_user_id == current_user.user_id:
return inventory
if not inventory.is_household_shared or owner_user_id is None:
raise _not_found(ProductInventory.__name__)
if not _is_same_household(session, owner_user_id, current_user):
raise _not_found(ProductInventory.__name__)
return inventory
def can_update_inventory(
session: Session,
inventory_id: UUID,
current_user: CurrentUser,
) -> bool:
inventory = session.get(ProductInventory, inventory_id)
if inventory is None:
return False
if _is_admin(current_user):
return True
return inventory.user_id == current_user.user_id
def is_product_visible(
session: Session, product_id: UUID, current_user: CurrentUser
) -> bool:
product = session.get(Product, product_id)
if product is None:
return False
if _is_admin(current_user):
return True
if product.user_id == current_user.user_id:
return True
if current_user.household_membership is None:
return False
inventories = session.exec(
select(ProductInventory).where(ProductInventory.product_id == product_id)
).all()
for inventory in inventories:
if not inventory.is_household_shared or inventory.user_id is None:
continue
if _is_same_household(session, inventory.user_id, current_user):
return True
return False

View file

@ -3,6 +3,18 @@ from typing import TypeVar
from fastapi import HTTPException
from sqlmodel import Session
from innercontext.api.authz import (
get_owned_or_404 as authz_get_owned_or_404,
)
from innercontext.api.authz import (
get_owned_or_404_admin_override as authz_get_owned_or_404_admin_override,
)
from innercontext.api.authz import list_owned as authz_list_owned
from innercontext.api.authz import (
list_owned_admin_override as authz_list_owned_admin_override,
)
from innercontext.auth import CurrentUser
_T = TypeVar("_T")
@ -11,3 +23,37 @@ def get_or_404(session: Session, model: type[_T], record_id: object) -> _T:
if obj is None:
raise HTTPException(status_code=404, detail=f"{model.__name__} not found")
return obj
def get_owned_or_404(
session: Session,
model: type[_T],
record_id: object,
current_user: CurrentUser,
) -> _T:
return authz_get_owned_or_404(session, model, record_id, current_user)
def get_owned_or_404_admin_override(
session: Session,
model: type[_T],
record_id: object,
current_user: CurrentUser,
) -> _T:
return authz_get_owned_or_404_admin_override(
session, model, record_id, current_user
)
def list_owned(
session: Session, model: type[_T], current_user: CurrentUser
) -> list[_T]:
return authz_list_owned(session, model, current_user)
def list_owned_admin_override(
session: Session,
model: type[_T],
current_user: CurrentUser,
) -> list[_T]:
return authz_list_owned_admin_override(session, model, current_user)

293
backend/tests/test_authz.py Normal file
View file

@ -0,0 +1,293 @@
from __future__ import annotations
from datetime import UTC, datetime, timedelta
from uuid import UUID, uuid4
import pytest
from fastapi import HTTPException
from sqlmodel import Session
from innercontext.api.authz import (
can_update_inventory,
check_household_inventory_access,
get_owned_or_404,
get_owned_or_404_admin_override,
is_product_visible,
list_owned,
list_owned_admin_override,
)
from innercontext.auth import (
CurrentHouseholdMembership,
CurrentUser,
IdentityData,
TokenClaims,
)
from innercontext.models import (
Household,
HouseholdMembership,
HouseholdRole,
DayTime,
MedicationEntry,
MedicationKind,
Product,
ProductCategory,
ProductInventory,
Role,
)
def _claims(subject: str) -> TokenClaims:
return TokenClaims(
issuer="https://auth.example.test",
subject=subject,
audience=("innercontext-web",),
expires_at=datetime.now(UTC) + timedelta(hours=1),
raw_claims={"iss": "https://auth.example.test", "sub": subject},
)
def _current_user(
user_id: UUID,
*,
role: Role = Role.MEMBER,
household_id: UUID | None = None,
) -> CurrentUser:
claims = _claims(str(user_id))
membership = None
if household_id is not None:
membership = CurrentHouseholdMembership(
household_id=household_id,
role=HouseholdRole.MEMBER,
)
return CurrentUser(
user_id=user_id,
role=role,
identity=IdentityData.from_claims(claims),
claims=claims,
household_membership=membership,
)
def _create_household(session: Session) -> Household:
household = Household()
session.add(household)
session.commit()
session.refresh(household)
return household
def _create_membership(
session: Session, user_id: UUID, household_id: UUID
) -> HouseholdMembership:
membership = HouseholdMembership(user_id=user_id, household_id=household_id)
session.add(membership)
session.commit()
session.refresh(membership)
return membership
def _create_medication(session: Session, user_id: UUID) -> MedicationEntry:
entry = MedicationEntry(
user_id=user_id,
kind=MedicationKind.PRESCRIPTION,
product_name="Test medication",
)
session.add(entry)
session.commit()
session.refresh(entry)
return entry
def _create_product(session: Session, user_id: UUID, short_id: str) -> Product:
product = Product(
user_id=user_id,
short_id=short_id,
name="Shared product",
brand="Test brand",
category=ProductCategory.MOISTURIZER,
recommended_time=DayTime.BOTH,
leave_on=True,
)
setattr(product, "product_effect_profile", {})
session.add(product)
session.commit()
session.refresh(product)
return product
def _create_inventory(
session: Session,
*,
user_id: UUID,
product_id: UUID,
is_household_shared: bool,
) -> ProductInventory:
inventory = ProductInventory(
user_id=user_id,
product_id=product_id,
is_household_shared=is_household_shared,
)
session.add(inventory)
session.commit()
session.refresh(inventory)
return inventory
def test_owner_helpers_return_only_owned_records(session: Session):
owner_id = uuid4()
other_id = uuid4()
owner_user = _current_user(owner_id)
owner_entry = _create_medication(session, owner_id)
_ = _create_medication(session, other_id)
fetched = get_owned_or_404(
session, MedicationEntry, owner_entry.record_id, owner_user
)
owned_entries = list_owned(session, MedicationEntry, owner_user)
assert fetched.record_id == owner_entry.record_id
assert len(owned_entries) == 1
assert owned_entries[0].user_id == owner_id
def test_admin_helpers_allow_admin_override_for_lookup_and_list(session: Session):
owner_id = uuid4()
admin_user = _current_user(uuid4(), role=Role.ADMIN)
owner_entry = _create_medication(session, owner_id)
fetched = get_owned_or_404_admin_override(
session,
MedicationEntry,
owner_entry.record_id,
admin_user,
)
listed = list_owned_admin_override(session, MedicationEntry, admin_user)
assert fetched.record_id == owner_entry.record_id
assert len(listed) == 1
def test_owner_denied_for_non_owned_lookup_returns_404(session: Session):
owner_id = uuid4()
intruder = _current_user(uuid4())
owner_entry = _create_medication(session, owner_id)
with pytest.raises(HTTPException) as exc_info:
_ = get_owned_or_404(session, MedicationEntry, owner_entry.record_id, intruder)
assert exc_info.value.status_code == 404
def test_household_shared_inventory_access_allows_same_household_member(
session: Session,
):
owner_id = uuid4()
household_member_id = uuid4()
household = _create_household(session)
_ = _create_membership(session, owner_id, household.id)
_ = _create_membership(session, household_member_id, household.id)
product = _create_product(session, owner_id, short_id="abcd0001")
inventory = _create_inventory(
session,
user_id=owner_id,
product_id=product.id,
is_household_shared=True,
)
current_user = _current_user(household_member_id, household_id=household.id)
fetched = check_household_inventory_access(session, inventory.id, current_user)
assert fetched.id == inventory.id
def test_household_shared_inventory_denied_for_cross_household_member(session: Session):
owner_id = uuid4()
outsider_id = uuid4()
owner_household = _create_household(session)
outsider_household = _create_household(session)
_ = _create_membership(session, owner_id, owner_household.id)
_ = _create_membership(session, outsider_id, outsider_household.id)
product = _create_product(session, owner_id, short_id="abcd0002")
inventory = _create_inventory(
session,
user_id=owner_id,
product_id=product.id,
is_household_shared=True,
)
outsider = _current_user(outsider_id, household_id=outsider_household.id)
with pytest.raises(HTTPException) as exc_info:
_ = check_household_inventory_access(session, inventory.id, outsider)
assert exc_info.value.status_code == 404
def test_household_inventory_update_rules_owner_admin_and_member(session: Session):
owner_id = uuid4()
member_id = uuid4()
household = _create_household(session)
_ = _create_membership(session, owner_id, household.id)
_ = _create_membership(session, member_id, household.id)
product = _create_product(session, owner_id, short_id="abcd0003")
inventory = _create_inventory(
session,
user_id=owner_id,
product_id=product.id,
is_household_shared=True,
)
owner = _current_user(owner_id, household_id=household.id)
admin = _current_user(uuid4(), role=Role.ADMIN)
member = _current_user(member_id, household_id=household.id)
assert can_update_inventory(session, inventory.id, owner) is True
assert can_update_inventory(session, inventory.id, admin) is True
assert can_update_inventory(session, inventory.id, member) is False
def test_product_visibility_for_owner_admin_and_household_shared(session: Session):
owner_id = uuid4()
member_id = uuid4()
household = _create_household(session)
_ = _create_membership(session, owner_id, household.id)
_ = _create_membership(session, member_id, household.id)
product = _create_product(session, owner_id, short_id="abcd0004")
_ = _create_inventory(
session,
user_id=owner_id,
product_id=product.id,
is_household_shared=True,
)
owner = _current_user(owner_id, household_id=household.id)
admin = _current_user(uuid4(), role=Role.ADMIN)
member = _current_user(member_id, household_id=household.id)
assert is_product_visible(session, product.id, owner) is True
assert is_product_visible(session, product.id, admin) is True
assert is_product_visible(session, product.id, member) is True
def test_product_visibility_denied_for_cross_household_member(session: Session):
owner_id = uuid4()
outsider_id = uuid4()
owner_household = _create_household(session)
outsider_household = _create_household(session)
_ = _create_membership(session, owner_id, owner_household.id)
_ = _create_membership(session, outsider_id, outsider_household.id)
product = _create_product(session, owner_id, short_id="abcd0005")
_ = _create_inventory(
session,
user_id=owner_id,
product_id=product.id,
is_household_shared=True,
)
outsider = _current_user(outsider_id, household_id=outsider_household.id)
assert is_product_visible(session, product.id, outsider) is False