refactor(api): centralize tenant authorization helpers
This commit is contained in:
parent
4782fad5b9
commit
1f47974f48
3 changed files with 512 additions and 0 deletions
173
backend/innercontext/api/authz.py
Normal file
173
backend/innercontext/api/authz.py
Normal file
|
|
@ -0,0 +1,173 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TypeVar, cast
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from innercontext.auth import CurrentUser
|
||||
from innercontext.models import HouseholdMembership, Product, ProductInventory, Role
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
def _not_found(model_name: str) -> HTTPException:
|
||||
return HTTPException(status_code=404, detail=f"{model_name} not found")
|
||||
|
||||
|
||||
def _user_scoped_model_name(model: type[object]) -> str:
|
||||
return getattr(model, "__name__", str(model))
|
||||
|
||||
|
||||
def _record_user_id(model: type[object], record: object) -> object:
|
||||
if not hasattr(record, "user_id"):
|
||||
model_name = _user_scoped_model_name(model)
|
||||
raise TypeError(f"{model_name} does not expose user_id")
|
||||
return cast(object, getattr(record, "user_id"))
|
||||
|
||||
|
||||
def _is_admin(current_user: CurrentUser) -> bool:
|
||||
return current_user.role is Role.ADMIN
|
||||
|
||||
|
||||
def _owner_household_id(session: Session, owner_user_id: UUID) -> UUID | None:
|
||||
membership = session.exec(
|
||||
select(HouseholdMembership).where(HouseholdMembership.user_id == owner_user_id)
|
||||
).first()
|
||||
if membership is None:
|
||||
return None
|
||||
return membership.household_id
|
||||
|
||||
|
||||
def _is_same_household(
|
||||
session: Session,
|
||||
owner_user_id: UUID,
|
||||
current_user: CurrentUser,
|
||||
) -> bool:
|
||||
if current_user.household_membership is None:
|
||||
return False
|
||||
owner_household_id = _owner_household_id(session, owner_user_id)
|
||||
return owner_household_id == current_user.household_membership.household_id
|
||||
|
||||
|
||||
def get_owned_or_404(
|
||||
session: Session,
|
||||
model: type[_T],
|
||||
record_id: object,
|
||||
current_user: CurrentUser,
|
||||
) -> _T:
|
||||
obj = session.get(model, record_id)
|
||||
model_name = _user_scoped_model_name(model)
|
||||
if obj is None:
|
||||
raise _not_found(model_name)
|
||||
if _record_user_id(model, obj) != current_user.user_id:
|
||||
raise _not_found(model_name)
|
||||
return obj
|
||||
|
||||
|
||||
def get_owned_or_404_admin_override(
|
||||
session: Session,
|
||||
model: type[_T],
|
||||
record_id: object,
|
||||
current_user: CurrentUser,
|
||||
) -> _T:
|
||||
obj = session.get(model, record_id)
|
||||
model_name = _user_scoped_model_name(model)
|
||||
if obj is None:
|
||||
raise _not_found(model_name)
|
||||
if _is_admin(current_user):
|
||||
return obj
|
||||
if _record_user_id(model, obj) != current_user.user_id:
|
||||
raise _not_found(model_name)
|
||||
return obj
|
||||
|
||||
|
||||
def list_owned(
|
||||
session: Session, model: type[_T], current_user: CurrentUser
|
||||
) -> list[_T]:
|
||||
model_name = _user_scoped_model_name(model)
|
||||
if not hasattr(model, "user_id"):
|
||||
raise TypeError(f"{model_name} does not expose user_id")
|
||||
records = cast(list[_T], session.exec(select(model)).all())
|
||||
return [
|
||||
record
|
||||
for record in records
|
||||
if _record_user_id(model, record) == current_user.user_id
|
||||
]
|
||||
|
||||
|
||||
def list_owned_admin_override(
|
||||
session: Session,
|
||||
model: type[_T],
|
||||
current_user: CurrentUser,
|
||||
) -> list[_T]:
|
||||
if _is_admin(current_user):
|
||||
statement = select(model)
|
||||
return cast(list[_T], session.exec(statement).all())
|
||||
return list_owned(session, model, current_user)
|
||||
|
||||
|
||||
def check_household_inventory_access(
|
||||
session: Session,
|
||||
inventory_id: UUID,
|
||||
current_user: CurrentUser,
|
||||
) -> ProductInventory:
|
||||
inventory = session.get(ProductInventory, inventory_id)
|
||||
if inventory is None:
|
||||
raise _not_found(ProductInventory.__name__)
|
||||
|
||||
if _is_admin(current_user):
|
||||
return inventory
|
||||
|
||||
owner_user_id = inventory.user_id
|
||||
if owner_user_id == current_user.user_id:
|
||||
return inventory
|
||||
|
||||
if not inventory.is_household_shared or owner_user_id is None:
|
||||
raise _not_found(ProductInventory.__name__)
|
||||
|
||||
if not _is_same_household(session, owner_user_id, current_user):
|
||||
raise _not_found(ProductInventory.__name__)
|
||||
|
||||
return inventory
|
||||
|
||||
|
||||
def can_update_inventory(
|
||||
session: Session,
|
||||
inventory_id: UUID,
|
||||
current_user: CurrentUser,
|
||||
) -> bool:
|
||||
inventory = session.get(ProductInventory, inventory_id)
|
||||
if inventory is None:
|
||||
return False
|
||||
if _is_admin(current_user):
|
||||
return True
|
||||
return inventory.user_id == current_user.user_id
|
||||
|
||||
|
||||
def is_product_visible(
|
||||
session: Session, product_id: UUID, current_user: CurrentUser
|
||||
) -> bool:
|
||||
product = session.get(Product, product_id)
|
||||
if product is None:
|
||||
return False
|
||||
|
||||
if _is_admin(current_user):
|
||||
return True
|
||||
|
||||
if product.user_id == current_user.user_id:
|
||||
return True
|
||||
|
||||
if current_user.household_membership is None:
|
||||
return False
|
||||
|
||||
inventories = session.exec(
|
||||
select(ProductInventory).where(ProductInventory.product_id == product_id)
|
||||
).all()
|
||||
for inventory in inventories:
|
||||
if not inventory.is_household_shared or inventory.user_id is None:
|
||||
continue
|
||||
if _is_same_household(session, inventory.user_id, current_user):
|
||||
return True
|
||||
return False
|
||||
|
|
@ -3,6 +3,18 @@ from typing import TypeVar
|
|||
from fastapi import HTTPException
|
||||
from sqlmodel import Session
|
||||
|
||||
from innercontext.api.authz import (
|
||||
get_owned_or_404 as authz_get_owned_or_404,
|
||||
)
|
||||
from innercontext.api.authz import (
|
||||
get_owned_or_404_admin_override as authz_get_owned_or_404_admin_override,
|
||||
)
|
||||
from innercontext.api.authz import list_owned as authz_list_owned
|
||||
from innercontext.api.authz import (
|
||||
list_owned_admin_override as authz_list_owned_admin_override,
|
||||
)
|
||||
from innercontext.auth import CurrentUser
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
|
|
@ -11,3 +23,37 @@ def get_or_404(session: Session, model: type[_T], record_id: object) -> _T:
|
|||
if obj is None:
|
||||
raise HTTPException(status_code=404, detail=f"{model.__name__} not found")
|
||||
return obj
|
||||
|
||||
|
||||
def get_owned_or_404(
|
||||
session: Session,
|
||||
model: type[_T],
|
||||
record_id: object,
|
||||
current_user: CurrentUser,
|
||||
) -> _T:
|
||||
return authz_get_owned_or_404(session, model, record_id, current_user)
|
||||
|
||||
|
||||
def get_owned_or_404_admin_override(
|
||||
session: Session,
|
||||
model: type[_T],
|
||||
record_id: object,
|
||||
current_user: CurrentUser,
|
||||
) -> _T:
|
||||
return authz_get_owned_or_404_admin_override(
|
||||
session, model, record_id, current_user
|
||||
)
|
||||
|
||||
|
||||
def list_owned(
|
||||
session: Session, model: type[_T], current_user: CurrentUser
|
||||
) -> list[_T]:
|
||||
return authz_list_owned(session, model, current_user)
|
||||
|
||||
|
||||
def list_owned_admin_override(
|
||||
session: Session,
|
||||
model: type[_T],
|
||||
current_user: CurrentUser,
|
||||
) -> list[_T]:
|
||||
return authz_list_owned_admin_override(session, model, current_user)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue