refactor(api): centralize tenant authorization helpers

This commit is contained in:
Piotr Oleszczyk 2026-03-12 15:26:06 +01:00
parent 4782fad5b9
commit 1f47974f48
3 changed files with 512 additions and 0 deletions

View file

@ -0,0 +1,173 @@
from __future__ import annotations
from typing import TypeVar, cast
from uuid import UUID
from fastapi import HTTPException
from sqlmodel import Session, select
from innercontext.auth import CurrentUser
from innercontext.models import HouseholdMembership, Product, ProductInventory, Role
_T = TypeVar("_T")
def _not_found(model_name: str) -> HTTPException:
return HTTPException(status_code=404, detail=f"{model_name} not found")
def _user_scoped_model_name(model: type[object]) -> str:
return getattr(model, "__name__", str(model))
def _record_user_id(model: type[object], record: object) -> object:
if not hasattr(record, "user_id"):
model_name = _user_scoped_model_name(model)
raise TypeError(f"{model_name} does not expose user_id")
return cast(object, getattr(record, "user_id"))
def _is_admin(current_user: CurrentUser) -> bool:
return current_user.role is Role.ADMIN
def _owner_household_id(session: Session, owner_user_id: UUID) -> UUID | None:
membership = session.exec(
select(HouseholdMembership).where(HouseholdMembership.user_id == owner_user_id)
).first()
if membership is None:
return None
return membership.household_id
def _is_same_household(
session: Session,
owner_user_id: UUID,
current_user: CurrentUser,
) -> bool:
if current_user.household_membership is None:
return False
owner_household_id = _owner_household_id(session, owner_user_id)
return owner_household_id == current_user.household_membership.household_id
def get_owned_or_404(
session: Session,
model: type[_T],
record_id: object,
current_user: CurrentUser,
) -> _T:
obj = session.get(model, record_id)
model_name = _user_scoped_model_name(model)
if obj is None:
raise _not_found(model_name)
if _record_user_id(model, obj) != current_user.user_id:
raise _not_found(model_name)
return obj
def get_owned_or_404_admin_override(
session: Session,
model: type[_T],
record_id: object,
current_user: CurrentUser,
) -> _T:
obj = session.get(model, record_id)
model_name = _user_scoped_model_name(model)
if obj is None:
raise _not_found(model_name)
if _is_admin(current_user):
return obj
if _record_user_id(model, obj) != current_user.user_id:
raise _not_found(model_name)
return obj
def list_owned(
session: Session, model: type[_T], current_user: CurrentUser
) -> list[_T]:
model_name = _user_scoped_model_name(model)
if not hasattr(model, "user_id"):
raise TypeError(f"{model_name} does not expose user_id")
records = cast(list[_T], session.exec(select(model)).all())
return [
record
for record in records
if _record_user_id(model, record) == current_user.user_id
]
def list_owned_admin_override(
session: Session,
model: type[_T],
current_user: CurrentUser,
) -> list[_T]:
if _is_admin(current_user):
statement = select(model)
return cast(list[_T], session.exec(statement).all())
return list_owned(session, model, current_user)
def check_household_inventory_access(
session: Session,
inventory_id: UUID,
current_user: CurrentUser,
) -> ProductInventory:
inventory = session.get(ProductInventory, inventory_id)
if inventory is None:
raise _not_found(ProductInventory.__name__)
if _is_admin(current_user):
return inventory
owner_user_id = inventory.user_id
if owner_user_id == current_user.user_id:
return inventory
if not inventory.is_household_shared or owner_user_id is None:
raise _not_found(ProductInventory.__name__)
if not _is_same_household(session, owner_user_id, current_user):
raise _not_found(ProductInventory.__name__)
return inventory
def can_update_inventory(
session: Session,
inventory_id: UUID,
current_user: CurrentUser,
) -> bool:
inventory = session.get(ProductInventory, inventory_id)
if inventory is None:
return False
if _is_admin(current_user):
return True
return inventory.user_id == current_user.user_id
def is_product_visible(
session: Session, product_id: UUID, current_user: CurrentUser
) -> bool:
product = session.get(Product, product_id)
if product is None:
return False
if _is_admin(current_user):
return True
if product.user_id == current_user.user_id:
return True
if current_user.household_membership is None:
return False
inventories = session.exec(
select(ProductInventory).where(ProductInventory.product_id == product_id)
).all()
for inventory in inventories:
if not inventory.is_household_shared or inventory.user_id is None:
continue
if _is_same_household(session, inventory.user_id, current_user):
return True
return False

View file

@ -3,6 +3,18 @@ from typing import TypeVar
from fastapi import HTTPException
from sqlmodel import Session
from innercontext.api.authz import (
get_owned_or_404 as authz_get_owned_or_404,
)
from innercontext.api.authz import (
get_owned_or_404_admin_override as authz_get_owned_or_404_admin_override,
)
from innercontext.api.authz import list_owned as authz_list_owned
from innercontext.api.authz import (
list_owned_admin_override as authz_list_owned_admin_override,
)
from innercontext.auth import CurrentUser
_T = TypeVar("_T")
@ -11,3 +23,37 @@ def get_or_404(session: Session, model: type[_T], record_id: object) -> _T:
if obj is None:
raise HTTPException(status_code=404, detail=f"{model.__name__} not found")
return obj
def get_owned_or_404(
session: Session,
model: type[_T],
record_id: object,
current_user: CurrentUser,
) -> _T:
return authz_get_owned_or_404(session, model, record_id, current_user)
def get_owned_or_404_admin_override(
session: Session,
model: type[_T],
record_id: object,
current_user: CurrentUser,
) -> _T:
return authz_get_owned_or_404_admin_override(
session, model, record_id, current_user
)
def list_owned(
session: Session, model: type[_T], current_user: CurrentUser
) -> list[_T]:
return authz_list_owned(session, model, current_user)
def list_owned_admin_override(
session: Session,
model: type[_T],
current_user: CurrentUser,
) -> list[_T]:
return authz_list_owned_admin_override(session, model, current_user)