feat(api): enforce ownership across health routines and profile flows
This commit is contained in:
parent
cd8e39939a
commit
ffa3b71309
14 changed files with 1225 additions and 206 deletions
|
|
@ -2,10 +2,13 @@ import json
|
|||
from typing import Any, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlmodel import Session, SQLModel, col, select
|
||||
|
||||
from db import get_session
|
||||
from innercontext.api.auth_deps import get_current_user
|
||||
from innercontext.auth import CurrentUser
|
||||
from innercontext.models.enums import Role
|
||||
from innercontext.models.ai_log import AICallLog
|
||||
|
||||
router = APIRouter()
|
||||
|
|
@ -43,14 +46,33 @@ class AICallLogPublic(SQLModel):
|
|||
error_detail: Optional[str] = None
|
||||
|
||||
|
||||
def _resolve_target_user_id(
|
||||
current_user: CurrentUser,
|
||||
user_id: UUID | None,
|
||||
) -> UUID:
|
||||
if user_id is None:
|
||||
return current_user.user_id
|
||||
if current_user.role is not Role.ADMIN:
|
||||
raise HTTPException(status_code=403, detail="Admin role required")
|
||||
return user_id
|
||||
|
||||
|
||||
@router.get("", response_model=list[AICallLogPublic])
|
||||
def list_ai_logs(
|
||||
endpoint: Optional[str] = None,
|
||||
success: Optional[bool] = None,
|
||||
limit: int = 50,
|
||||
user_id: UUID | None = Query(default=None),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
stmt = select(AICallLog).order_by(col(AICallLog.created_at).desc()).limit(limit)
|
||||
target_user_id = _resolve_target_user_id(current_user, user_id)
|
||||
stmt = (
|
||||
select(AICallLog)
|
||||
.where(AICallLog.user_id == target_user_id)
|
||||
.order_by(col(AICallLog.created_at).desc())
|
||||
.limit(limit)
|
||||
)
|
||||
if endpoint is not None:
|
||||
stmt = stmt.where(AICallLog.endpoint == endpoint)
|
||||
if success is not None:
|
||||
|
|
@ -75,9 +97,17 @@ def list_ai_logs(
|
|||
|
||||
|
||||
@router.get("/{log_id}", response_model=AICallLog)
|
||||
def get_ai_log(log_id: UUID, session: Session = Depends(get_session)):
|
||||
def get_ai_log(
|
||||
log_id: UUID,
|
||||
user_id: UUID | None = Query(default=None),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
target_user_id = _resolve_target_user_id(current_user, user_id)
|
||||
log = session.get(AICallLog, log_id)
|
||||
if log is None:
|
||||
raise HTTPException(status_code=404, detail="Log not found")
|
||||
if log.user_id != target_user_id:
|
||||
raise HTTPException(status_code=404, detail="Log not found")
|
||||
log.tool_trace = _normalize_tool_trace(getattr(log, "tool_trace", None))
|
||||
return log
|
||||
|
|
|
|||
|
|
@ -3,15 +3,17 @@ from datetime import datetime
|
|||
from typing import Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import field_validator
|
||||
from sqlalchemy import Integer, cast, func, or_
|
||||
from sqlmodel import Session, SQLModel, col, select
|
||||
|
||||
from db import get_session
|
||||
from innercontext.api.utils import get_or_404
|
||||
from innercontext.api.auth_deps import get_current_user
|
||||
from innercontext.api.utils import get_owned_or_404
|
||||
from innercontext.auth import CurrentUser
|
||||
from innercontext.models import LabResult, MedicationEntry, MedicationUsage
|
||||
from innercontext.models.enums import MedicationKind, ResultFlag
|
||||
from innercontext.models.enums import MedicationKind, ResultFlag, Role
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
|
@ -133,6 +135,34 @@ class LabResultListResponse(SQLModel):
|
|||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _resolve_target_user_id(
|
||||
current_user: CurrentUser,
|
||||
user_id: UUID | None,
|
||||
) -> UUID:
|
||||
if user_id is None:
|
||||
return current_user.user_id
|
||||
if current_user.role is not Role.ADMIN:
|
||||
raise HTTPException(status_code=403, detail="Admin role required")
|
||||
return user_id
|
||||
|
||||
|
||||
def _get_owned_or_admin_override(
|
||||
session: Session,
|
||||
model: type[MedicationEntry] | type[MedicationUsage] | type[LabResult],
|
||||
record_id: UUID,
|
||||
current_user: CurrentUser,
|
||||
user_id: UUID | None,
|
||||
):
|
||||
if user_id is None:
|
||||
return get_owned_or_404(session, model, record_id, current_user)
|
||||
record = session.get(model, record_id)
|
||||
if record is None or record.user_id != _resolve_target_user_id(
|
||||
current_user, user_id
|
||||
):
|
||||
raise HTTPException(status_code=404, detail=f"{model.__name__} not found")
|
||||
return record
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Medication routes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -142,9 +172,12 @@ class LabResultListResponse(SQLModel):
|
|||
def list_medications(
|
||||
kind: Optional[MedicationKind] = None,
|
||||
product_name: Optional[str] = None,
|
||||
user_id: UUID | None = Query(default=None),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
stmt = select(MedicationEntry)
|
||||
target_user_id = _resolve_target_user_id(current_user, user_id)
|
||||
stmt = select(MedicationEntry).where(MedicationEntry.user_id == target_user_id)
|
||||
if kind is not None:
|
||||
stmt = stmt.where(MedicationEntry.kind == kind)
|
||||
if product_name is not None:
|
||||
|
|
@ -153,8 +186,18 @@ def list_medications(
|
|||
|
||||
|
||||
@router.post("/medications", response_model=MedicationEntry, status_code=201)
|
||||
def create_medication(data: MedicationCreate, session: Session = Depends(get_session)):
|
||||
entry = MedicationEntry(record_id=uuid4(), **data.model_dump())
|
||||
def create_medication(
|
||||
data: MedicationCreate,
|
||||
user_id: UUID | None = Query(default=None),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
target_user_id = _resolve_target_user_id(current_user, user_id)
|
||||
entry = MedicationEntry(
|
||||
record_id=uuid4(),
|
||||
user_id=target_user_id,
|
||||
**data.model_dump(),
|
||||
)
|
||||
session.add(entry)
|
||||
session.commit()
|
||||
session.refresh(entry)
|
||||
|
|
@ -162,17 +205,36 @@ def create_medication(data: MedicationCreate, session: Session = Depends(get_ses
|
|||
|
||||
|
||||
@router.get("/medications/{medication_id}", response_model=MedicationEntry)
|
||||
def get_medication(medication_id: UUID, session: Session = Depends(get_session)):
|
||||
return get_or_404(session, MedicationEntry, medication_id)
|
||||
def get_medication(
|
||||
medication_id: UUID,
|
||||
user_id: UUID | None = Query(default=None),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
return _get_owned_or_admin_override(
|
||||
session,
|
||||
MedicationEntry,
|
||||
medication_id,
|
||||
current_user,
|
||||
user_id,
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/medications/{medication_id}", response_model=MedicationEntry)
|
||||
def update_medication(
|
||||
medication_id: UUID,
|
||||
data: MedicationUpdate,
|
||||
user_id: UUID | None = Query(default=None),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
entry = get_or_404(session, MedicationEntry, medication_id)
|
||||
entry = _get_owned_or_admin_override(
|
||||
session,
|
||||
MedicationEntry,
|
||||
medication_id,
|
||||
current_user,
|
||||
user_id,
|
||||
)
|
||||
for key, value in data.model_dump(exclude_unset=True).items():
|
||||
setattr(entry, key, value)
|
||||
session.add(entry)
|
||||
|
|
@ -182,13 +244,25 @@ def update_medication(
|
|||
|
||||
|
||||
@router.delete("/medications/{medication_id}", status_code=204)
|
||||
def delete_medication(medication_id: UUID, session: Session = Depends(get_session)):
|
||||
entry = get_or_404(session, MedicationEntry, medication_id)
|
||||
def delete_medication(
|
||||
medication_id: UUID,
|
||||
user_id: UUID | None = Query(default=None),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
target_user_id = _resolve_target_user_id(current_user, user_id)
|
||||
entry = _get_owned_or_admin_override(
|
||||
session,
|
||||
MedicationEntry,
|
||||
medication_id,
|
||||
current_user,
|
||||
user_id,
|
||||
)
|
||||
# Delete usages first (no cascade configured at DB level)
|
||||
usages = session.exec(
|
||||
select(MedicationUsage).where(
|
||||
MedicationUsage.medication_record_id == medication_id
|
||||
)
|
||||
select(MedicationUsage)
|
||||
.where(MedicationUsage.medication_record_id == medication_id)
|
||||
.where(MedicationUsage.user_id == target_user_id)
|
||||
).all()
|
||||
for u in usages:
|
||||
session.delete(u)
|
||||
|
|
@ -202,10 +276,24 @@ def delete_medication(medication_id: UUID, session: Session = Depends(get_sessio
|
|||
|
||||
|
||||
@router.get("/medications/{medication_id}/usages", response_model=list[MedicationUsage])
|
||||
def list_usages(medication_id: UUID, session: Session = Depends(get_session)):
|
||||
get_or_404(session, MedicationEntry, medication_id)
|
||||
stmt = select(MedicationUsage).where(
|
||||
MedicationUsage.medication_record_id == medication_id
|
||||
def list_usages(
|
||||
medication_id: UUID,
|
||||
user_id: UUID | None = Query(default=None),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
target_user_id = _resolve_target_user_id(current_user, user_id)
|
||||
_ = _get_owned_or_admin_override(
|
||||
session,
|
||||
MedicationEntry,
|
||||
medication_id,
|
||||
current_user,
|
||||
user_id,
|
||||
)
|
||||
stmt = (
|
||||
select(MedicationUsage)
|
||||
.where(MedicationUsage.medication_record_id == medication_id)
|
||||
.where(MedicationUsage.user_id == target_user_id)
|
||||
)
|
||||
return session.exec(stmt).all()
|
||||
|
||||
|
|
@ -218,11 +306,21 @@ def list_usages(medication_id: UUID, session: Session = Depends(get_session)):
|
|||
def create_usage(
|
||||
medication_id: UUID,
|
||||
data: UsageCreate,
|
||||
user_id: UUID | None = Query(default=None),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
get_or_404(session, MedicationEntry, medication_id)
|
||||
target_user_id = _resolve_target_user_id(current_user, user_id)
|
||||
_ = _get_owned_or_admin_override(
|
||||
session,
|
||||
MedicationEntry,
|
||||
medication_id,
|
||||
current_user,
|
||||
user_id,
|
||||
)
|
||||
usage = MedicationUsage(
|
||||
record_id=uuid4(),
|
||||
user_id=target_user_id,
|
||||
medication_record_id=medication_id,
|
||||
**data.model_dump(),
|
||||
)
|
||||
|
|
@ -236,9 +334,17 @@ def create_usage(
|
|||
def update_usage(
|
||||
usage_id: UUID,
|
||||
data: UsageUpdate,
|
||||
user_id: UUID | None = Query(default=None),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
usage = get_or_404(session, MedicationUsage, usage_id)
|
||||
usage = _get_owned_or_admin_override(
|
||||
session,
|
||||
MedicationUsage,
|
||||
usage_id,
|
||||
current_user,
|
||||
user_id,
|
||||
)
|
||||
for key, value in data.model_dump(exclude_unset=True).items():
|
||||
setattr(usage, key, value)
|
||||
session.add(usage)
|
||||
|
|
@ -248,8 +354,19 @@ def update_usage(
|
|||
|
||||
|
||||
@router.delete("/usages/{usage_id}", status_code=204)
|
||||
def delete_usage(usage_id: UUID, session: Session = Depends(get_session)):
|
||||
usage = get_or_404(session, MedicationUsage, usage_id)
|
||||
def delete_usage(
|
||||
usage_id: UUID,
|
||||
user_id: UUID | None = Query(default=None),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
usage = _get_owned_or_admin_override(
|
||||
session,
|
||||
MedicationUsage,
|
||||
usage_id,
|
||||
current_user,
|
||||
user_id,
|
||||
)
|
||||
session.delete(usage)
|
||||
session.commit()
|
||||
|
||||
|
|
@ -271,29 +388,35 @@ def list_lab_results(
|
|||
latest_only: bool = False,
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
offset: int = Query(default=0, ge=0),
|
||||
user_id: UUID | None = Query(default=None),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
filters = []
|
||||
target_user_id = _resolve_target_user_id(current_user, user_id)
|
||||
|
||||
def _apply_filters(statement):
|
||||
statement = statement.where(col(LabResult.user_id) == target_user_id)
|
||||
if q is not None and q.strip():
|
||||
query = f"%{q.strip()}%"
|
||||
filters.append(
|
||||
statement = statement.where(
|
||||
or_(
|
||||
col(LabResult.test_code).ilike(query),
|
||||
col(LabResult.test_name_original).ilike(query),
|
||||
)
|
||||
)
|
||||
if test_code is not None:
|
||||
filters.append(LabResult.test_code == test_code)
|
||||
statement = statement.where(col(LabResult.test_code) == test_code)
|
||||
if flag is not None:
|
||||
filters.append(LabResult.flag == flag)
|
||||
statement = statement.where(col(LabResult.flag) == flag)
|
||||
if flags:
|
||||
filters.append(col(LabResult.flag).in_(flags))
|
||||
statement = statement.where(col(LabResult.flag).in_(flags))
|
||||
if without_flag:
|
||||
filters.append(col(LabResult.flag).is_(None))
|
||||
statement = statement.where(col(LabResult.flag).is_(None))
|
||||
if from_date is not None:
|
||||
filters.append(LabResult.collected_at >= from_date)
|
||||
statement = statement.where(col(LabResult.collected_at) >= from_date)
|
||||
if to_date is not None:
|
||||
filters.append(LabResult.collected_at <= to_date)
|
||||
statement = statement.where(col(LabResult.collected_at) <= to_date)
|
||||
return statement
|
||||
|
||||
if latest_only:
|
||||
ranked_stmt = select(
|
||||
|
|
@ -309,8 +432,7 @@ def list_lab_results(
|
|||
)
|
||||
.label("rank"),
|
||||
)
|
||||
if filters:
|
||||
ranked_stmt = ranked_stmt.where(*filters)
|
||||
ranked_stmt = _apply_filters(ranked_stmt)
|
||||
|
||||
ranked_subquery = ranked_stmt.subquery()
|
||||
latest_ids = select(ranked_subquery.c.record_id).where(
|
||||
|
|
@ -323,11 +445,8 @@ def list_lab_results(
|
|||
.subquery()
|
||||
)
|
||||
else:
|
||||
stmt = select(LabResult)
|
||||
count_stmt = select(func.count()).select_from(LabResult)
|
||||
if filters:
|
||||
stmt = stmt.where(*filters)
|
||||
count_stmt = count_stmt.where(*filters)
|
||||
stmt = _apply_filters(select(LabResult))
|
||||
count_stmt = _apply_filters(select(func.count()).select_from(LabResult))
|
||||
|
||||
test_code_numeric = cast(
|
||||
func.replace(col(LabResult.test_code), "-", ""),
|
||||
|
|
@ -345,8 +464,18 @@ def list_lab_results(
|
|||
|
||||
|
||||
@router.post("/lab-results", response_model=LabResult, status_code=201)
|
||||
def create_lab_result(data: LabResultCreate, session: Session = Depends(get_session)):
|
||||
result = LabResult(record_id=uuid4(), **data.model_dump())
|
||||
def create_lab_result(
|
||||
data: LabResultCreate,
|
||||
user_id: UUID | None = Query(default=None),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
target_user_id = _resolve_target_user_id(current_user, user_id)
|
||||
result = LabResult(
|
||||
record_id=uuid4(),
|
||||
user_id=target_user_id,
|
||||
**data.model_dump(),
|
||||
)
|
||||
session.add(result)
|
||||
session.commit()
|
||||
session.refresh(result)
|
||||
|
|
@ -354,17 +483,36 @@ def create_lab_result(data: LabResultCreate, session: Session = Depends(get_sess
|
|||
|
||||
|
||||
@router.get("/lab-results/{result_id}", response_model=LabResult)
|
||||
def get_lab_result(result_id: UUID, session: Session = Depends(get_session)):
|
||||
return get_or_404(session, LabResult, result_id)
|
||||
def get_lab_result(
|
||||
result_id: UUID,
|
||||
user_id: UUID | None = Query(default=None),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
return _get_owned_or_admin_override(
|
||||
session,
|
||||
LabResult,
|
||||
result_id,
|
||||
current_user,
|
||||
user_id,
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/lab-results/{result_id}", response_model=LabResult)
|
||||
def update_lab_result(
|
||||
result_id: UUID,
|
||||
data: LabResultUpdate,
|
||||
user_id: UUID | None = Query(default=None),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
result = get_or_404(session, LabResult, result_id)
|
||||
result = _get_owned_or_admin_override(
|
||||
session,
|
||||
LabResult,
|
||||
result_id,
|
||||
current_user,
|
||||
user_id,
|
||||
)
|
||||
for key, value in data.model_dump(exclude_unset=True).items():
|
||||
setattr(result, key, value)
|
||||
session.add(result)
|
||||
|
|
@ -374,7 +522,18 @@ def update_lab_result(
|
|||
|
||||
|
||||
@router.delete("/lab-results/{result_id}", status_code=204)
|
||||
def delete_lab_result(result_id: UUID, session: Session = Depends(get_session)):
|
||||
result = get_or_404(session, LabResult, result_id)
|
||||
def delete_lab_result(
|
||||
result_id: UUID,
|
||||
user_id: UUID | None = Query(default=None),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
result = _get_owned_or_admin_override(
|
||||
session,
|
||||
LabResult,
|
||||
result_id,
|
||||
current_user,
|
||||
user_id,
|
||||
)
|
||||
session.delete(result)
|
||||
session.commit()
|
||||
|
|
|
|||
|
|
@ -2,16 +2,43 @@ from datetime import date
|
|||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlmodel import Session, col, select
|
||||
|
||||
from innercontext.auth import CurrentUser
|
||||
from innercontext.models import Product, UserProfile
|
||||
from innercontext.models.enums import Role
|
||||
|
||||
|
||||
def get_user_profile(session: Session) -> UserProfile | None:
|
||||
def _resolve_target_user_id(
|
||||
current_user: CurrentUser,
|
||||
user_id: UUID | None,
|
||||
) -> UUID:
|
||||
if user_id is None:
|
||||
return current_user.user_id
|
||||
if current_user.role is not Role.ADMIN:
|
||||
raise HTTPException(status_code=403, detail="Admin role required")
|
||||
return user_id
|
||||
|
||||
|
||||
def get_user_profile(
|
||||
session: Session,
|
||||
current_user: CurrentUser | None = None,
|
||||
*,
|
||||
user_id: UUID | None = None,
|
||||
) -> UserProfile | None:
|
||||
if current_user is None:
|
||||
return session.exec(
|
||||
select(UserProfile).order_by(col(UserProfile.created_at).desc())
|
||||
).first()
|
||||
|
||||
target_user_id = _resolve_target_user_id(current_user, user_id)
|
||||
return session.exec(
|
||||
select(UserProfile)
|
||||
.where(UserProfile.user_id == target_user_id)
|
||||
.order_by(col(UserProfile.created_at).desc())
|
||||
).first()
|
||||
|
||||
|
||||
def calculate_age(birth_date: date, reference_date: date) -> int:
|
||||
years = reference_date.year - birth_date.year
|
||||
|
|
@ -20,8 +47,14 @@ def calculate_age(birth_date: date, reference_date: date) -> int:
|
|||
return years
|
||||
|
||||
|
||||
def build_user_profile_context(session: Session, reference_date: date) -> str:
|
||||
profile = get_user_profile(session)
|
||||
def build_user_profile_context(
|
||||
session: Session,
|
||||
reference_date: date,
|
||||
current_user: CurrentUser | None = None,
|
||||
*,
|
||||
user_id: UUID | None = None,
|
||||
) -> str:
|
||||
profile = get_user_profile(session, current_user, user_id=user_id)
|
||||
if profile is None:
|
||||
return "USER PROFILE: no data\n"
|
||||
|
||||
|
|
@ -69,8 +102,9 @@ def build_product_context_summary(product: Product, has_inventory: bool = False)
|
|||
|
||||
# Get effect profile scores if available
|
||||
effects = []
|
||||
if hasattr(product, "effect_profile") and product.effect_profile:
|
||||
profile = product.effect_profile
|
||||
effect_profile = getattr(product, "effect_profile", None)
|
||||
if effect_profile:
|
||||
profile = effect_profile
|
||||
# Only include notable effects (score > 0)
|
||||
# Handle both dict (from DB) and object (from Pydantic)
|
||||
if isinstance(profile, dict):
|
||||
|
|
@ -165,11 +199,12 @@ def build_product_context_detailed(
|
|||
|
||||
# Effect profile
|
||||
effect_profile = None
|
||||
if hasattr(product, "effect_profile") and product.effect_profile:
|
||||
if isinstance(product.effect_profile, dict):
|
||||
effect_profile = product.effect_profile
|
||||
product_effect_profile = getattr(product, "effect_profile", None)
|
||||
if product_effect_profile:
|
||||
if isinstance(product_effect_profile, dict):
|
||||
effect_profile = product_effect_profile
|
||||
else:
|
||||
effect_profile = product.effect_profile.model_dump()
|
||||
effect_profile = product_effect_profile.model_dump()
|
||||
|
||||
# Context rules
|
||||
context_rules = None
|
||||
|
|
|
|||
|
|
@ -1,11 +1,14 @@
|
|||
from datetime import date, datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlmodel import Session, SQLModel
|
||||
|
||||
from db import get_session
|
||||
from innercontext.api.llm_context import get_user_profile
|
||||
from innercontext.api.auth_deps import get_current_user
|
||||
from innercontext.auth import CurrentUser
|
||||
from innercontext.models import SexAtBirth, UserProfile
|
||||
|
||||
router = APIRouter()
|
||||
|
|
@ -25,8 +28,12 @@ class UserProfilePublic(SQLModel):
|
|||
|
||||
|
||||
@router.get("", response_model=UserProfilePublic | None)
|
||||
def get_profile(session: Session = Depends(get_session)):
|
||||
profile = get_user_profile(session)
|
||||
def get_profile(
|
||||
user_id: UUID | None = Query(default=None),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
profile = get_user_profile(session, current_user, user_id=user_id)
|
||||
if profile is None:
|
||||
return None
|
||||
return UserProfilePublic(
|
||||
|
|
@ -39,12 +46,18 @@ def get_profile(session: Session = Depends(get_session)):
|
|||
|
||||
|
||||
@router.patch("", response_model=UserProfilePublic)
|
||||
def upsert_profile(data: UserProfileUpdate, session: Session = Depends(get_session)):
|
||||
profile = get_user_profile(session)
|
||||
def upsert_profile(
|
||||
data: UserProfileUpdate,
|
||||
user_id: UUID | None = Query(default=None),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
target_user_id = user_id if user_id is not None else current_user.user_id
|
||||
profile = get_user_profile(session, current_user, user_id=user_id)
|
||||
payload = data.model_dump(exclude_unset=True)
|
||||
|
||||
if profile is None:
|
||||
profile = UserProfile(**payload)
|
||||
profile = UserProfile(user_id=target_user_id, **payload)
|
||||
else:
|
||||
for key, value in payload.items():
|
||||
setattr(profile, key, value)
|
||||
|
|
|
|||
|
|
@ -5,12 +5,15 @@ from datetime import date, timedelta
|
|||
from typing import Any, Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from google.genai import types as genai_types
|
||||
from pydantic import BaseModel as PydanticBase
|
||||
from sqlalchemy import or_
|
||||
from sqlmodel import Field, Session, SQLModel, col, select
|
||||
|
||||
from db import get_session
|
||||
from innercontext.api.auth_deps import get_current_user
|
||||
from innercontext.api.authz import is_product_visible
|
||||
from innercontext.api.llm_context import (
|
||||
build_products_context_summary_list,
|
||||
build_user_profile_context,
|
||||
|
|
@ -25,7 +28,8 @@ from innercontext.api.product_llm_tools import (
|
|||
build_last_used_on_by_product,
|
||||
build_product_details_tool_handler,
|
||||
)
|
||||
from innercontext.api.utils import get_or_404
|
||||
from innercontext.api.utils import get_owned_or_404
|
||||
from innercontext.auth import CurrentUser
|
||||
from innercontext.llm import (
|
||||
call_gemini,
|
||||
call_gemini_with_function_tools,
|
||||
|
|
@ -33,6 +37,7 @@ from innercontext.llm import (
|
|||
)
|
||||
from innercontext.llm_safety import isolate_user_input, sanitize_user_input
|
||||
from innercontext.models import (
|
||||
HouseholdMembership,
|
||||
GroomingSchedule,
|
||||
Product,
|
||||
ProductInventory,
|
||||
|
|
@ -43,6 +48,7 @@ from innercontext.models import (
|
|||
from innercontext.models.ai_log import AICallLog
|
||||
from innercontext.models.api_metadata import ResponseMetadata, TokenMetrics
|
||||
from innercontext.models.enums import GroomingAction, PartOfDay
|
||||
from innercontext.models.enums import Role
|
||||
from innercontext.validators import BatchValidator, RoutineSuggestionValidator
|
||||
from innercontext.validators.batch_validator import BatchValidationContext
|
||||
from innercontext.validators.routine_validator import RoutineValidationContext
|
||||
|
|
@ -86,6 +92,47 @@ def _build_response_metadata(session: Session, log_id: Any) -> ResponseMetadata
|
|||
router = APIRouter()
|
||||
|
||||
|
||||
def _resolve_target_user_id(
|
||||
current_user: CurrentUser,
|
||||
user_id: UUID | None,
|
||||
) -> UUID:
|
||||
if user_id is None:
|
||||
return current_user.user_id
|
||||
if current_user.role is not Role.ADMIN:
|
||||
raise HTTPException(status_code=403, detail="Admin role required")
|
||||
return user_id
|
||||
|
||||
|
||||
def _shared_household_user_ids(
|
||||
session: Session, current_user: CurrentUser
|
||||
) -> set[UUID]:
|
||||
membership = current_user.household_membership
|
||||
if membership is None:
|
||||
return set()
|
||||
user_ids = session.exec(
|
||||
select(HouseholdMembership.user_id).where(
|
||||
HouseholdMembership.household_id == membership.household_id
|
||||
)
|
||||
).all()
|
||||
return {uid for uid in user_ids if uid != current_user.user_id}
|
||||
|
||||
|
||||
def _get_owned_or_admin_override(
|
||||
session: Session,
|
||||
model: type[Routine] | type[RoutineStep] | type[GroomingSchedule],
|
||||
record_id: UUID,
|
||||
current_user: CurrentUser,
|
||||
user_id: UUID | None,
|
||||
):
|
||||
if user_id is None:
|
||||
return get_owned_or_404(session, model, record_id, current_user)
|
||||
target_user_id = _resolve_target_user_id(current_user, user_id)
|
||||
record = session.get(model, record_id)
|
||||
if record is None or record.user_id != target_user_id:
|
||||
raise HTTPException(status_code=404, detail=f"{model.__name__} not found")
|
||||
return record
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -289,6 +336,7 @@ def _ev(v: object) -> str:
|
|||
|
||||
def _get_recent_skin_snapshot(
|
||||
session: Session,
|
||||
target_user_id: UUID,
|
||||
reference_date: date,
|
||||
window_days: int = HISTORY_WINDOW_DAYS,
|
||||
fallback_days: int = SNAPSHOT_FALLBACK_DAYS,
|
||||
|
|
@ -298,6 +346,7 @@ def _get_recent_skin_snapshot(
|
|||
|
||||
snapshot = session.exec(
|
||||
select(SkinConditionSnapshot)
|
||||
.where(SkinConditionSnapshot.user_id == target_user_id)
|
||||
.where(SkinConditionSnapshot.snapshot_date <= reference_date)
|
||||
.where(SkinConditionSnapshot.snapshot_date >= window_cutoff)
|
||||
.order_by(col(SkinConditionSnapshot.snapshot_date).desc())
|
||||
|
|
@ -307,6 +356,7 @@ def _get_recent_skin_snapshot(
|
|||
|
||||
return session.exec(
|
||||
select(SkinConditionSnapshot)
|
||||
.where(SkinConditionSnapshot.user_id == target_user_id)
|
||||
.where(SkinConditionSnapshot.snapshot_date <= reference_date)
|
||||
.where(SkinConditionSnapshot.snapshot_date >= fallback_cutoff)
|
||||
.order_by(col(SkinConditionSnapshot.snapshot_date).desc())
|
||||
|
|
@ -315,12 +365,14 @@ def _get_recent_skin_snapshot(
|
|||
|
||||
def _get_latest_skin_snapshot_within_days(
|
||||
session: Session,
|
||||
target_user_id: UUID,
|
||||
reference_date: date,
|
||||
max_age_days: int = SNAPSHOT_FALLBACK_DAYS,
|
||||
) -> SkinConditionSnapshot | None:
|
||||
cutoff = reference_date - timedelta(days=max_age_days)
|
||||
return session.exec(
|
||||
select(SkinConditionSnapshot)
|
||||
.where(SkinConditionSnapshot.user_id == target_user_id)
|
||||
.where(SkinConditionSnapshot.snapshot_date <= reference_date)
|
||||
.where(SkinConditionSnapshot.snapshot_date >= cutoff)
|
||||
.order_by(col(SkinConditionSnapshot.snapshot_date).desc())
|
||||
|
|
@ -329,12 +381,14 @@ def _get_latest_skin_snapshot_within_days(
|
|||
|
||||
def _build_skin_context(
|
||||
session: Session,
|
||||
target_user_id: UUID,
|
||||
reference_date: date,
|
||||
window_days: int = HISTORY_WINDOW_DAYS,
|
||||
fallback_days: int = SNAPSHOT_FALLBACK_DAYS,
|
||||
) -> str:
|
||||
snapshot = _get_recent_skin_snapshot(
|
||||
session,
|
||||
target_user_id=target_user_id,
|
||||
reference_date=reference_date,
|
||||
window_days=window_days,
|
||||
fallback_days=fallback_days,
|
||||
|
|
@ -354,10 +408,14 @@ def _build_skin_context(
|
|||
|
||||
|
||||
def _build_grooming_context(
|
||||
session: Session, weekdays: Optional[list[int]] = None
|
||||
session: Session,
|
||||
target_user_id: UUID,
|
||||
weekdays: Optional[list[int]] = None,
|
||||
) -> str:
|
||||
entries = session.exec(
|
||||
select(GroomingSchedule).order_by(col(GroomingSchedule.day_of_week))
|
||||
select(GroomingSchedule)
|
||||
.where(GroomingSchedule.user_id == target_user_id)
|
||||
.order_by(col(GroomingSchedule.day_of_week))
|
||||
).all()
|
||||
if not entries:
|
||||
return "GROOMING SCHEDULE: none\n"
|
||||
|
|
@ -378,11 +436,14 @@ def _build_grooming_context(
|
|||
|
||||
def _build_upcoming_grooming_context(
|
||||
session: Session,
|
||||
target_user_id: UUID,
|
||||
start_date: date,
|
||||
days: int = 7,
|
||||
) -> str:
|
||||
entries = session.exec(
|
||||
select(GroomingSchedule).order_by(col(GroomingSchedule.day_of_week))
|
||||
select(GroomingSchedule)
|
||||
.where(GroomingSchedule.user_id == target_user_id)
|
||||
.order_by(col(GroomingSchedule.day_of_week))
|
||||
).all()
|
||||
if not entries:
|
||||
return f"UPCOMING GROOMING (next {days} days): none\n"
|
||||
|
|
@ -420,12 +481,14 @@ def _build_upcoming_grooming_context(
|
|||
|
||||
def _build_recent_history(
|
||||
session: Session,
|
||||
target_user_id: UUID,
|
||||
reference_date: date,
|
||||
window_days: int = HISTORY_WINDOW_DAYS,
|
||||
) -> str:
|
||||
cutoff = reference_date - timedelta(days=window_days)
|
||||
routines = session.exec(
|
||||
select(Routine)
|
||||
.where(Routine.user_id == target_user_id)
|
||||
.where(Routine.routine_date <= reference_date)
|
||||
.where(Routine.routine_date >= cutoff)
|
||||
.order_by(col(Routine.routine_date).desc())
|
||||
|
|
@ -437,6 +500,7 @@ def _build_recent_history(
|
|||
steps = session.exec(
|
||||
select(RoutineStep)
|
||||
.where(RoutineStep.routine_id == r.id)
|
||||
.where(RoutineStep.user_id == target_user_id)
|
||||
.order_by(col(RoutineStep.order_index))
|
||||
).all()
|
||||
step_names = []
|
||||
|
|
@ -458,10 +522,36 @@ def _build_recent_history(
|
|||
|
||||
def _get_available_products(
|
||||
session: Session,
|
||||
current_user: CurrentUser,
|
||||
time_filter: Optional[str] = None,
|
||||
include_minoxidil: bool = True,
|
||||
) -> list[Product]:
|
||||
stmt = select(Product).where(col(Product.is_tool).is_(False))
|
||||
if current_user.role is not Role.ADMIN:
|
||||
owned_products = session.exec(
|
||||
stmt.where(col(Product.user_id) == current_user.user_id)
|
||||
).all()
|
||||
shared_user_ids = _shared_household_user_ids(session, current_user)
|
||||
shared_product_ids = (
|
||||
session.exec(
|
||||
select(ProductInventory.product_id)
|
||||
.where(col(ProductInventory.is_household_shared).is_(True))
|
||||
.where(col(ProductInventory.user_id).in_(list(shared_user_ids)))
|
||||
.distinct()
|
||||
).all()
|
||||
if shared_user_ids
|
||||
else []
|
||||
)
|
||||
shared_products = (
|
||||
session.exec(stmt.where(col(Product.id).in_(shared_product_ids))).all()
|
||||
if shared_product_ids
|
||||
else []
|
||||
)
|
||||
products_by_id = {p.id: p for p in owned_products}
|
||||
for product in shared_products:
|
||||
products_by_id.setdefault(product.id, product)
|
||||
products = list(products_by_id.values())
|
||||
else:
|
||||
products = session.exec(stmt).all()
|
||||
result: list[Product] = []
|
||||
for p in products:
|
||||
|
|
@ -517,7 +607,9 @@ def _extract_requested_product_ids(
|
|||
|
||||
|
||||
def _get_products_with_inventory(
|
||||
session: Session, product_ids: list[UUID]
|
||||
session: Session,
|
||||
current_user: CurrentUser,
|
||||
product_ids: list[UUID],
|
||||
) -> set[UUID]:
|
||||
"""
|
||||
Return set of product IDs that have active (non-finished) inventory.
|
||||
|
|
@ -527,17 +619,33 @@ def _get_products_with_inventory(
|
|||
if not product_ids:
|
||||
return set()
|
||||
|
||||
inventory_rows = session.exec(
|
||||
stmt = (
|
||||
select(ProductInventory.product_id)
|
||||
.where(col(ProductInventory.product_id).in_(product_ids))
|
||||
.where(col(ProductInventory.finished_at).is_(None))
|
||||
)
|
||||
if current_user.role is not Role.ADMIN:
|
||||
owned_inventory_rows = session.exec(
|
||||
stmt.where(col(ProductInventory.user_id) == current_user.user_id).distinct()
|
||||
).all()
|
||||
shared_user_ids = _shared_household_user_ids(session, current_user)
|
||||
shared_inventory_rows = session.exec(
|
||||
stmt.where(col(ProductInventory.is_household_shared).is_(True))
|
||||
.where(col(ProductInventory.user_id).in_(list(shared_user_ids)))
|
||||
.distinct()
|
||||
).all()
|
||||
|
||||
inventory_rows = set(owned_inventory_rows)
|
||||
inventory_rows.update(shared_inventory_rows)
|
||||
return inventory_rows
|
||||
inventory_rows = session.exec(stmt.distinct()).all()
|
||||
return set(inventory_rows)
|
||||
|
||||
|
||||
def _expand_product_id(session: Session, short_or_full_id: str) -> UUID | None:
|
||||
def _expand_product_id(
|
||||
session: Session,
|
||||
current_user: CurrentUser,
|
||||
short_or_full_id: str,
|
||||
) -> UUID | None:
|
||||
"""
|
||||
Expand 8-char short_id to full UUID, or validate full UUID.
|
||||
|
||||
|
|
@ -558,7 +666,13 @@ def _expand_product_id(session: Session, short_or_full_id: str) -> UUID | None:
|
|||
uuid_obj = UUID(short_or_full_id)
|
||||
# Verify it exists
|
||||
product = session.get(Product, uuid_obj)
|
||||
return uuid_obj if product else None
|
||||
if product is None:
|
||||
return None
|
||||
return (
|
||||
uuid_obj
|
||||
if is_product_visible(session, uuid_obj, current_user)
|
||||
else None
|
||||
)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
|
|
@ -567,7 +681,13 @@ def _expand_product_id(session: Session, short_or_full_id: str) -> UUID | None:
|
|||
product = session.exec(
|
||||
select(Product).where(Product.short_id == short_or_full_id)
|
||||
).first()
|
||||
return product.id if product else None
|
||||
if product is None:
|
||||
return None
|
||||
return (
|
||||
product.id
|
||||
if is_product_visible(session, product.id, current_user)
|
||||
else None
|
||||
)
|
||||
|
||||
# Invalid length
|
||||
return None
|
||||
|
|
@ -590,6 +710,17 @@ def _build_day_context(leaving_home: Optional[bool]) -> str:
|
|||
return f"DAY CONTEXT:\n Leaving home: {val}\n"
|
||||
|
||||
|
||||
def _coerce_action_type(value: object) -> GroomingAction | None:
|
||||
if isinstance(value, GroomingAction):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return GroomingAction(value)
|
||||
except ValueError:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
_ROUTINES_SYSTEM_PROMPT = """\
|
||||
Jesteś ekspertem planowania pielęgnacji.
|
||||
|
||||
|
|
@ -676,9 +807,12 @@ def list_routines(
|
|||
from_date: Optional[date] = None,
|
||||
to_date: Optional[date] = None,
|
||||
part_of_day: Optional[PartOfDay] = None,
|
||||
user_id: UUID | None = Query(default=None),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
stmt = select(Routine)
|
||||
target_user_id = _resolve_target_user_id(current_user, user_id)
|
||||
stmt = select(Routine).where(Routine.user_id == target_user_id)
|
||||
if from_date is not None:
|
||||
stmt = stmt.where(Routine.routine_date >= from_date)
|
||||
if to_date is not None:
|
||||
|
|
@ -688,10 +822,12 @@ def list_routines(
|
|||
routines = session.exec(stmt).all()
|
||||
|
||||
routine_ids = [r.id for r in routines]
|
||||
steps_by_routine: dict = {}
|
||||
steps_by_routine: dict[UUID, list[RoutineStep]] = {}
|
||||
if routine_ids:
|
||||
all_steps = session.exec(
|
||||
select(RoutineStep).where(col(RoutineStep.routine_id).in_(routine_ids))
|
||||
select(RoutineStep)
|
||||
.where(col(RoutineStep.routine_id).in_(routine_ids))
|
||||
.where(RoutineStep.user_id == target_user_id)
|
||||
).all()
|
||||
for step in all_steps:
|
||||
steps_by_routine.setdefault(step.routine_id, []).append(step)
|
||||
|
|
@ -707,8 +843,14 @@ def list_routines(
|
|||
|
||||
|
||||
@router.post("", response_model=Routine, status_code=201)
|
||||
def create_routine(data: RoutineCreate, session: Session = Depends(get_session)):
|
||||
routine = Routine(id=uuid4(), **data.model_dump())
|
||||
def create_routine(
|
||||
data: RoutineCreate,
|
||||
user_id: UUID | None = Query(default=None),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
target_user_id = _resolve_target_user_id(current_user, user_id)
|
||||
routine = Routine(id=uuid4(), user_id=target_user_id, **data.model_dump())
|
||||
session.add(routine)
|
||||
session.commit()
|
||||
session.refresh(routine)
|
||||
|
|
@ -724,19 +866,35 @@ def create_routine(data: RoutineCreate, session: Session = Depends(get_session))
|
|||
def suggest_routine(
|
||||
data: SuggestRoutineRequest,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
target_user_id = current_user.user_id
|
||||
weekday = data.routine_date.weekday()
|
||||
skin_ctx = _build_skin_context(session, reference_date=data.routine_date)
|
||||
profile_ctx = build_user_profile_context(session, reference_date=data.routine_date)
|
||||
skin_ctx = _build_skin_context(
|
||||
session,
|
||||
target_user_id=target_user_id,
|
||||
reference_date=data.routine_date,
|
||||
)
|
||||
profile_ctx = build_user_profile_context(
|
||||
session,
|
||||
reference_date=data.routine_date,
|
||||
current_user=current_user,
|
||||
)
|
||||
upcoming_grooming_ctx = _build_upcoming_grooming_context(
|
||||
session,
|
||||
target_user_id=target_user_id,
|
||||
start_date=data.routine_date,
|
||||
days=7,
|
||||
)
|
||||
history_ctx = _build_recent_history(session, reference_date=data.routine_date)
|
||||
history_ctx = _build_recent_history(
|
||||
session,
|
||||
target_user_id=target_user_id,
|
||||
reference_date=data.routine_date,
|
||||
)
|
||||
day_ctx = _build_day_context(data.leaving_home)
|
||||
available_products = _get_available_products(
|
||||
session,
|
||||
current_user=current_user,
|
||||
time_filter=data.part_of_day.value,
|
||||
include_minoxidil=data.include_minoxidil_beard,
|
||||
)
|
||||
|
|
@ -752,7 +910,9 @@ def suggest_routine(
|
|||
|
||||
# Phase 2: Use tiered context (summary mode for initial prompt)
|
||||
products_with_inventory = _get_products_with_inventory(
|
||||
session, [p.id for p in available_products]
|
||||
session,
|
||||
current_user,
|
||||
[p.id for p in available_products],
|
||||
)
|
||||
products_ctx = build_products_context_summary_list(
|
||||
available_products, products_with_inventory
|
||||
|
|
@ -865,22 +1025,35 @@ def suggest_routine(
|
|||
|
||||
# Translation layer: Expand short_ids (8 chars) to full UUIDs (36 chars)
|
||||
steps = []
|
||||
for s in parsed.get("steps", []):
|
||||
raw_steps = parsed.get("steps", [])
|
||||
if not isinstance(raw_steps, list):
|
||||
raw_steps = []
|
||||
for s in raw_steps:
|
||||
if not isinstance(s, dict):
|
||||
continue
|
||||
product_id_str = s.get("product_id")
|
||||
product_id_uuid = None
|
||||
|
||||
if product_id_str:
|
||||
if isinstance(product_id_str, str) and product_id_str:
|
||||
# Expand short_id or validate full UUID
|
||||
product_id_uuid = _expand_product_id(session, product_id_str)
|
||||
product_id_uuid = _expand_product_id(session, current_user, product_id_str)
|
||||
|
||||
action_type = s.get("action_type")
|
||||
action_notes = s.get("action_notes")
|
||||
region = s.get("region")
|
||||
why_this_step = s.get("why_this_step")
|
||||
optional = s.get("optional")
|
||||
|
||||
steps.append(
|
||||
SuggestedStep(
|
||||
product_id=product_id_uuid,
|
||||
action_type=s.get("action_type") or None,
|
||||
action_notes=s.get("action_notes"),
|
||||
region=s.get("region"),
|
||||
why_this_step=s.get("why_this_step"),
|
||||
optional=s.get("optional"),
|
||||
action_type=_coerce_action_type(action_type),
|
||||
action_notes=action_notes if isinstance(action_notes, str) else None,
|
||||
region=region if isinstance(region, str) else None,
|
||||
why_this_step=(
|
||||
why_this_step if isinstance(why_this_step, str) else None
|
||||
),
|
||||
optional=optional if isinstance(optional, bool) else None,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -904,6 +1077,7 @@ def suggest_routine(
|
|||
# Get skin snapshot for barrier state
|
||||
skin_snapshot = _get_latest_skin_snapshot_within_days(
|
||||
session,
|
||||
target_user_id=target_user_id,
|
||||
reference_date=data.routine_date,
|
||||
)
|
||||
|
||||
|
|
@ -964,7 +1138,9 @@ def suggest_routine(
|
|||
def suggest_batch(
|
||||
data: SuggestBatchRequest,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
target_user_id = current_user.user_id
|
||||
delta = (data.to_date - data.from_date).days + 1
|
||||
if delta > 14:
|
||||
raise HTTPException(
|
||||
|
|
@ -976,18 +1152,37 @@ def suggest_batch(
|
|||
weekdays = list(
|
||||
{(data.from_date + timedelta(days=i)).weekday() for i in range(delta)}
|
||||
)
|
||||
profile_ctx = build_user_profile_context(session, reference_date=data.from_date)
|
||||
skin_ctx = _build_skin_context(session, reference_date=data.from_date)
|
||||
grooming_ctx = _build_grooming_context(session, weekdays=weekdays)
|
||||
history_ctx = _build_recent_history(session, reference_date=data.from_date)
|
||||
profile_ctx = build_user_profile_context(
|
||||
session,
|
||||
reference_date=data.from_date,
|
||||
current_user=current_user,
|
||||
)
|
||||
skin_ctx = _build_skin_context(
|
||||
session,
|
||||
target_user_id=target_user_id,
|
||||
reference_date=data.from_date,
|
||||
)
|
||||
grooming_ctx = _build_grooming_context(
|
||||
session,
|
||||
target_user_id=target_user_id,
|
||||
weekdays=weekdays,
|
||||
)
|
||||
history_ctx = _build_recent_history(
|
||||
session,
|
||||
target_user_id=target_user_id,
|
||||
reference_date=data.from_date,
|
||||
)
|
||||
batch_products = _get_available_products(
|
||||
session,
|
||||
current_user=current_user,
|
||||
include_minoxidil=data.include_minoxidil_beard,
|
||||
)
|
||||
|
||||
# Phase 2: Use tiered context (summary mode for batch planning)
|
||||
products_with_inventory = _get_products_with_inventory(
|
||||
session, [p.id for p in batch_products]
|
||||
session,
|
||||
current_user,
|
||||
[p.id for p in batch_products],
|
||||
)
|
||||
products_ctx = build_products_context_summary_list(
|
||||
batch_products, products_with_inventory
|
||||
|
|
@ -1045,25 +1240,39 @@ def suggest_batch(
|
|||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=502, detail=f"LLM returned invalid JSON: {e}")
|
||||
|
||||
def _parse_steps(raw_steps: list) -> list[SuggestedStep]:
|
||||
def _parse_steps(raw_steps: list[dict[str, object]]) -> list[SuggestedStep]:
|
||||
"""Parse steps and expand short_ids to full UUIDs."""
|
||||
result = []
|
||||
for s in raw_steps:
|
||||
product_id_str = s.get("product_id")
|
||||
product_id_uuid = None
|
||||
|
||||
if product_id_str:
|
||||
if isinstance(product_id_str, str) and product_id_str:
|
||||
# Translation layer: expand short_id to full UUID
|
||||
product_id_uuid = _expand_product_id(session, product_id_str)
|
||||
product_id_uuid = _expand_product_id(
|
||||
session,
|
||||
current_user,
|
||||
product_id_str,
|
||||
)
|
||||
|
||||
action_type = s.get("action_type")
|
||||
action_notes = s.get("action_notes")
|
||||
region = s.get("region")
|
||||
why_this_step = s.get("why_this_step")
|
||||
optional = s.get("optional")
|
||||
|
||||
result.append(
|
||||
SuggestedStep(
|
||||
product_id=product_id_uuid,
|
||||
action_type=s.get("action_type") or None,
|
||||
action_notes=s.get("action_notes"),
|
||||
region=s.get("region"),
|
||||
why_this_step=s.get("why_this_step"),
|
||||
optional=s.get("optional"),
|
||||
action_type=_coerce_action_type(action_type),
|
||||
action_notes=(
|
||||
action_notes if isinstance(action_notes, str) else None
|
||||
),
|
||||
region=region if isinstance(region, str) else None,
|
||||
why_this_step=(
|
||||
why_this_step if isinstance(why_this_step, str) else None
|
||||
),
|
||||
optional=optional if isinstance(optional, bool) else None,
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
|
@ -1086,6 +1295,7 @@ def suggest_batch(
|
|||
# Get skin snapshot for barrier state
|
||||
skin_snapshot = _get_latest_skin_snapshot_within_days(
|
||||
session,
|
||||
target_user_id=target_user_id,
|
||||
reference_date=data.from_date,
|
||||
)
|
||||
|
||||
|
|
@ -1140,15 +1350,36 @@ def suggest_batch(
|
|||
|
||||
# Grooming-schedule GET must appear before /{routine_id} to avoid being shadowed
|
||||
@router.get("/grooming-schedule", response_model=list[GroomingSchedule])
|
||||
def list_grooming_schedule(session: Session = Depends(get_session)):
|
||||
return session.exec(select(GroomingSchedule)).all()
|
||||
def list_grooming_schedule(
|
||||
user_id: UUID | None = Query(default=None),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
target_user_id = _resolve_target_user_id(current_user, user_id)
|
||||
return session.exec(
|
||||
select(GroomingSchedule).where(GroomingSchedule.user_id == target_user_id)
|
||||
).all()
|
||||
|
||||
|
||||
@router.get("/{routine_id}")
|
||||
def get_routine(routine_id: UUID, session: Session = Depends(get_session)):
|
||||
routine = get_or_404(session, Routine, routine_id)
|
||||
def get_routine(
|
||||
routine_id: UUID,
|
||||
user_id: UUID | None = Query(default=None),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
target_user_id = _resolve_target_user_id(current_user, user_id)
|
||||
routine = _get_owned_or_admin_override(
|
||||
session,
|
||||
Routine,
|
||||
routine_id,
|
||||
current_user,
|
||||
user_id,
|
||||
)
|
||||
steps = session.exec(
|
||||
select(RoutineStep).where(RoutineStep.routine_id == routine_id)
|
||||
select(RoutineStep)
|
||||
.where(RoutineStep.routine_id == routine_id)
|
||||
.where(RoutineStep.user_id == target_user_id)
|
||||
).all()
|
||||
data = routine.model_dump(mode="json")
|
||||
data["steps"] = [step.model_dump(mode="json") for step in steps]
|
||||
|
|
@ -1159,9 +1390,17 @@ def get_routine(routine_id: UUID, session: Session = Depends(get_session)):
|
|||
def update_routine(
|
||||
routine_id: UUID,
|
||||
data: RoutineUpdate,
|
||||
user_id: UUID | None = Query(default=None),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
routine = get_or_404(session, Routine, routine_id)
|
||||
routine = _get_owned_or_admin_override(
|
||||
session,
|
||||
Routine,
|
||||
routine_id,
|
||||
current_user,
|
||||
user_id,
|
||||
)
|
||||
for key, value in data.model_dump(exclude_unset=True).items():
|
||||
setattr(routine, key, value)
|
||||
session.add(routine)
|
||||
|
|
@ -1171,8 +1410,19 @@ def update_routine(
|
|||
|
||||
|
||||
@router.delete("/{routine_id}", status_code=204)
|
||||
def delete_routine(routine_id: UUID, session: Session = Depends(get_session)):
|
||||
routine = get_or_404(session, Routine, routine_id)
|
||||
def delete_routine(
|
||||
routine_id: UUID,
|
||||
user_id: UUID | None = Query(default=None),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
routine = _get_owned_or_admin_override(
|
||||
session,
|
||||
Routine,
|
||||
routine_id,
|
||||
current_user,
|
||||
user_id,
|
||||
)
|
||||
session.delete(routine)
|
||||
session.commit()
|
||||
|
||||
|
|
@ -1186,10 +1436,28 @@ def delete_routine(routine_id: UUID, session: Session = Depends(get_session)):
|
|||
def add_step(
|
||||
routine_id: UUID,
|
||||
data: RoutineStepCreate,
|
||||
user_id: UUID | None = Query(default=None),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
get_or_404(session, Routine, routine_id)
|
||||
step = RoutineStep(id=uuid4(), routine_id=routine_id, **data.model_dump())
|
||||
target_user_id = _resolve_target_user_id(current_user, user_id)
|
||||
_ = _get_owned_or_admin_override(
|
||||
session,
|
||||
Routine,
|
||||
routine_id,
|
||||
current_user,
|
||||
user_id,
|
||||
)
|
||||
if data.product_id and not is_product_visible(
|
||||
session, data.product_id, current_user
|
||||
):
|
||||
raise HTTPException(status_code=404, detail="Product not found")
|
||||
step = RoutineStep(
|
||||
id=uuid4(),
|
||||
user_id=target_user_id,
|
||||
routine_id=routine_id,
|
||||
**data.model_dump(),
|
||||
)
|
||||
session.add(step)
|
||||
session.commit()
|
||||
session.refresh(step)
|
||||
|
|
@ -1200,9 +1468,21 @@ def add_step(
|
|||
def update_step(
|
||||
step_id: UUID,
|
||||
data: RoutineStepUpdate,
|
||||
user_id: UUID | None = Query(default=None),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
step = get_or_404(session, RoutineStep, step_id)
|
||||
step = _get_owned_or_admin_override(
|
||||
session,
|
||||
RoutineStep,
|
||||
step_id,
|
||||
current_user,
|
||||
user_id,
|
||||
)
|
||||
if data.product_id and not is_product_visible(
|
||||
session, data.product_id, current_user
|
||||
):
|
||||
raise HTTPException(status_code=404, detail="Product not found")
|
||||
for key, value in data.model_dump(exclude_unset=True).items():
|
||||
setattr(step, key, value)
|
||||
session.add(step)
|
||||
|
|
@ -1212,8 +1492,19 @@ def update_step(
|
|||
|
||||
|
||||
@router.delete("/steps/{step_id}", status_code=204)
|
||||
def delete_step(step_id: UUID, session: Session = Depends(get_session)):
|
||||
step = get_or_404(session, RoutineStep, step_id)
|
||||
def delete_step(
|
||||
step_id: UUID,
|
||||
user_id: UUID | None = Query(default=None),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
step = _get_owned_or_admin_override(
|
||||
session,
|
||||
RoutineStep,
|
||||
step_id,
|
||||
current_user,
|
||||
user_id,
|
||||
)
|
||||
session.delete(step)
|
||||
session.commit()
|
||||
|
||||
|
|
@ -1225,9 +1516,13 @@ def delete_step(step_id: UUID, session: Session = Depends(get_session)):
|
|||
|
||||
@router.post("/grooming-schedule", response_model=GroomingSchedule, status_code=201)
|
||||
def create_grooming_schedule(
|
||||
data: GroomingScheduleCreate, session: Session = Depends(get_session)
|
||||
data: GroomingScheduleCreate,
|
||||
user_id: UUID | None = Query(default=None),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
entry = GroomingSchedule(id=uuid4(), **data.model_dump())
|
||||
target_user_id = _resolve_target_user_id(current_user, user_id)
|
||||
entry = GroomingSchedule(id=uuid4(), user_id=target_user_id, **data.model_dump())
|
||||
session.add(entry)
|
||||
session.commit()
|
||||
session.refresh(entry)
|
||||
|
|
@ -1238,9 +1533,17 @@ def create_grooming_schedule(
|
|||
def update_grooming_schedule(
|
||||
entry_id: UUID,
|
||||
data: GroomingScheduleUpdate,
|
||||
user_id: UUID | None = Query(default=None),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
entry = get_or_404(session, GroomingSchedule, entry_id)
|
||||
entry = _get_owned_or_admin_override(
|
||||
session,
|
||||
GroomingSchedule,
|
||||
entry_id,
|
||||
current_user,
|
||||
user_id,
|
||||
)
|
||||
for key, value in data.model_dump(exclude_unset=True).items():
|
||||
setattr(entry, key, value)
|
||||
session.add(entry)
|
||||
|
|
@ -1250,7 +1553,18 @@ def update_grooming_schedule(
|
|||
|
||||
|
||||
@router.delete("/grooming-schedule/{entry_id}", status_code=204)
|
||||
def delete_grooming_schedule(entry_id: UUID, session: Session = Depends(get_session)):
|
||||
entry = get_or_404(session, GroomingSchedule, entry_id)
|
||||
def delete_grooming_schedule(
|
||||
entry_id: UUID,
|
||||
user_id: UUID | None = Query(default=None),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
entry = _get_owned_or_admin_override(
|
||||
session,
|
||||
GroomingSchedule,
|
||||
entry_id,
|
||||
current_user,
|
||||
user_id,
|
||||
)
|
||||
session.delete(entry)
|
||||
session.commit()
|
||||
|
|
|
|||
|
|
@ -4,15 +4,17 @@ from datetime import date
|
|||
from typing import Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, Query, UploadFile
|
||||
from google.genai import types as genai_types
|
||||
from pydantic import BaseModel as PydanticBase
|
||||
from pydantic import ValidationError
|
||||
from sqlmodel import Session, SQLModel, select
|
||||
|
||||
from db import get_session
|
||||
from innercontext.api.auth_deps import get_current_user
|
||||
from innercontext.api.llm_context import build_user_profile_context
|
||||
from innercontext.api.utils import get_or_404
|
||||
from innercontext.api.utils import get_owned_or_404
|
||||
from innercontext.auth import CurrentUser
|
||||
from innercontext.llm import call_gemini, get_extraction_config
|
||||
from innercontext.models import (
|
||||
SkinConditionSnapshot,
|
||||
|
|
@ -26,6 +28,7 @@ from innercontext.models.enums import (
|
|||
SkinTexture,
|
||||
SkinType,
|
||||
)
|
||||
from innercontext.models.enums import Role
|
||||
from innercontext.validators import PhotoValidator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -135,6 +138,34 @@ OUTPUT (all fields optional):
|
|||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _resolve_target_user_id(
|
||||
current_user: CurrentUser,
|
||||
user_id: UUID | None,
|
||||
) -> UUID:
|
||||
if user_id is None:
|
||||
return current_user.user_id
|
||||
if current_user.role is not Role.ADMIN:
|
||||
raise HTTPException(status_code=403, detail="Admin role required")
|
||||
return user_id
|
||||
|
||||
|
||||
def _get_owned_or_admin_override(
|
||||
session: Session,
|
||||
snapshot_id: UUID,
|
||||
current_user: CurrentUser,
|
||||
user_id: UUID | None,
|
||||
) -> SkinConditionSnapshot:
|
||||
if user_id is None:
|
||||
return get_owned_or_404(
|
||||
session, SkinConditionSnapshot, snapshot_id, current_user
|
||||
)
|
||||
target_user_id = _resolve_target_user_id(current_user, user_id)
|
||||
snapshot = session.get(SkinConditionSnapshot, snapshot_id)
|
||||
if snapshot is None or snapshot.user_id != target_user_id:
|
||||
raise HTTPException(status_code=404, detail="SkinConditionSnapshot not found")
|
||||
return snapshot
|
||||
|
||||
|
||||
MAX_IMAGE_BYTES = 5 * 1024 * 1024 # 5 MB
|
||||
|
||||
|
||||
|
|
@ -142,6 +173,7 @@ MAX_IMAGE_BYTES = 5 * 1024 * 1024 # 5 MB
|
|||
async def analyze_skin_photos(
|
||||
photos: list[UploadFile] = File(...),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
) -> SkinPhotoAnalysisResponse:
|
||||
if not (1 <= len(photos) <= 3):
|
||||
raise HTTPException(status_code=422, detail="Send between 1 and 3 photos.")
|
||||
|
|
@ -174,7 +206,11 @@ async def analyze_skin_photos(
|
|||
)
|
||||
parts.append(
|
||||
genai_types.Part.from_text(
|
||||
text=build_user_profile_context(session, reference_date=date.today())
|
||||
text=build_user_profile_context(
|
||||
session,
|
||||
reference_date=date.today(),
|
||||
current_user=current_user,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -224,9 +260,14 @@ def list_snapshots(
|
|||
from_date: Optional[date] = None,
|
||||
to_date: Optional[date] = None,
|
||||
overall_state: Optional[OverallSkinState] = None,
|
||||
user_id: UUID | None = Query(default=None),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
stmt = select(SkinConditionSnapshot)
|
||||
target_user_id = _resolve_target_user_id(current_user, user_id)
|
||||
stmt = select(SkinConditionSnapshot).where(
|
||||
SkinConditionSnapshot.user_id == target_user_id
|
||||
)
|
||||
if from_date is not None:
|
||||
stmt = stmt.where(SkinConditionSnapshot.snapshot_date >= from_date)
|
||||
if to_date is not None:
|
||||
|
|
@ -237,8 +278,18 @@ def list_snapshots(
|
|||
|
||||
|
||||
@router.post("", response_model=SkinConditionSnapshotPublic, status_code=201)
|
||||
def create_snapshot(data: SnapshotCreate, session: Session = Depends(get_session)):
|
||||
snapshot = SkinConditionSnapshot(id=uuid4(), **data.model_dump())
|
||||
def create_snapshot(
|
||||
data: SnapshotCreate,
|
||||
user_id: UUID | None = Query(default=None),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
target_user_id = _resolve_target_user_id(current_user, user_id)
|
||||
snapshot = SkinConditionSnapshot(
|
||||
id=uuid4(),
|
||||
user_id=target_user_id,
|
||||
**data.model_dump(),
|
||||
)
|
||||
session.add(snapshot)
|
||||
session.commit()
|
||||
session.refresh(snapshot)
|
||||
|
|
@ -246,17 +297,34 @@ def create_snapshot(data: SnapshotCreate, session: Session = Depends(get_session
|
|||
|
||||
|
||||
@router.get("/{snapshot_id}", response_model=SkinConditionSnapshotPublic)
|
||||
def get_snapshot(snapshot_id: UUID, session: Session = Depends(get_session)):
|
||||
return get_or_404(session, SkinConditionSnapshot, snapshot_id)
|
||||
def get_snapshot(
|
||||
snapshot_id: UUID,
|
||||
user_id: UUID | None = Query(default=None),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
return _get_owned_or_admin_override(
|
||||
session,
|
||||
snapshot_id,
|
||||
current_user,
|
||||
user_id,
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/{snapshot_id}", response_model=SkinConditionSnapshotPublic)
|
||||
def update_snapshot(
|
||||
snapshot_id: UUID,
|
||||
data: SnapshotUpdate,
|
||||
user_id: UUID | None = Query(default=None),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
snapshot = get_or_404(session, SkinConditionSnapshot, snapshot_id)
|
||||
snapshot = _get_owned_or_admin_override(
|
||||
session,
|
||||
snapshot_id,
|
||||
current_user,
|
||||
user_id,
|
||||
)
|
||||
for key, value in data.model_dump(exclude_unset=True).items():
|
||||
setattr(snapshot, key, value)
|
||||
session.add(snapshot)
|
||||
|
|
@ -266,7 +334,17 @@ def update_snapshot(
|
|||
|
||||
|
||||
@router.delete("/{snapshot_id}", status_code=204)
|
||||
def delete_snapshot(snapshot_id: UUID, session: Session = Depends(get_session)):
|
||||
snapshot = get_or_404(session, SkinConditionSnapshot, snapshot_id)
|
||||
def delete_snapshot(
|
||||
snapshot_id: UUID,
|
||||
user_id: UUID | None = Query(default=None),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
snapshot = _get_owned_or_admin_override(
|
||||
session,
|
||||
snapshot_id,
|
||||
current_user,
|
||||
user_id,
|
||||
)
|
||||
session.delete(snapshot)
|
||||
session.commit()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel import Session, col, select
|
||||
|
||||
|
|
@ -66,9 +67,53 @@ def _apply_pricing_snapshot(session: Session, computed_at: datetime) -> int:
|
|||
return len(products)
|
||||
|
||||
|
||||
def _scope_user_id(scope: str) -> UUID | None:
|
||||
prefix = "user:"
|
||||
if not scope.startswith(prefix):
|
||||
return None
|
||||
raw_user_id = scope[len(prefix) :].strip()
|
||||
if not raw_user_id:
|
||||
return None
|
||||
try:
|
||||
return UUID(raw_user_id)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def _apply_pricing_snapshot_for_scope(
|
||||
session: Session,
|
||||
*,
|
||||
computed_at: datetime,
|
||||
scope: str,
|
||||
) -> int:
|
||||
from innercontext.api.products import _compute_pricing_outputs
|
||||
|
||||
scoped_user_id = _scope_user_id(scope)
|
||||
stmt = select(Product)
|
||||
if scoped_user_id is not None:
|
||||
stmt = stmt.where(Product.user_id == scoped_user_id)
|
||||
products = list(session.exec(stmt).all())
|
||||
pricing_outputs = _compute_pricing_outputs(products)
|
||||
|
||||
for product in products:
|
||||
tier, price_per_use_pln, tier_source = pricing_outputs.get(
|
||||
product.id, (None, None, None)
|
||||
)
|
||||
product.price_tier = tier
|
||||
product.price_per_use_pln = price_per_use_pln
|
||||
product.price_tier_source = tier_source
|
||||
product.pricing_computed_at = computed_at
|
||||
|
||||
return len(products)
|
||||
|
||||
|
||||
def process_pricing_job(session: Session, job: PricingRecalcJob) -> int:
|
||||
try:
|
||||
updated_count = _apply_pricing_snapshot(session, computed_at=utc_now())
|
||||
updated_count = _apply_pricing_snapshot_for_scope(
|
||||
session,
|
||||
computed_at=utc_now(),
|
||||
scope=job.scope,
|
||||
)
|
||||
job.status = "succeeded"
|
||||
job.finished_at = utc_now()
|
||||
job.error = None
|
||||
|
|
|
|||
|
|
@ -37,13 +37,7 @@ def session(monkeypatch):
|
|||
|
||||
|
||||
@pytest.fixture()
|
||||
def client(session, monkeypatch):
|
||||
"""TestClient using the per-test session for every request."""
|
||||
|
||||
def _override():
|
||||
yield session
|
||||
|
||||
def _current_user_override():
|
||||
def current_user() -> CurrentUser:
|
||||
claims = TokenClaims(
|
||||
issuer="https://auth.test",
|
||||
subject="test-user",
|
||||
|
|
@ -59,6 +53,17 @@ def client(session, monkeypatch):
|
|||
claims=claims,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client(session, monkeypatch, current_user):
|
||||
"""TestClient using the per-test session for every request."""
|
||||
|
||||
def _override():
|
||||
yield session
|
||||
|
||||
def _current_user_override():
|
||||
return current_user
|
||||
|
||||
app.dependency_overrides[get_session] = _override
|
||||
app.dependency_overrides[get_current_user] = _current_user_override
|
||||
with TestClient(app) as c:
|
||||
|
|
|
|||
|
|
@ -4,12 +4,13 @@ from typing import Any, cast
|
|||
from innercontext.models.ai_log import AICallLog
|
||||
|
||||
|
||||
def test_list_ai_logs_normalizes_tool_trace_string(client, session):
|
||||
def test_list_ai_logs_normalizes_tool_trace_string(client, session, current_user):
|
||||
log = AICallLog(
|
||||
id=uuid.uuid4(),
|
||||
endpoint="routines/suggest",
|
||||
model="gemini-3-flash-preview",
|
||||
success=True,
|
||||
user_id=current_user.user_id,
|
||||
)
|
||||
log.tool_trace = cast(
|
||||
Any,
|
||||
|
|
@ -26,12 +27,13 @@ def test_list_ai_logs_normalizes_tool_trace_string(client, session):
|
|||
assert data[0]["tool_trace"]["events"][0]["function"] == "get_product_inci"
|
||||
|
||||
|
||||
def test_get_ai_log_normalizes_tool_trace_string(client, session):
|
||||
def test_get_ai_log_normalizes_tool_trace_string(client, session, current_user):
|
||||
log = AICallLog(
|
||||
id=uuid.uuid4(),
|
||||
endpoint="routines/suggest",
|
||||
model="gemini-3-flash-preview",
|
||||
success=True,
|
||||
user_id=current_user.user_id,
|
||||
)
|
||||
log.tool_trace = cast(Any, '{"mode":"function_tools","round":1}')
|
||||
session.add(log)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from datetime import date
|
|||
from unittest.mock import patch
|
||||
|
||||
from innercontext.models import Routine, SkinConditionSnapshot
|
||||
from innercontext.models.enums import BarrierState, OverallSkinState
|
||||
from innercontext.models.enums import BarrierState, OverallSkinState, PartOfDay
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Routines
|
||||
|
|
@ -223,13 +223,14 @@ def test_delete_grooming_schedule_not_found(client):
|
|||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_suggest_routine(client, session):
|
||||
def test_suggest_routine(client, session, current_user):
|
||||
with patch(
|
||||
"innercontext.api.routines.call_gemini_with_function_tools"
|
||||
) as mock_gemini:
|
||||
session.add(
|
||||
SkinConditionSnapshot(
|
||||
id=uuid.uuid4(),
|
||||
user_id=current_user.user_id,
|
||||
snapshot_date=date(2026, 2, 22),
|
||||
overall_state=OverallSkinState.GOOD,
|
||||
hydration_level=4,
|
||||
|
|
@ -272,18 +273,20 @@ def test_suggest_routine(client, session):
|
|||
assert "get_product_details" in kwargs["function_handlers"]
|
||||
|
||||
|
||||
def test_suggest_batch(client, session):
|
||||
def test_suggest_batch(client, session, current_user):
|
||||
with patch("innercontext.api.routines.call_gemini") as mock_gemini:
|
||||
session.add(
|
||||
Routine(
|
||||
id=uuid.uuid4(),
|
||||
user_id=current_user.user_id,
|
||||
routine_date=date(2026, 2, 27),
|
||||
part_of_day="pm",
|
||||
part_of_day=PartOfDay.PM,
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
SkinConditionSnapshot(
|
||||
id=uuid.uuid4(),
|
||||
user_id=current_user.user_id,
|
||||
snapshot_date=date(2026, 2, 20),
|
||||
overall_state=OverallSkinState.GOOD,
|
||||
hydration_level=4,
|
||||
|
|
|
|||
112
backend/tests/test_routines_auth.py
Normal file
112
backend/tests/test_routines_auth.py
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from innercontext.api.auth_deps import get_current_user
|
||||
from innercontext.auth import CurrentUser, IdentityData, TokenClaims
|
||||
from innercontext.models import Role
|
||||
from main import app
|
||||
|
||||
|
||||
def _user(subject: str, *, role: Role = Role.MEMBER) -> CurrentUser:
|
||||
claims = TokenClaims(
|
||||
issuer="https://auth.test",
|
||||
subject=subject,
|
||||
audience=("innercontext-web",),
|
||||
expires_at=datetime.now(UTC) + timedelta(hours=1),
|
||||
raw_claims={"iss": "https://auth.test", "sub": subject},
|
||||
)
|
||||
return CurrentUser(
|
||||
user_id=uuid4(),
|
||||
role=role,
|
||||
identity=IdentityData.from_claims(claims),
|
||||
claims=claims,
|
||||
)
|
||||
|
||||
|
||||
def _set_current_user(user: CurrentUser) -> None:
|
||||
app.dependency_overrides[get_current_user] = lambda: user
|
||||
|
||||
|
||||
def test_suggest_uses_current_user_profile_and_visible_products_only(client):
|
||||
owner = _user("owner")
|
||||
other = _user("other")
|
||||
|
||||
_set_current_user(owner)
|
||||
owner_profile = client.patch(
|
||||
"/profile", json={"birth_date": "1991-01-15", "sex_at_birth": "male"}
|
||||
)
|
||||
owner_product = client.post(
|
||||
"/products",
|
||||
json={
|
||||
"name": "Owner Serum",
|
||||
"brand": "Test",
|
||||
"category": "serum",
|
||||
"recommended_time": "both",
|
||||
"leave_on": True,
|
||||
},
|
||||
)
|
||||
assert owner_profile.status_code == 200
|
||||
assert owner_product.status_code == 201
|
||||
|
||||
_set_current_user(other)
|
||||
other_profile = client.patch(
|
||||
"/profile", json={"birth_date": "1975-06-20", "sex_at_birth": "female"}
|
||||
)
|
||||
other_product = client.post(
|
||||
"/products",
|
||||
json={
|
||||
"name": "Other Serum",
|
||||
"brand": "Test",
|
||||
"category": "serum",
|
||||
"recommended_time": "both",
|
||||
"leave_on": True,
|
||||
},
|
||||
)
|
||||
assert other_profile.status_code == 200
|
||||
assert other_product.status_code == 201
|
||||
|
||||
_set_current_user(owner)
|
||||
|
||||
with patch(
|
||||
"innercontext.api.routines.call_gemini_with_function_tools"
|
||||
) as mock_gemini:
|
||||
mock_response = type(
|
||||
"Response",
|
||||
(),
|
||||
{
|
||||
"text": '{"steps": [{"product_id": null, "action_type": "shaving_razor"}], "reasoning": "ok", "summary": {"primary_goal": "safe", "constraints_applied": [], "confidence": 0.7}}'
|
||||
},
|
||||
)
|
||||
mock_gemini.return_value = (mock_response, None)
|
||||
|
||||
response = client.post(
|
||||
"/routines/suggest",
|
||||
json={
|
||||
"routine_date": "2026-03-05",
|
||||
"part_of_day": "am",
|
||||
"include_minoxidil_beard": False,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
kwargs = mock_gemini.call_args.kwargs
|
||||
prompt = kwargs["contents"]
|
||||
assert "Birth date: 1991-01-15" in prompt
|
||||
assert "Birth date: 1975-06-20" not in prompt
|
||||
assert "Owner Serum" in prompt
|
||||
assert "Other Serum" not in prompt
|
||||
|
||||
handler = kwargs["function_handlers"]["get_product_details"]
|
||||
payload = handler(
|
||||
{
|
||||
"product_ids": [
|
||||
owner_product.json()["id"],
|
||||
other_product.json()["id"],
|
||||
]
|
||||
}
|
||||
)
|
||||
assert len(payload["products"]) == 1
|
||||
assert payload["products"][0]["name"] == "Owner Serum"
|
||||
|
|
@ -78,17 +78,22 @@ def test_ev():
|
|||
assert _ev("string") == "string"
|
||||
|
||||
|
||||
def test_build_skin_context(session: Session):
|
||||
def test_build_skin_context(session: Session, current_user):
|
||||
# Empty
|
||||
reference_date = date(2026, 3, 10)
|
||||
assert (
|
||||
_build_skin_context(session, reference_date=reference_date)
|
||||
_build_skin_context(
|
||||
session,
|
||||
target_user_id=current_user.user_id,
|
||||
reference_date=reference_date,
|
||||
)
|
||||
== "SKIN CONDITION: no data\n"
|
||||
)
|
||||
|
||||
# With data
|
||||
snap = SkinConditionSnapshot(
|
||||
id=uuid.uuid4(),
|
||||
user_id=current_user.user_id,
|
||||
snapshot_date=reference_date,
|
||||
overall_state=OverallSkinState.GOOD,
|
||||
hydration_level=4,
|
||||
|
|
@ -100,7 +105,11 @@ def test_build_skin_context(session: Session):
|
|||
session.add(snap)
|
||||
session.commit()
|
||||
|
||||
ctx = _build_skin_context(session, reference_date=reference_date)
|
||||
ctx = _build_skin_context(
|
||||
session,
|
||||
target_user_id=current_user.user_id,
|
||||
reference_date=reference_date,
|
||||
)
|
||||
assert "SKIN CONDITION (snapshot from" in ctx
|
||||
assert "Overall state: good" in ctx
|
||||
assert "Hydration: 4/5" in ctx
|
||||
|
|
@ -112,10 +121,12 @@ def test_build_skin_context(session: Session):
|
|||
|
||||
def test_build_skin_context_falls_back_to_recent_snapshot_within_14_days(
|
||||
session: Session,
|
||||
current_user,
|
||||
):
|
||||
reference_date = date(2026, 3, 20)
|
||||
snap = SkinConditionSnapshot(
|
||||
id=uuid.uuid4(),
|
||||
user_id=current_user.user_id,
|
||||
snapshot_date=reference_date - timedelta(days=10),
|
||||
overall_state=OverallSkinState.FAIR,
|
||||
hydration_level=3,
|
||||
|
|
@ -126,16 +137,23 @@ def test_build_skin_context_falls_back_to_recent_snapshot_within_14_days(
|
|||
session.add(snap)
|
||||
session.commit()
|
||||
|
||||
ctx = _build_skin_context(session, reference_date=reference_date)
|
||||
ctx = _build_skin_context(
|
||||
session,
|
||||
target_user_id=current_user.user_id,
|
||||
reference_date=reference_date,
|
||||
)
|
||||
|
||||
assert f"snapshot from {reference_date - timedelta(days=10)}" in ctx
|
||||
assert "Barrier: compromised" in ctx
|
||||
|
||||
|
||||
def test_build_skin_context_ignores_snapshot_older_than_14_days(session: Session):
|
||||
def test_build_skin_context_ignores_snapshot_older_than_14_days(
|
||||
session: Session, current_user
|
||||
):
|
||||
reference_date = date(2026, 3, 20)
|
||||
snap = SkinConditionSnapshot(
|
||||
id=uuid.uuid4(),
|
||||
user_id=current_user.user_id,
|
||||
snapshot_date=reference_date - timedelta(days=15),
|
||||
overall_state=OverallSkinState.FAIR,
|
||||
hydration_level=3,
|
||||
|
|
@ -145,15 +163,20 @@ def test_build_skin_context_ignores_snapshot_older_than_14_days(session: Session
|
|||
session.commit()
|
||||
|
||||
assert (
|
||||
_build_skin_context(session, reference_date=reference_date)
|
||||
_build_skin_context(
|
||||
session,
|
||||
target_user_id=current_user.user_id,
|
||||
reference_date=reference_date,
|
||||
)
|
||||
== "SKIN CONDITION: no data\n"
|
||||
)
|
||||
|
||||
|
||||
def test_get_recent_skin_snapshot_prefers_window_match(session: Session):
|
||||
def test_get_recent_skin_snapshot_prefers_window_match(session: Session, current_user):
|
||||
reference_date = date(2026, 3, 20)
|
||||
older = SkinConditionSnapshot(
|
||||
id=uuid.uuid4(),
|
||||
user_id=current_user.user_id,
|
||||
snapshot_date=reference_date - timedelta(days=10),
|
||||
overall_state=OverallSkinState.POOR,
|
||||
hydration_level=2,
|
||||
|
|
@ -161,6 +184,7 @@ def test_get_recent_skin_snapshot_prefers_window_match(session: Session):
|
|||
)
|
||||
newer = SkinConditionSnapshot(
|
||||
id=uuid.uuid4(),
|
||||
user_id=current_user.user_id,
|
||||
snapshot_date=reference_date - timedelta(days=2),
|
||||
overall_state=OverallSkinState.GOOD,
|
||||
hydration_level=4,
|
||||
|
|
@ -169,7 +193,11 @@ def test_get_recent_skin_snapshot_prefers_window_match(session: Session):
|
|||
session.add_all([older, newer])
|
||||
session.commit()
|
||||
|
||||
snapshot = _get_recent_skin_snapshot(session, reference_date=reference_date)
|
||||
snapshot = _get_recent_skin_snapshot(
|
||||
session,
|
||||
target_user_id=current_user.user_id,
|
||||
reference_date=reference_date,
|
||||
)
|
||||
|
||||
assert snapshot is not None
|
||||
assert snapshot.id == newer.id
|
||||
|
|
@ -177,10 +205,12 @@ def test_get_recent_skin_snapshot_prefers_window_match(session: Session):
|
|||
|
||||
def test_get_latest_skin_snapshot_within_days_uses_latest_within_14_days(
|
||||
session: Session,
|
||||
current_user,
|
||||
):
|
||||
reference_date = date(2026, 3, 20)
|
||||
older = SkinConditionSnapshot(
|
||||
id=uuid.uuid4(),
|
||||
user_id=current_user.user_id,
|
||||
snapshot_date=reference_date - timedelta(days=10),
|
||||
overall_state=OverallSkinState.POOR,
|
||||
hydration_level=2,
|
||||
|
|
@ -188,6 +218,7 @@ def test_get_latest_skin_snapshot_within_days_uses_latest_within_14_days(
|
|||
)
|
||||
newer = SkinConditionSnapshot(
|
||||
id=uuid.uuid4(),
|
||||
user_id=current_user.user_id,
|
||||
snapshot_date=reference_date - timedelta(days=2),
|
||||
overall_state=OverallSkinState.GOOD,
|
||||
hydration_level=4,
|
||||
|
|
@ -198,6 +229,7 @@ def test_get_latest_skin_snapshot_within_days_uses_latest_within_14_days(
|
|||
|
||||
snapshot = _get_latest_skin_snapshot_within_days(
|
||||
session,
|
||||
target_user_id=current_user.user_id,
|
||||
reference_date=reference_date,
|
||||
)
|
||||
|
||||
|
|
@ -205,39 +237,65 @@ def test_get_latest_skin_snapshot_within_days_uses_latest_within_14_days(
|
|||
assert snapshot.id == newer.id
|
||||
|
||||
|
||||
def test_build_grooming_context(session: Session):
|
||||
assert _build_grooming_context(session) == "GROOMING SCHEDULE: none\n"
|
||||
def test_build_grooming_context(session: Session, current_user):
|
||||
assert (
|
||||
_build_grooming_context(session, target_user_id=current_user.user_id)
|
||||
== "GROOMING SCHEDULE: none\n"
|
||||
)
|
||||
|
||||
sch = GroomingSchedule(
|
||||
id=uuid.uuid4(), day_of_week=0, action="shaving_oneblade", notes="Morning"
|
||||
id=uuid.uuid4(),
|
||||
user_id=current_user.user_id,
|
||||
day_of_week=0,
|
||||
action="shaving_oneblade",
|
||||
notes="Morning",
|
||||
)
|
||||
session.add(sch)
|
||||
session.commit()
|
||||
|
||||
ctx = _build_grooming_context(session)
|
||||
ctx = _build_grooming_context(session, target_user_id=current_user.user_id)
|
||||
assert "GROOMING SCHEDULE:" in ctx
|
||||
assert "poniedziałek: shaving_oneblade (Morning)" in ctx
|
||||
|
||||
# Test weekdays filter
|
||||
ctx2 = _build_grooming_context(session, weekdays=[1]) # not monday
|
||||
ctx2 = _build_grooming_context(
|
||||
session,
|
||||
target_user_id=current_user.user_id,
|
||||
weekdays=[1],
|
||||
) # not monday
|
||||
assert "(no entries for specified days)" in ctx2
|
||||
|
||||
|
||||
def test_build_upcoming_grooming_context(session: Session):
|
||||
def test_build_upcoming_grooming_context(session: Session, current_user):
|
||||
assert (
|
||||
_build_upcoming_grooming_context(session, start_date=date(2026, 3, 2), days=7)
|
||||
_build_upcoming_grooming_context(
|
||||
session,
|
||||
target_user_id=current_user.user_id,
|
||||
start_date=date(2026, 3, 2),
|
||||
days=7,
|
||||
)
|
||||
== "UPCOMING GROOMING (next 7 days): none\n"
|
||||
)
|
||||
|
||||
monday = GroomingSchedule(
|
||||
id=uuid.uuid4(), day_of_week=0, action="shaving_oneblade", notes="Morning"
|
||||
id=uuid.uuid4(),
|
||||
user_id=current_user.user_id,
|
||||
day_of_week=0,
|
||||
action="shaving_oneblade",
|
||||
notes="Morning",
|
||||
)
|
||||
wednesday = GroomingSchedule(
|
||||
id=uuid.uuid4(),
|
||||
user_id=current_user.user_id,
|
||||
day_of_week=2,
|
||||
action="dermarolling",
|
||||
)
|
||||
wednesday = GroomingSchedule(id=uuid.uuid4(), day_of_week=2, action="dermarolling")
|
||||
session.add_all([monday, wednesday])
|
||||
session.commit()
|
||||
|
||||
ctx = _build_upcoming_grooming_context(
|
||||
session,
|
||||
target_user_id=current_user.user_id,
|
||||
start_date=date(2026, 3, 2),
|
||||
days=7,
|
||||
)
|
||||
|
|
@ -246,14 +304,23 @@ def test_build_upcoming_grooming_context(session: Session):
|
|||
assert "za 2 dni (2026-03-04, środa): dermarolling" in ctx
|
||||
|
||||
|
||||
def test_build_recent_history(session: Session):
|
||||
def test_build_recent_history(session: Session, current_user):
|
||||
reference_date = date(2026, 3, 10)
|
||||
assert (
|
||||
_build_recent_history(session, reference_date=reference_date)
|
||||
_build_recent_history(
|
||||
session,
|
||||
target_user_id=current_user.user_id,
|
||||
reference_date=reference_date,
|
||||
)
|
||||
== "RECENT ROUTINES: none\n"
|
||||
)
|
||||
|
||||
r = Routine(id=uuid.uuid4(), routine_date=reference_date, part_of_day="am")
|
||||
r = Routine(
|
||||
id=uuid.uuid4(),
|
||||
user_id=current_user.user_id,
|
||||
routine_date=reference_date,
|
||||
part_of_day="am",
|
||||
)
|
||||
session.add(r)
|
||||
p = Product(
|
||||
id=uuid.uuid4(),
|
||||
|
|
@ -268,19 +335,37 @@ def test_build_recent_history(session: Session):
|
|||
session.add(p)
|
||||
session.commit()
|
||||
|
||||
s1 = RoutineStep(id=uuid.uuid4(), routine_id=r.id, order_index=1, product_id=p.id)
|
||||
s1 = RoutineStep(
|
||||
id=uuid.uuid4(),
|
||||
user_id=current_user.user_id,
|
||||
routine_id=r.id,
|
||||
order_index=1,
|
||||
product_id=p.id,
|
||||
)
|
||||
s2 = RoutineStep(
|
||||
id=uuid.uuid4(), routine_id=r.id, order_index=2, action_type="shaving_razor"
|
||||
id=uuid.uuid4(),
|
||||
user_id=current_user.user_id,
|
||||
routine_id=r.id,
|
||||
order_index=2,
|
||||
action_type="shaving_razor",
|
||||
)
|
||||
# Step with non-existent product
|
||||
s3 = RoutineStep(
|
||||
id=uuid.uuid4(), routine_id=r.id, order_index=3, product_id=uuid.uuid4()
|
||||
id=uuid.uuid4(),
|
||||
user_id=current_user.user_id,
|
||||
routine_id=r.id,
|
||||
order_index=3,
|
||||
product_id=uuid.uuid4(),
|
||||
)
|
||||
|
||||
session.add_all([s1, s2, s3])
|
||||
session.commit()
|
||||
|
||||
ctx = _build_recent_history(session, reference_date=reference_date)
|
||||
ctx = _build_recent_history(
|
||||
session,
|
||||
target_user_id=current_user.user_id,
|
||||
reference_date=reference_date,
|
||||
)
|
||||
assert "RECENT ROUTINES:" in ctx
|
||||
assert "AM:" in ctx
|
||||
assert "cleanser [" in ctx
|
||||
|
|
@ -288,31 +373,38 @@ def test_build_recent_history(session: Session):
|
|||
assert "unknown [" in ctx
|
||||
|
||||
|
||||
def test_build_recent_history_uses_reference_window(session: Session):
|
||||
def test_build_recent_history_uses_reference_window(session: Session, current_user):
|
||||
reference_date = date(2026, 3, 10)
|
||||
recent = Routine(
|
||||
id=uuid.uuid4(),
|
||||
user_id=current_user.user_id,
|
||||
routine_date=reference_date - timedelta(days=3),
|
||||
part_of_day="pm",
|
||||
)
|
||||
old = Routine(
|
||||
id=uuid.uuid4(),
|
||||
user_id=current_user.user_id,
|
||||
routine_date=reference_date - timedelta(days=6),
|
||||
part_of_day="am",
|
||||
)
|
||||
session.add_all([recent, old])
|
||||
session.commit()
|
||||
|
||||
ctx = _build_recent_history(session, reference_date=reference_date)
|
||||
ctx = _build_recent_history(
|
||||
session,
|
||||
target_user_id=current_user.user_id,
|
||||
reference_date=reference_date,
|
||||
)
|
||||
|
||||
assert str(recent.routine_date) in ctx
|
||||
assert str(old.routine_date) not in ctx
|
||||
|
||||
|
||||
def test_build_recent_history_excludes_future_routines(session: Session):
|
||||
def test_build_recent_history_excludes_future_routines(session: Session, current_user):
|
||||
reference_date = date(2026, 3, 10)
|
||||
future = Routine(
|
||||
id=uuid.uuid4(),
|
||||
user_id=current_user.user_id,
|
||||
routine_date=reference_date + timedelta(days=1),
|
||||
part_of_day="am",
|
||||
)
|
||||
|
|
@ -320,12 +412,16 @@ def test_build_recent_history_excludes_future_routines(session: Session):
|
|||
session.commit()
|
||||
|
||||
assert (
|
||||
_build_recent_history(session, reference_date=reference_date)
|
||||
_build_recent_history(
|
||||
session,
|
||||
target_user_id=current_user.user_id,
|
||||
reference_date=reference_date,
|
||||
)
|
||||
== "RECENT ROUTINES: none\n"
|
||||
)
|
||||
|
||||
|
||||
def test_build_products_context_summary_list(session: Session):
|
||||
def test_build_products_context_summary_list(session: Session, current_user):
|
||||
p1 = Product(
|
||||
id=uuid.uuid4(),
|
||||
short_id=str(uuid.uuid4())[:8],
|
||||
|
|
@ -336,6 +432,7 @@ def test_build_products_context_summary_list(session: Session):
|
|||
recommended_time="both",
|
||||
leave_on=True,
|
||||
product_effect_profile={},
|
||||
user_id=current_user.user_id,
|
||||
)
|
||||
p2 = Product(
|
||||
id=uuid.uuid4(),
|
||||
|
|
@ -350,11 +447,16 @@ def test_build_products_context_summary_list(session: Session):
|
|||
context_rules={"safe_after_shaving": False},
|
||||
min_interval_hours=12,
|
||||
max_frequency_per_week=7,
|
||||
user_id=current_user.user_id,
|
||||
)
|
||||
session.add_all([p1, p2])
|
||||
session.commit()
|
||||
|
||||
products_am = _get_available_products(session, time_filter="am")
|
||||
products_am = _get_available_products(
|
||||
session,
|
||||
current_user=current_user,
|
||||
time_filter="am",
|
||||
)
|
||||
ctx = build_products_context_summary_list(products_am, {p2.id})
|
||||
|
||||
assert "Regaine Minoxidil" in ctx
|
||||
|
|
@ -375,7 +477,7 @@ def test_build_day_context():
|
|||
assert "Leaving home: no" in _build_day_context(False)
|
||||
|
||||
|
||||
def test_get_available_products_respects_filters(session: Session):
|
||||
def test_get_available_products_respects_filters(session: Session, current_user):
|
||||
regular_med = Product(
|
||||
id=uuid.uuid4(),
|
||||
name="Tretinoin",
|
||||
|
|
@ -385,6 +487,7 @@ def test_get_available_products_respects_filters(session: Session):
|
|||
recommended_time="pm",
|
||||
leave_on=True,
|
||||
product_effect_profile={},
|
||||
user_id=current_user.user_id,
|
||||
)
|
||||
minoxidil_med = Product(
|
||||
id=uuid.uuid4(),
|
||||
|
|
@ -395,6 +498,7 @@ def test_get_available_products_respects_filters(session: Session):
|
|||
recommended_time="both",
|
||||
leave_on=True,
|
||||
product_effect_profile={},
|
||||
user_id=current_user.user_id,
|
||||
)
|
||||
am_product = Product(
|
||||
id=uuid.uuid4(),
|
||||
|
|
@ -404,6 +508,7 @@ def test_get_available_products_respects_filters(session: Session):
|
|||
recommended_time="am",
|
||||
leave_on=True,
|
||||
product_effect_profile={},
|
||||
user_id=current_user.user_id,
|
||||
)
|
||||
pm_product = Product(
|
||||
id=uuid.uuid4(),
|
||||
|
|
@ -413,11 +518,16 @@ def test_get_available_products_respects_filters(session: Session):
|
|||
recommended_time="pm",
|
||||
leave_on=True,
|
||||
product_effect_profile={},
|
||||
user_id=current_user.user_id,
|
||||
)
|
||||
session.add_all([regular_med, minoxidil_med, am_product, pm_product])
|
||||
session.commit()
|
||||
|
||||
am_available = _get_available_products(session, time_filter="am")
|
||||
am_available = _get_available_products(
|
||||
session,
|
||||
current_user=current_user,
|
||||
time_filter="am",
|
||||
)
|
||||
am_names = {p.name for p in am_available}
|
||||
assert "Tretinoin" not in am_names
|
||||
assert "Minoxidil 5%" in am_names
|
||||
|
|
@ -508,7 +618,10 @@ def test_extract_active_names_uses_compact_distinct_names(session: Session):
|
|||
assert names == ["Niacinamide", "Zinc PCA"]
|
||||
|
||||
|
||||
def test_get_available_products_excludes_minoxidil_when_flag_false(session: Session):
|
||||
def test_get_available_products_excludes_minoxidil_when_flag_false(
|
||||
session: Session,
|
||||
current_user,
|
||||
):
|
||||
minoxidil = Product(
|
||||
id=uuid.uuid4(),
|
||||
name="Minoxidil 5%",
|
||||
|
|
@ -518,6 +631,7 @@ def test_get_available_products_excludes_minoxidil_when_flag_false(session: Sess
|
|||
recommended_time="both",
|
||||
leave_on=True,
|
||||
product_effect_profile={},
|
||||
user_id=current_user.user_id,
|
||||
)
|
||||
regular = Product(
|
||||
id=uuid.uuid4(),
|
||||
|
|
@ -527,18 +641,27 @@ def test_get_available_products_excludes_minoxidil_when_flag_false(session: Sess
|
|||
recommended_time="both",
|
||||
leave_on=False,
|
||||
product_effect_profile={},
|
||||
user_id=current_user.user_id,
|
||||
)
|
||||
session.add_all([minoxidil, regular])
|
||||
session.commit()
|
||||
|
||||
# With flag True (default) - minoxidil included
|
||||
products = _get_available_products(session, include_minoxidil=True)
|
||||
products = _get_available_products(
|
||||
session,
|
||||
current_user=current_user,
|
||||
include_minoxidil=True,
|
||||
)
|
||||
names = {p.name for p in products}
|
||||
assert "Minoxidil 5%" in names
|
||||
assert "Cleanser" in names
|
||||
|
||||
# With flag False - minoxidil excluded
|
||||
products = _get_available_products(session, include_minoxidil=False)
|
||||
products = _get_available_products(
|
||||
session,
|
||||
current_user=current_user,
|
||||
include_minoxidil=False,
|
||||
)
|
||||
names = {p.name for p in products}
|
||||
assert "Minoxidil 5%" not in names
|
||||
assert "Cleanser" in names
|
||||
|
|
|
|||
|
|
@ -140,7 +140,7 @@ def test_analyze_photos_includes_user_profile_context(client, monkeypatch):
|
|||
|
||||
def _fake_call_gemini(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return _FakeResponse()
|
||||
return _FakeResponse(), None
|
||||
|
||||
monkeypatch.setattr(skincare_api, "call_gemini", _fake_call_gemini)
|
||||
|
||||
|
|
|
|||
100
backend/tests/test_tenancy_domains.py
Normal file
100
backend/tests/test_tenancy_domains.py
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from uuid import uuid4
|
||||
|
||||
from innercontext.api.auth_deps import get_current_user
|
||||
from innercontext.auth import CurrentUser, IdentityData, TokenClaims
|
||||
from innercontext.models import Role
|
||||
from innercontext.models.ai_log import AICallLog
|
||||
from main import app
|
||||
|
||||
|
||||
def _user(subject: str, *, role: Role = Role.MEMBER) -> CurrentUser:
|
||||
claims = TokenClaims(
|
||||
issuer="https://auth.test",
|
||||
subject=subject,
|
||||
audience=("innercontext-web",),
|
||||
expires_at=datetime.now(UTC) + timedelta(hours=1),
|
||||
raw_claims={"iss": "https://auth.test", "sub": subject},
|
||||
)
|
||||
return CurrentUser(
|
||||
user_id=uuid4(),
|
||||
role=role,
|
||||
identity=IdentityData.from_claims(claims),
|
||||
claims=claims,
|
||||
)
|
||||
|
||||
|
||||
def _set_current_user(user: CurrentUser) -> None:
|
||||
app.dependency_overrides[get_current_user] = lambda: user
|
||||
|
||||
|
||||
def test_profile_health_routines_skincare_ai_logs_are_user_scoped_by_default(
|
||||
client, session
|
||||
):
|
||||
owner = _user("owner")
|
||||
intruder = _user("intruder")
|
||||
|
||||
_set_current_user(owner)
|
||||
profile = client.patch(
|
||||
"/profile", json={"birth_date": "1991-01-15", "sex_at_birth": "male"}
|
||||
)
|
||||
medication = client.post(
|
||||
"/health/medications", json={"kind": "prescription", "product_name": "Owner Rx"}
|
||||
)
|
||||
routine = client.post(
|
||||
"/routines", json={"routine_date": "2026-03-01", "part_of_day": "am"}
|
||||
)
|
||||
snapshot = client.post("/skincare", json={"snapshot_date": "2026-03-01"})
|
||||
log = AICallLog(endpoint="routines/suggest", model="gemini-3-flash-preview")
|
||||
log.user_id = owner.user_id
|
||||
session.add(log)
|
||||
session.commit()
|
||||
session.refresh(log)
|
||||
|
||||
assert profile.status_code == 200
|
||||
assert medication.status_code == 201
|
||||
assert routine.status_code == 201
|
||||
assert snapshot.status_code == 201
|
||||
|
||||
medication_id = medication.json()["record_id"]
|
||||
routine_id = routine.json()["id"]
|
||||
snapshot_id = snapshot.json()["id"]
|
||||
|
||||
_set_current_user(intruder)
|
||||
assert client.get("/profile").json() is None
|
||||
assert client.get("/health/medications").json() == []
|
||||
assert client.get("/routines").json() == []
|
||||
assert client.get("/skincare").json() == []
|
||||
assert client.get("/ai-logs").json() == []
|
||||
|
||||
assert client.get(f"/health/medications/{medication_id}").status_code == 404
|
||||
assert client.get(f"/routines/{routine_id}").status_code == 404
|
||||
assert client.get(f"/skincare/{snapshot_id}").status_code == 404
|
||||
assert client.get(f"/ai-logs/{log.id}").status_code == 404
|
||||
|
||||
|
||||
def test_health_admin_override_requires_explicit_user_id(client):
|
||||
owner = _user("owner")
|
||||
admin = _user("admin", role=Role.ADMIN)
|
||||
|
||||
_set_current_user(owner)
|
||||
created = client.post(
|
||||
"/health/lab-results",
|
||||
json={
|
||||
"collected_at": "2026-03-01T00:00:00",
|
||||
"test_code": "718-7",
|
||||
"test_name_original": "Hemoglobin",
|
||||
},
|
||||
)
|
||||
assert created.status_code == 201
|
||||
|
||||
_set_current_user(admin)
|
||||
default_scope = client.get("/health/lab-results")
|
||||
assert default_scope.status_code == 200
|
||||
assert default_scope.json()["items"] == []
|
||||
|
||||
overridden = client.get(f"/health/lab-results?user_id={owner.user_id}")
|
||||
assert overridden.status_code == 200
|
||||
assert len(overridden.json()["items"]) == 1
|
||||
Loading…
Add table
Add a link
Reference in a new issue