diff --git a/backend/innercontext/api/ai_logs.py b/backend/innercontext/api/ai_logs.py index 040d47e..b407006 100644 --- a/backend/innercontext/api/ai_logs.py +++ b/backend/innercontext/api/ai_logs.py @@ -2,10 +2,13 @@ import json from typing import Any, Optional from uuid import UUID -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends, HTTPException, Query from sqlmodel import Session, SQLModel, col, select from db import get_session +from innercontext.api.auth_deps import get_current_user +from innercontext.auth import CurrentUser +from innercontext.models.enums import Role from innercontext.models.ai_log import AICallLog router = APIRouter() @@ -43,14 +46,33 @@ class AICallLogPublic(SQLModel): error_detail: Optional[str] = None +def _resolve_target_user_id( + current_user: CurrentUser, + user_id: UUID | None, +) -> UUID: + if user_id is None: + return current_user.user_id + if current_user.role is not Role.ADMIN: + raise HTTPException(status_code=403, detail="Admin role required") + return user_id + + @router.get("", response_model=list[AICallLogPublic]) def list_ai_logs( endpoint: Optional[str] = None, success: Optional[bool] = None, limit: int = 50, + user_id: UUID | None = Query(default=None), session: Session = Depends(get_session), + current_user: CurrentUser = Depends(get_current_user), ): - stmt = select(AICallLog).order_by(col(AICallLog.created_at).desc()).limit(limit) + target_user_id = _resolve_target_user_id(current_user, user_id) + stmt = ( + select(AICallLog) + .where(AICallLog.user_id == target_user_id) + .order_by(col(AICallLog.created_at).desc()) + .limit(limit) + ) if endpoint is not None: stmt = stmt.where(AICallLog.endpoint == endpoint) if success is not None: @@ -75,9 +97,17 @@ def list_ai_logs( @router.get("/{log_id}", response_model=AICallLog) -def get_ai_log(log_id: UUID, session: Session = Depends(get_session)): +def get_ai_log( + log_id: UUID, + user_id: UUID | None = Query(default=None), + session: Session = Depends(get_session), + current_user: CurrentUser = Depends(get_current_user), +): + target_user_id = _resolve_target_user_id(current_user, user_id) log = session.get(AICallLog, log_id) if log is None: raise HTTPException(status_code=404, detail="Log not found") + if log.user_id != target_user_id: + raise HTTPException(status_code=404, detail="Log not found") log.tool_trace = _normalize_tool_trace(getattr(log, "tool_trace", None)) return log diff --git a/backend/innercontext/api/health.py b/backend/innercontext/api/health.py index d3e8064..9f2334e 100644 --- a/backend/innercontext/api/health.py +++ b/backend/innercontext/api/health.py @@ -3,15 +3,17 @@ from datetime import datetime from typing import Optional from uuid import UUID, uuid4 -from fastapi import APIRouter, Depends, Query +from fastapi import APIRouter, Depends, HTTPException, Query from pydantic import field_validator from sqlalchemy import Integer, cast, func, or_ from sqlmodel import Session, SQLModel, col, select from db import get_session -from innercontext.api.utils import get_or_404 +from innercontext.api.auth_deps import get_current_user +from innercontext.api.utils import get_owned_or_404 +from innercontext.auth import CurrentUser from innercontext.models import LabResult, MedicationEntry, MedicationUsage -from innercontext.models.enums import MedicationKind, ResultFlag +from innercontext.models.enums import MedicationKind, ResultFlag, Role router = APIRouter() @@ -133,6 +135,34 @@ class LabResultListResponse(SQLModel): # --------------------------------------------------------------------------- +def _resolve_target_user_id( + current_user: CurrentUser, + user_id: UUID | None, +) -> UUID: + if user_id is None: + return current_user.user_id + if current_user.role is not Role.ADMIN: + raise HTTPException(status_code=403, detail="Admin role required") + return user_id + + +def _get_owned_or_admin_override( + session: Session, + model: type[MedicationEntry] | type[MedicationUsage] | type[LabResult], + record_id: UUID, + current_user: CurrentUser, + user_id: UUID | None, +): + if user_id is None: + return get_owned_or_404(session, model, record_id, current_user) + record = session.get(model, record_id) + if record is None or record.user_id != _resolve_target_user_id( + current_user, user_id + ): + raise HTTPException(status_code=404, detail=f"{model.__name__} not found") + return record + + # --------------------------------------------------------------------------- # Medication routes # --------------------------------------------------------------------------- @@ -142,9 +172,12 @@ class LabResultListResponse(SQLModel): def list_medications( kind: Optional[MedicationKind] = None, product_name: Optional[str] = None, + user_id: UUID | None = Query(default=None), session: Session = Depends(get_session), + current_user: CurrentUser = Depends(get_current_user), ): - stmt = select(MedicationEntry) + target_user_id = _resolve_target_user_id(current_user, user_id) + stmt = select(MedicationEntry).where(MedicationEntry.user_id == target_user_id) if kind is not None: stmt = stmt.where(MedicationEntry.kind == kind) if product_name is not None: @@ -153,8 +186,18 @@ def list_medications( @router.post("/medications", response_model=MedicationEntry, status_code=201) -def create_medication(data: MedicationCreate, session: Session = Depends(get_session)): - entry = MedicationEntry(record_id=uuid4(), **data.model_dump()) +def create_medication( + data: MedicationCreate, + user_id: UUID | None = Query(default=None), + session: Session = Depends(get_session), + current_user: CurrentUser = Depends(get_current_user), +): + target_user_id = _resolve_target_user_id(current_user, user_id) + entry = MedicationEntry( + record_id=uuid4(), + user_id=target_user_id, + **data.model_dump(), + ) session.add(entry) session.commit() session.refresh(entry) @@ -162,17 +205,36 @@ def create_medication(data: MedicationCreate, session: Session = Depends(get_ses @router.get("/medications/{medication_id}", response_model=MedicationEntry) -def get_medication(medication_id: UUID, session: Session = Depends(get_session)): - return get_or_404(session, MedicationEntry, medication_id) +def get_medication( + medication_id: UUID, + user_id: UUID | None = Query(default=None), + session: Session = Depends(get_session), + current_user: CurrentUser = Depends(get_current_user), +): + return _get_owned_or_admin_override( + session, + MedicationEntry, + medication_id, + current_user, + user_id, + ) @router.patch("/medications/{medication_id}", response_model=MedicationEntry) def update_medication( medication_id: UUID, data: MedicationUpdate, + user_id: UUID | None = Query(default=None), session: Session = Depends(get_session), + current_user: CurrentUser = Depends(get_current_user), ): - entry = get_or_404(session, MedicationEntry, medication_id) + entry = _get_owned_or_admin_override( + session, + MedicationEntry, + medication_id, + current_user, + user_id, + ) for key, value in data.model_dump(exclude_unset=True).items(): setattr(entry, key, value) session.add(entry) @@ -182,13 +244,25 @@ def update_medication( @router.delete("/medications/{medication_id}", status_code=204) -def delete_medication(medication_id: UUID, session: Session = Depends(get_session)): - entry = get_or_404(session, MedicationEntry, medication_id) +def delete_medication( + medication_id: UUID, + user_id: UUID | None = Query(default=None), + session: Session = Depends(get_session), + current_user: CurrentUser = Depends(get_current_user), +): + target_user_id = _resolve_target_user_id(current_user, user_id) + entry = _get_owned_or_admin_override( + session, + MedicationEntry, + medication_id, + current_user, + user_id, + ) # Delete usages first (no cascade configured at DB level) usages = session.exec( - select(MedicationUsage).where( - MedicationUsage.medication_record_id == medication_id - ) + select(MedicationUsage) + .where(MedicationUsage.medication_record_id == medication_id) + .where(MedicationUsage.user_id == target_user_id) ).all() for u in usages: session.delete(u) @@ -202,10 +276,24 @@ def delete_medication(medication_id: UUID, session: Session = Depends(get_sessio @router.get("/medications/{medication_id}/usages", response_model=list[MedicationUsage]) -def list_usages(medication_id: UUID, session: Session = Depends(get_session)): - get_or_404(session, MedicationEntry, medication_id) - stmt = select(MedicationUsage).where( - MedicationUsage.medication_record_id == medication_id +def list_usages( + medication_id: UUID, + user_id: UUID | None = Query(default=None), + session: Session = Depends(get_session), + current_user: CurrentUser = Depends(get_current_user), +): + target_user_id = _resolve_target_user_id(current_user, user_id) + _ = _get_owned_or_admin_override( + session, + MedicationEntry, + medication_id, + current_user, + user_id, + ) + stmt = ( + select(MedicationUsage) + .where(MedicationUsage.medication_record_id == medication_id) + .where(MedicationUsage.user_id == target_user_id) ) return session.exec(stmt).all() @@ -218,11 +306,21 @@ def list_usages(medication_id: UUID, session: Session = Depends(get_session)): def create_usage( medication_id: UUID, data: UsageCreate, + user_id: UUID | None = Query(default=None), session: Session = Depends(get_session), + current_user: CurrentUser = Depends(get_current_user), ): - get_or_404(session, MedicationEntry, medication_id) + target_user_id = _resolve_target_user_id(current_user, user_id) + _ = _get_owned_or_admin_override( + session, + MedicationEntry, + medication_id, + current_user, + user_id, + ) usage = MedicationUsage( record_id=uuid4(), + user_id=target_user_id, medication_record_id=medication_id, **data.model_dump(), ) @@ -236,9 +334,17 @@ def create_usage( def update_usage( usage_id: UUID, data: UsageUpdate, + user_id: UUID | None = Query(default=None), session: Session = Depends(get_session), + current_user: CurrentUser = Depends(get_current_user), ): - usage = get_or_404(session, MedicationUsage, usage_id) + usage = _get_owned_or_admin_override( + session, + MedicationUsage, + usage_id, + current_user, + user_id, + ) for key, value in data.model_dump(exclude_unset=True).items(): setattr(usage, key, value) session.add(usage) @@ -248,8 +354,19 @@ def update_usage( @router.delete("/usages/{usage_id}", status_code=204) -def delete_usage(usage_id: UUID, session: Session = Depends(get_session)): - usage = get_or_404(session, MedicationUsage, usage_id) +def delete_usage( + usage_id: UUID, + user_id: UUID | None = Query(default=None), + session: Session = Depends(get_session), + current_user: CurrentUser = Depends(get_current_user), +): + usage = _get_owned_or_admin_override( + session, + MedicationUsage, + usage_id, + current_user, + user_id, + ) session.delete(usage) session.commit() @@ -271,29 +388,35 @@ def list_lab_results( latest_only: bool = False, limit: int = Query(default=50, ge=1, le=200), offset: int = Query(default=0, ge=0), + user_id: UUID | None = Query(default=None), session: Session = Depends(get_session), + current_user: CurrentUser = Depends(get_current_user), ): - filters = [] - if q is not None and q.strip(): - query = f"%{q.strip()}%" - filters.append( - or_( - col(LabResult.test_code).ilike(query), - col(LabResult.test_name_original).ilike(query), + target_user_id = _resolve_target_user_id(current_user, user_id) + + def _apply_filters(statement): + statement = statement.where(col(LabResult.user_id) == target_user_id) + if q is not None and q.strip(): + query = f"%{q.strip()}%" + statement = statement.where( + or_( + col(LabResult.test_code).ilike(query), + col(LabResult.test_name_original).ilike(query), + ) ) - ) - if test_code is not None: - filters.append(LabResult.test_code == test_code) - if flag is not None: - filters.append(LabResult.flag == flag) - if flags: - filters.append(col(LabResult.flag).in_(flags)) - if without_flag: - filters.append(col(LabResult.flag).is_(None)) - if from_date is not None: - filters.append(LabResult.collected_at >= from_date) - if to_date is not None: - filters.append(LabResult.collected_at <= to_date) + if test_code is not None: + statement = statement.where(col(LabResult.test_code) == test_code) + if flag is not None: + statement = statement.where(col(LabResult.flag) == flag) + if flags: + statement = statement.where(col(LabResult.flag).in_(flags)) + if without_flag: + statement = statement.where(col(LabResult.flag).is_(None)) + if from_date is not None: + statement = statement.where(col(LabResult.collected_at) >= from_date) + if to_date is not None: + statement = statement.where(col(LabResult.collected_at) <= to_date) + return statement if latest_only: ranked_stmt = select( @@ -309,8 +432,7 @@ def list_lab_results( ) .label("rank"), ) - if filters: - ranked_stmt = ranked_stmt.where(*filters) + ranked_stmt = _apply_filters(ranked_stmt) ranked_subquery = ranked_stmt.subquery() latest_ids = select(ranked_subquery.c.record_id).where( @@ -323,11 +445,8 @@ def list_lab_results( .subquery() ) else: - stmt = select(LabResult) - count_stmt = select(func.count()).select_from(LabResult) - if filters: - stmt = stmt.where(*filters) - count_stmt = count_stmt.where(*filters) + stmt = _apply_filters(select(LabResult)) + count_stmt = _apply_filters(select(func.count()).select_from(LabResult)) test_code_numeric = cast( func.replace(col(LabResult.test_code), "-", ""), @@ -345,8 +464,18 @@ def list_lab_results( @router.post("/lab-results", response_model=LabResult, status_code=201) -def create_lab_result(data: LabResultCreate, session: Session = Depends(get_session)): - result = LabResult(record_id=uuid4(), **data.model_dump()) +def create_lab_result( + data: LabResultCreate, + user_id: UUID | None = Query(default=None), + session: Session = Depends(get_session), + current_user: CurrentUser = Depends(get_current_user), +): + target_user_id = _resolve_target_user_id(current_user, user_id) + result = LabResult( + record_id=uuid4(), + user_id=target_user_id, + **data.model_dump(), + ) session.add(result) session.commit() session.refresh(result) @@ -354,17 +483,36 @@ def create_lab_result(data: LabResultCreate, session: Session = Depends(get_sess @router.get("/lab-results/{result_id}", response_model=LabResult) -def get_lab_result(result_id: UUID, session: Session = Depends(get_session)): - return get_or_404(session, LabResult, result_id) +def get_lab_result( + result_id: UUID, + user_id: UUID | None = Query(default=None), + session: Session = Depends(get_session), + current_user: CurrentUser = Depends(get_current_user), +): + return _get_owned_or_admin_override( + session, + LabResult, + result_id, + current_user, + user_id, + ) @router.patch("/lab-results/{result_id}", response_model=LabResult) def update_lab_result( result_id: UUID, data: LabResultUpdate, + user_id: UUID | None = Query(default=None), session: Session = Depends(get_session), + current_user: CurrentUser = Depends(get_current_user), ): - result = get_or_404(session, LabResult, result_id) + result = _get_owned_or_admin_override( + session, + LabResult, + result_id, + current_user, + user_id, + ) for key, value in data.model_dump(exclude_unset=True).items(): setattr(result, key, value) session.add(result) @@ -374,7 +522,18 @@ def update_lab_result( @router.delete("/lab-results/{result_id}", status_code=204) -def delete_lab_result(result_id: UUID, session: Session = Depends(get_session)): - result = get_or_404(session, LabResult, result_id) +def delete_lab_result( + result_id: UUID, + user_id: UUID | None = Query(default=None), + session: Session = Depends(get_session), + current_user: CurrentUser = Depends(get_current_user), +): + result = _get_owned_or_admin_override( + session, + LabResult, + result_id, + current_user, + user_id, + ) session.delete(result) session.commit() diff --git a/backend/innercontext/api/llm_context.py b/backend/innercontext/api/llm_context.py index 6ffc68e..40e7dcf 100644 --- a/backend/innercontext/api/llm_context.py +++ b/backend/innercontext/api/llm_context.py @@ -2,14 +2,41 @@ from datetime import date from typing import Any from uuid import UUID +from fastapi import HTTPException from sqlmodel import Session, col, select +from innercontext.auth import CurrentUser from innercontext.models import Product, UserProfile +from innercontext.models.enums import Role -def get_user_profile(session: Session) -> UserProfile | None: +def _resolve_target_user_id( + current_user: CurrentUser, + user_id: UUID | None, +) -> UUID: + if user_id is None: + return current_user.user_id + if current_user.role is not Role.ADMIN: + raise HTTPException(status_code=403, detail="Admin role required") + return user_id + + +def get_user_profile( + session: Session, + current_user: CurrentUser | None = None, + *, + user_id: UUID | None = None, +) -> UserProfile | None: + if current_user is None: + return session.exec( + select(UserProfile).order_by(col(UserProfile.created_at).desc()) + ).first() + + target_user_id = _resolve_target_user_id(current_user, user_id) return session.exec( - select(UserProfile).order_by(col(UserProfile.created_at).desc()) + select(UserProfile) + .where(UserProfile.user_id == target_user_id) + .order_by(col(UserProfile.created_at).desc()) ).first() @@ -20,8 +47,14 @@ def calculate_age(birth_date: date, reference_date: date) -> int: return years -def build_user_profile_context(session: Session, reference_date: date) -> str: - profile = get_user_profile(session) +def build_user_profile_context( + session: Session, + reference_date: date, + current_user: CurrentUser | None = None, + *, + user_id: UUID | None = None, +) -> str: + profile = get_user_profile(session, current_user, user_id=user_id) if profile is None: return "USER PROFILE: no data\n" @@ -69,8 +102,9 @@ def build_product_context_summary(product: Product, has_inventory: bool = False) # Get effect profile scores if available effects = [] - if hasattr(product, "effect_profile") and product.effect_profile: - profile = product.effect_profile + effect_profile = getattr(product, "effect_profile", None) + if effect_profile: + profile = effect_profile # Only include notable effects (score > 0) # Handle both dict (from DB) and object (from Pydantic) if isinstance(profile, dict): @@ -165,11 +199,12 @@ def build_product_context_detailed( # Effect profile effect_profile = None - if hasattr(product, "effect_profile") and product.effect_profile: - if isinstance(product.effect_profile, dict): - effect_profile = product.effect_profile + product_effect_profile = getattr(product, "effect_profile", None) + if product_effect_profile: + if isinstance(product_effect_profile, dict): + effect_profile = product_effect_profile else: - effect_profile = product.effect_profile.model_dump() + effect_profile = product_effect_profile.model_dump() # Context rules context_rules = None diff --git a/backend/innercontext/api/profile.py b/backend/innercontext/api/profile.py index 52e8e14..ebadbd4 100644 --- a/backend/innercontext/api/profile.py +++ b/backend/innercontext/api/profile.py @@ -1,11 +1,14 @@ from datetime import date, datetime from typing import Optional +from uuid import UUID -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, Query from sqlmodel import Session, SQLModel from db import get_session from innercontext.api.llm_context import get_user_profile +from innercontext.api.auth_deps import get_current_user +from innercontext.auth import CurrentUser from innercontext.models import SexAtBirth, UserProfile router = APIRouter() @@ -25,8 +28,12 @@ class UserProfilePublic(SQLModel): @router.get("", response_model=UserProfilePublic | None) -def get_profile(session: Session = Depends(get_session)): - profile = get_user_profile(session) +def get_profile( + user_id: UUID | None = Query(default=None), + session: Session = Depends(get_session), + current_user: CurrentUser = Depends(get_current_user), +): + profile = get_user_profile(session, current_user, user_id=user_id) if profile is None: return None return UserProfilePublic( @@ -39,12 +46,18 @@ def get_profile(session: Session = Depends(get_session)): @router.patch("", response_model=UserProfilePublic) -def upsert_profile(data: UserProfileUpdate, session: Session = Depends(get_session)): - profile = get_user_profile(session) +def upsert_profile( + data: UserProfileUpdate, + user_id: UUID | None = Query(default=None), + session: Session = Depends(get_session), + current_user: CurrentUser = Depends(get_current_user), +): + target_user_id = user_id if user_id is not None else current_user.user_id + profile = get_user_profile(session, current_user, user_id=user_id) payload = data.model_dump(exclude_unset=True) if profile is None: - profile = UserProfile(**payload) + profile = UserProfile(user_id=target_user_id, **payload) else: for key, value in payload.items(): setattr(profile, key, value) diff --git a/backend/innercontext/api/routines.py b/backend/innercontext/api/routines.py index 9f98b27..3b825ff 100644 --- a/backend/innercontext/api/routines.py +++ b/backend/innercontext/api/routines.py @@ -5,12 +5,15 @@ from datetime import date, timedelta from typing import Any, Optional from uuid import UUID, uuid4 -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends, HTTPException, Query from google.genai import types as genai_types from pydantic import BaseModel as PydanticBase +from sqlalchemy import or_ from sqlmodel import Field, Session, SQLModel, col, select from db import get_session +from innercontext.api.auth_deps import get_current_user +from innercontext.api.authz import is_product_visible from innercontext.api.llm_context import ( build_products_context_summary_list, build_user_profile_context, @@ -25,7 +28,8 @@ from innercontext.api.product_llm_tools import ( build_last_used_on_by_product, build_product_details_tool_handler, ) -from innercontext.api.utils import get_or_404 +from innercontext.api.utils import get_owned_or_404 +from innercontext.auth import CurrentUser from innercontext.llm import ( call_gemini, call_gemini_with_function_tools, @@ -33,6 +37,7 @@ from innercontext.llm import ( ) from innercontext.llm_safety import isolate_user_input, sanitize_user_input from innercontext.models import ( + HouseholdMembership, GroomingSchedule, Product, ProductInventory, @@ -43,6 +48,7 @@ from innercontext.models import ( from innercontext.models.ai_log import AICallLog from innercontext.models.api_metadata import ResponseMetadata, TokenMetrics from innercontext.models.enums import GroomingAction, PartOfDay +from innercontext.models.enums import Role from innercontext.validators import BatchValidator, RoutineSuggestionValidator from innercontext.validators.batch_validator import BatchValidationContext from innercontext.validators.routine_validator import RoutineValidationContext @@ -86,6 +92,47 @@ def _build_response_metadata(session: Session, log_id: Any) -> ResponseMetadata router = APIRouter() +def _resolve_target_user_id( + current_user: CurrentUser, + user_id: UUID | None, +) -> UUID: + if user_id is None: + return current_user.user_id + if current_user.role is not Role.ADMIN: + raise HTTPException(status_code=403, detail="Admin role required") + return user_id + + +def _shared_household_user_ids( + session: Session, current_user: CurrentUser +) -> set[UUID]: + membership = current_user.household_membership + if membership is None: + return set() + user_ids = session.exec( + select(HouseholdMembership.user_id).where( + HouseholdMembership.household_id == membership.household_id + ) + ).all() + return {uid for uid in user_ids if uid != current_user.user_id} + + +def _get_owned_or_admin_override( + session: Session, + model: type[Routine] | type[RoutineStep] | type[GroomingSchedule], + record_id: UUID, + current_user: CurrentUser, + user_id: UUID | None, +): + if user_id is None: + return get_owned_or_404(session, model, record_id, current_user) + target_user_id = _resolve_target_user_id(current_user, user_id) + record = session.get(model, record_id) + if record is None or record.user_id != target_user_id: + raise HTTPException(status_code=404, detail=f"{model.__name__} not found") + return record + + # --------------------------------------------------------------------------- # Schemas # --------------------------------------------------------------------------- @@ -289,6 +336,7 @@ def _ev(v: object) -> str: def _get_recent_skin_snapshot( session: Session, + target_user_id: UUID, reference_date: date, window_days: int = HISTORY_WINDOW_DAYS, fallback_days: int = SNAPSHOT_FALLBACK_DAYS, @@ -298,6 +346,7 @@ def _get_recent_skin_snapshot( snapshot = session.exec( select(SkinConditionSnapshot) + .where(SkinConditionSnapshot.user_id == target_user_id) .where(SkinConditionSnapshot.snapshot_date <= reference_date) .where(SkinConditionSnapshot.snapshot_date >= window_cutoff) .order_by(col(SkinConditionSnapshot.snapshot_date).desc()) @@ -307,6 +356,7 @@ def _get_recent_skin_snapshot( return session.exec( select(SkinConditionSnapshot) + .where(SkinConditionSnapshot.user_id == target_user_id) .where(SkinConditionSnapshot.snapshot_date <= reference_date) .where(SkinConditionSnapshot.snapshot_date >= fallback_cutoff) .order_by(col(SkinConditionSnapshot.snapshot_date).desc()) @@ -315,12 +365,14 @@ def _get_recent_skin_snapshot( def _get_latest_skin_snapshot_within_days( session: Session, + target_user_id: UUID, reference_date: date, max_age_days: int = SNAPSHOT_FALLBACK_DAYS, ) -> SkinConditionSnapshot | None: cutoff = reference_date - timedelta(days=max_age_days) return session.exec( select(SkinConditionSnapshot) + .where(SkinConditionSnapshot.user_id == target_user_id) .where(SkinConditionSnapshot.snapshot_date <= reference_date) .where(SkinConditionSnapshot.snapshot_date >= cutoff) .order_by(col(SkinConditionSnapshot.snapshot_date).desc()) @@ -329,12 +381,14 @@ def _get_latest_skin_snapshot_within_days( def _build_skin_context( session: Session, + target_user_id: UUID, reference_date: date, window_days: int = HISTORY_WINDOW_DAYS, fallback_days: int = SNAPSHOT_FALLBACK_DAYS, ) -> str: snapshot = _get_recent_skin_snapshot( session, + target_user_id=target_user_id, reference_date=reference_date, window_days=window_days, fallback_days=fallback_days, @@ -354,10 +408,14 @@ def _build_skin_context( def _build_grooming_context( - session: Session, weekdays: Optional[list[int]] = None + session: Session, + target_user_id: UUID, + weekdays: Optional[list[int]] = None, ) -> str: entries = session.exec( - select(GroomingSchedule).order_by(col(GroomingSchedule.day_of_week)) + select(GroomingSchedule) + .where(GroomingSchedule.user_id == target_user_id) + .order_by(col(GroomingSchedule.day_of_week)) ).all() if not entries: return "GROOMING SCHEDULE: none\n" @@ -378,11 +436,14 @@ def _build_grooming_context( def _build_upcoming_grooming_context( session: Session, + target_user_id: UUID, start_date: date, days: int = 7, ) -> str: entries = session.exec( - select(GroomingSchedule).order_by(col(GroomingSchedule.day_of_week)) + select(GroomingSchedule) + .where(GroomingSchedule.user_id == target_user_id) + .order_by(col(GroomingSchedule.day_of_week)) ).all() if not entries: return f"UPCOMING GROOMING (next {days} days): none\n" @@ -420,12 +481,14 @@ def _build_upcoming_grooming_context( def _build_recent_history( session: Session, + target_user_id: UUID, reference_date: date, window_days: int = HISTORY_WINDOW_DAYS, ) -> str: cutoff = reference_date - timedelta(days=window_days) routines = session.exec( select(Routine) + .where(Routine.user_id == target_user_id) .where(Routine.routine_date <= reference_date) .where(Routine.routine_date >= cutoff) .order_by(col(Routine.routine_date).desc()) @@ -437,6 +500,7 @@ def _build_recent_history( steps = session.exec( select(RoutineStep) .where(RoutineStep.routine_id == r.id) + .where(RoutineStep.user_id == target_user_id) .order_by(col(RoutineStep.order_index)) ).all() step_names = [] @@ -458,11 +522,37 @@ def _build_recent_history( def _get_available_products( session: Session, + current_user: CurrentUser, time_filter: Optional[str] = None, include_minoxidil: bool = True, ) -> list[Product]: stmt = select(Product).where(col(Product.is_tool).is_(False)) - products = session.exec(stmt).all() + if current_user.role is not Role.ADMIN: + owned_products = session.exec( + stmt.where(col(Product.user_id) == current_user.user_id) + ).all() + shared_user_ids = _shared_household_user_ids(session, current_user) + shared_product_ids = ( + session.exec( + select(ProductInventory.product_id) + .where(col(ProductInventory.is_household_shared).is_(True)) + .where(col(ProductInventory.user_id).in_(list(shared_user_ids))) + .distinct() + ).all() + if shared_user_ids + else [] + ) + shared_products = ( + session.exec(stmt.where(col(Product.id).in_(shared_product_ids))).all() + if shared_product_ids + else [] + ) + products_by_id = {p.id: p for p in owned_products} + for product in shared_products: + products_by_id.setdefault(product.id, product) + products = list(products_by_id.values()) + else: + products = session.exec(stmt).all() result: list[Product] = [] for p in products: if p.is_medication and not _is_minoxidil_product(p): @@ -517,7 +607,9 @@ def _extract_requested_product_ids( def _get_products_with_inventory( - session: Session, product_ids: list[UUID] + session: Session, + current_user: CurrentUser, + product_ids: list[UUID], ) -> set[UUID]: """ Return set of product IDs that have active (non-finished) inventory. @@ -527,17 +619,33 @@ def _get_products_with_inventory( if not product_ids: return set() - inventory_rows = session.exec( + stmt = ( select(ProductInventory.product_id) .where(col(ProductInventory.product_id).in_(product_ids)) .where(col(ProductInventory.finished_at).is_(None)) - .distinct() - ).all() - + ) + if current_user.role is not Role.ADMIN: + owned_inventory_rows = session.exec( + stmt.where(col(ProductInventory.user_id) == current_user.user_id).distinct() + ).all() + shared_user_ids = _shared_household_user_ids(session, current_user) + shared_inventory_rows = session.exec( + stmt.where(col(ProductInventory.is_household_shared).is_(True)) + .where(col(ProductInventory.user_id).in_(list(shared_user_ids))) + .distinct() + ).all() + inventory_rows = set(owned_inventory_rows) + inventory_rows.update(shared_inventory_rows) + return inventory_rows + inventory_rows = session.exec(stmt.distinct()).all() return set(inventory_rows) -def _expand_product_id(session: Session, short_or_full_id: str) -> UUID | None: +def _expand_product_id( + session: Session, + current_user: CurrentUser, + short_or_full_id: str, +) -> UUID | None: """ Expand 8-char short_id to full UUID, or validate full UUID. @@ -558,7 +666,13 @@ def _expand_product_id(session: Session, short_or_full_id: str) -> UUID | None: uuid_obj = UUID(short_or_full_id) # Verify it exists product = session.get(Product, uuid_obj) - return uuid_obj if product else None + if product is None: + return None + return ( + uuid_obj + if is_product_visible(session, uuid_obj, current_user) + else None + ) except (ValueError, TypeError): return None @@ -567,7 +681,13 @@ def _expand_product_id(session: Session, short_or_full_id: str) -> UUID | None: product = session.exec( select(Product).where(Product.short_id == short_or_full_id) ).first() - return product.id if product else None + if product is None: + return None + return ( + product.id + if is_product_visible(session, product.id, current_user) + else None + ) # Invalid length return None @@ -590,6 +710,17 @@ def _build_day_context(leaving_home: Optional[bool]) -> str: return f"DAY CONTEXT:\n Leaving home: {val}\n" +def _coerce_action_type(value: object) -> GroomingAction | None: + if isinstance(value, GroomingAction): + return value + if isinstance(value, str): + try: + return GroomingAction(value) + except ValueError: + return None + return None + + _ROUTINES_SYSTEM_PROMPT = """\ Jesteś ekspertem planowania pielęgnacji. @@ -676,9 +807,12 @@ def list_routines( from_date: Optional[date] = None, to_date: Optional[date] = None, part_of_day: Optional[PartOfDay] = None, + user_id: UUID | None = Query(default=None), session: Session = Depends(get_session), + current_user: CurrentUser = Depends(get_current_user), ): - stmt = select(Routine) + target_user_id = _resolve_target_user_id(current_user, user_id) + stmt = select(Routine).where(Routine.user_id == target_user_id) if from_date is not None: stmt = stmt.where(Routine.routine_date >= from_date) if to_date is not None: @@ -688,10 +822,12 @@ def list_routines( routines = session.exec(stmt).all() routine_ids = [r.id for r in routines] - steps_by_routine: dict = {} + steps_by_routine: dict[UUID, list[RoutineStep]] = {} if routine_ids: all_steps = session.exec( - select(RoutineStep).where(col(RoutineStep.routine_id).in_(routine_ids)) + select(RoutineStep) + .where(col(RoutineStep.routine_id).in_(routine_ids)) + .where(RoutineStep.user_id == target_user_id) ).all() for step in all_steps: steps_by_routine.setdefault(step.routine_id, []).append(step) @@ -707,8 +843,14 @@ def list_routines( @router.post("", response_model=Routine, status_code=201) -def create_routine(data: RoutineCreate, session: Session = Depends(get_session)): - routine = Routine(id=uuid4(), **data.model_dump()) +def create_routine( + data: RoutineCreate, + user_id: UUID | None = Query(default=None), + session: Session = Depends(get_session), + current_user: CurrentUser = Depends(get_current_user), +): + target_user_id = _resolve_target_user_id(current_user, user_id) + routine = Routine(id=uuid4(), user_id=target_user_id, **data.model_dump()) session.add(routine) session.commit() session.refresh(routine) @@ -724,19 +866,35 @@ def create_routine(data: RoutineCreate, session: Session = Depends(get_session)) def suggest_routine( data: SuggestRoutineRequest, session: Session = Depends(get_session), + current_user: CurrentUser = Depends(get_current_user), ): + target_user_id = current_user.user_id weekday = data.routine_date.weekday() - skin_ctx = _build_skin_context(session, reference_date=data.routine_date) - profile_ctx = build_user_profile_context(session, reference_date=data.routine_date) + skin_ctx = _build_skin_context( + session, + target_user_id=target_user_id, + reference_date=data.routine_date, + ) + profile_ctx = build_user_profile_context( + session, + reference_date=data.routine_date, + current_user=current_user, + ) upcoming_grooming_ctx = _build_upcoming_grooming_context( session, + target_user_id=target_user_id, start_date=data.routine_date, days=7, ) - history_ctx = _build_recent_history(session, reference_date=data.routine_date) + history_ctx = _build_recent_history( + session, + target_user_id=target_user_id, + reference_date=data.routine_date, + ) day_ctx = _build_day_context(data.leaving_home) available_products = _get_available_products( session, + current_user=current_user, time_filter=data.part_of_day.value, include_minoxidil=data.include_minoxidil_beard, ) @@ -752,7 +910,9 @@ def suggest_routine( # Phase 2: Use tiered context (summary mode for initial prompt) products_with_inventory = _get_products_with_inventory( - session, [p.id for p in available_products] + session, + current_user, + [p.id for p in available_products], ) products_ctx = build_products_context_summary_list( available_products, products_with_inventory @@ -865,22 +1025,35 @@ def suggest_routine( # Translation layer: Expand short_ids (8 chars) to full UUIDs (36 chars) steps = [] - for s in parsed.get("steps", []): + raw_steps = parsed.get("steps", []) + if not isinstance(raw_steps, list): + raw_steps = [] + for s in raw_steps: + if not isinstance(s, dict): + continue product_id_str = s.get("product_id") product_id_uuid = None - if product_id_str: + if isinstance(product_id_str, str) and product_id_str: # Expand short_id or validate full UUID - product_id_uuid = _expand_product_id(session, product_id_str) + product_id_uuid = _expand_product_id(session, current_user, product_id_str) + + action_type = s.get("action_type") + action_notes = s.get("action_notes") + region = s.get("region") + why_this_step = s.get("why_this_step") + optional = s.get("optional") steps.append( SuggestedStep( product_id=product_id_uuid, - action_type=s.get("action_type") or None, - action_notes=s.get("action_notes"), - region=s.get("region"), - why_this_step=s.get("why_this_step"), - optional=s.get("optional"), + action_type=_coerce_action_type(action_type), + action_notes=action_notes if isinstance(action_notes, str) else None, + region=region if isinstance(region, str) else None, + why_this_step=( + why_this_step if isinstance(why_this_step, str) else None + ), + optional=optional if isinstance(optional, bool) else None, ) ) @@ -904,6 +1077,7 @@ def suggest_routine( # Get skin snapshot for barrier state skin_snapshot = _get_latest_skin_snapshot_within_days( session, + target_user_id=target_user_id, reference_date=data.routine_date, ) @@ -964,7 +1138,9 @@ def suggest_routine( def suggest_batch( data: SuggestBatchRequest, session: Session = Depends(get_session), + current_user: CurrentUser = Depends(get_current_user), ): + target_user_id = current_user.user_id delta = (data.to_date - data.from_date).days + 1 if delta > 14: raise HTTPException( @@ -976,18 +1152,37 @@ def suggest_batch( weekdays = list( {(data.from_date + timedelta(days=i)).weekday() for i in range(delta)} ) - profile_ctx = build_user_profile_context(session, reference_date=data.from_date) - skin_ctx = _build_skin_context(session, reference_date=data.from_date) - grooming_ctx = _build_grooming_context(session, weekdays=weekdays) - history_ctx = _build_recent_history(session, reference_date=data.from_date) + profile_ctx = build_user_profile_context( + session, + reference_date=data.from_date, + current_user=current_user, + ) + skin_ctx = _build_skin_context( + session, + target_user_id=target_user_id, + reference_date=data.from_date, + ) + grooming_ctx = _build_grooming_context( + session, + target_user_id=target_user_id, + weekdays=weekdays, + ) + history_ctx = _build_recent_history( + session, + target_user_id=target_user_id, + reference_date=data.from_date, + ) batch_products = _get_available_products( session, + current_user=current_user, include_minoxidil=data.include_minoxidil_beard, ) # Phase 2: Use tiered context (summary mode for batch planning) products_with_inventory = _get_products_with_inventory( - session, [p.id for p in batch_products] + session, + current_user, + [p.id for p in batch_products], ) products_ctx = build_products_context_summary_list( batch_products, products_with_inventory @@ -1045,25 +1240,39 @@ def suggest_batch( except json.JSONDecodeError as e: raise HTTPException(status_code=502, detail=f"LLM returned invalid JSON: {e}") - def _parse_steps(raw_steps: list) -> list[SuggestedStep]: + def _parse_steps(raw_steps: list[dict[str, object]]) -> list[SuggestedStep]: """Parse steps and expand short_ids to full UUIDs.""" result = [] for s in raw_steps: product_id_str = s.get("product_id") product_id_uuid = None - if product_id_str: + if isinstance(product_id_str, str) and product_id_str: # Translation layer: expand short_id to full UUID - product_id_uuid = _expand_product_id(session, product_id_str) + product_id_uuid = _expand_product_id( + session, + current_user, + product_id_str, + ) + + action_type = s.get("action_type") + action_notes = s.get("action_notes") + region = s.get("region") + why_this_step = s.get("why_this_step") + optional = s.get("optional") result.append( SuggestedStep( product_id=product_id_uuid, - action_type=s.get("action_type") or None, - action_notes=s.get("action_notes"), - region=s.get("region"), - why_this_step=s.get("why_this_step"), - optional=s.get("optional"), + action_type=_coerce_action_type(action_type), + action_notes=( + action_notes if isinstance(action_notes, str) else None + ), + region=region if isinstance(region, str) else None, + why_this_step=( + why_this_step if isinstance(why_this_step, str) else None + ), + optional=optional if isinstance(optional, bool) else None, ) ) return result @@ -1086,6 +1295,7 @@ def suggest_batch( # Get skin snapshot for barrier state skin_snapshot = _get_latest_skin_snapshot_within_days( session, + target_user_id=target_user_id, reference_date=data.from_date, ) @@ -1140,15 +1350,36 @@ def suggest_batch( # Grooming-schedule GET must appear before /{routine_id} to avoid being shadowed @router.get("/grooming-schedule", response_model=list[GroomingSchedule]) -def list_grooming_schedule(session: Session = Depends(get_session)): - return session.exec(select(GroomingSchedule)).all() +def list_grooming_schedule( + user_id: UUID | None = Query(default=None), + session: Session = Depends(get_session), + current_user: CurrentUser = Depends(get_current_user), +): + target_user_id = _resolve_target_user_id(current_user, user_id) + return session.exec( + select(GroomingSchedule).where(GroomingSchedule.user_id == target_user_id) + ).all() @router.get("/{routine_id}") -def get_routine(routine_id: UUID, session: Session = Depends(get_session)): - routine = get_or_404(session, Routine, routine_id) +def get_routine( + routine_id: UUID, + user_id: UUID | None = Query(default=None), + session: Session = Depends(get_session), + current_user: CurrentUser = Depends(get_current_user), +): + target_user_id = _resolve_target_user_id(current_user, user_id) + routine = _get_owned_or_admin_override( + session, + Routine, + routine_id, + current_user, + user_id, + ) steps = session.exec( - select(RoutineStep).where(RoutineStep.routine_id == routine_id) + select(RoutineStep) + .where(RoutineStep.routine_id == routine_id) + .where(RoutineStep.user_id == target_user_id) ).all() data = routine.model_dump(mode="json") data["steps"] = [step.model_dump(mode="json") for step in steps] @@ -1159,9 +1390,17 @@ def get_routine(routine_id: UUID, session: Session = Depends(get_session)): def update_routine( routine_id: UUID, data: RoutineUpdate, + user_id: UUID | None = Query(default=None), session: Session = Depends(get_session), + current_user: CurrentUser = Depends(get_current_user), ): - routine = get_or_404(session, Routine, routine_id) + routine = _get_owned_or_admin_override( + session, + Routine, + routine_id, + current_user, + user_id, + ) for key, value in data.model_dump(exclude_unset=True).items(): setattr(routine, key, value) session.add(routine) @@ -1171,8 +1410,19 @@ def update_routine( @router.delete("/{routine_id}", status_code=204) -def delete_routine(routine_id: UUID, session: Session = Depends(get_session)): - routine = get_or_404(session, Routine, routine_id) +def delete_routine( + routine_id: UUID, + user_id: UUID | None = Query(default=None), + session: Session = Depends(get_session), + current_user: CurrentUser = Depends(get_current_user), +): + routine = _get_owned_or_admin_override( + session, + Routine, + routine_id, + current_user, + user_id, + ) session.delete(routine) session.commit() @@ -1186,10 +1436,28 @@ def delete_routine(routine_id: UUID, session: Session = Depends(get_session)): def add_step( routine_id: UUID, data: RoutineStepCreate, + user_id: UUID | None = Query(default=None), session: Session = Depends(get_session), + current_user: CurrentUser = Depends(get_current_user), ): - get_or_404(session, Routine, routine_id) - step = RoutineStep(id=uuid4(), routine_id=routine_id, **data.model_dump()) + target_user_id = _resolve_target_user_id(current_user, user_id) + _ = _get_owned_or_admin_override( + session, + Routine, + routine_id, + current_user, + user_id, + ) + if data.product_id and not is_product_visible( + session, data.product_id, current_user + ): + raise HTTPException(status_code=404, detail="Product not found") + step = RoutineStep( + id=uuid4(), + user_id=target_user_id, + routine_id=routine_id, + **data.model_dump(), + ) session.add(step) session.commit() session.refresh(step) @@ -1200,9 +1468,21 @@ def add_step( def update_step( step_id: UUID, data: RoutineStepUpdate, + user_id: UUID | None = Query(default=None), session: Session = Depends(get_session), + current_user: CurrentUser = Depends(get_current_user), ): - step = get_or_404(session, RoutineStep, step_id) + step = _get_owned_or_admin_override( + session, + RoutineStep, + step_id, + current_user, + user_id, + ) + if data.product_id and not is_product_visible( + session, data.product_id, current_user + ): + raise HTTPException(status_code=404, detail="Product not found") for key, value in data.model_dump(exclude_unset=True).items(): setattr(step, key, value) session.add(step) @@ -1212,8 +1492,19 @@ def update_step( @router.delete("/steps/{step_id}", status_code=204) -def delete_step(step_id: UUID, session: Session = Depends(get_session)): - step = get_or_404(session, RoutineStep, step_id) +def delete_step( + step_id: UUID, + user_id: UUID | None = Query(default=None), + session: Session = Depends(get_session), + current_user: CurrentUser = Depends(get_current_user), +): + step = _get_owned_or_admin_override( + session, + RoutineStep, + step_id, + current_user, + user_id, + ) session.delete(step) session.commit() @@ -1225,9 +1516,13 @@ def delete_step(step_id: UUID, session: Session = Depends(get_session)): @router.post("/grooming-schedule", response_model=GroomingSchedule, status_code=201) def create_grooming_schedule( - data: GroomingScheduleCreate, session: Session = Depends(get_session) + data: GroomingScheduleCreate, + user_id: UUID | None = Query(default=None), + session: Session = Depends(get_session), + current_user: CurrentUser = Depends(get_current_user), ): - entry = GroomingSchedule(id=uuid4(), **data.model_dump()) + target_user_id = _resolve_target_user_id(current_user, user_id) + entry = GroomingSchedule(id=uuid4(), user_id=target_user_id, **data.model_dump()) session.add(entry) session.commit() session.refresh(entry) @@ -1238,9 +1533,17 @@ def create_grooming_schedule( def update_grooming_schedule( entry_id: UUID, data: GroomingScheduleUpdate, + user_id: UUID | None = Query(default=None), session: Session = Depends(get_session), + current_user: CurrentUser = Depends(get_current_user), ): - entry = get_or_404(session, GroomingSchedule, entry_id) + entry = _get_owned_or_admin_override( + session, + GroomingSchedule, + entry_id, + current_user, + user_id, + ) for key, value in data.model_dump(exclude_unset=True).items(): setattr(entry, key, value) session.add(entry) @@ -1250,7 +1553,18 @@ def update_grooming_schedule( @router.delete("/grooming-schedule/{entry_id}", status_code=204) -def delete_grooming_schedule(entry_id: UUID, session: Session = Depends(get_session)): - entry = get_or_404(session, GroomingSchedule, entry_id) +def delete_grooming_schedule( + entry_id: UUID, + user_id: UUID | None = Query(default=None), + session: Session = Depends(get_session), + current_user: CurrentUser = Depends(get_current_user), +): + entry = _get_owned_or_admin_override( + session, + GroomingSchedule, + entry_id, + current_user, + user_id, + ) session.delete(entry) session.commit() diff --git a/backend/innercontext/api/skincare.py b/backend/innercontext/api/skincare.py index 730db1e..4984bf9 100644 --- a/backend/innercontext/api/skincare.py +++ b/backend/innercontext/api/skincare.py @@ -4,15 +4,17 @@ from datetime import date from typing import Optional from uuid import UUID, uuid4 -from fastapi import APIRouter, Depends, File, HTTPException, UploadFile +from fastapi import APIRouter, Depends, File, HTTPException, Query, UploadFile from google.genai import types as genai_types from pydantic import BaseModel as PydanticBase from pydantic import ValidationError from sqlmodel import Session, SQLModel, select from db import get_session +from innercontext.api.auth_deps import get_current_user from innercontext.api.llm_context import build_user_profile_context -from innercontext.api.utils import get_or_404 +from innercontext.api.utils import get_owned_or_404 +from innercontext.auth import CurrentUser from innercontext.llm import call_gemini, get_extraction_config from innercontext.models import ( SkinConditionSnapshot, @@ -26,6 +28,7 @@ from innercontext.models.enums import ( SkinTexture, SkinType, ) +from innercontext.models.enums import Role from innercontext.validators import PhotoValidator logger = logging.getLogger(__name__) @@ -135,6 +138,34 @@ OUTPUT (all fields optional): # --------------------------------------------------------------------------- +def _resolve_target_user_id( + current_user: CurrentUser, + user_id: UUID | None, +) -> UUID: + if user_id is None: + return current_user.user_id + if current_user.role is not Role.ADMIN: + raise HTTPException(status_code=403, detail="Admin role required") + return user_id + + +def _get_owned_or_admin_override( + session: Session, + snapshot_id: UUID, + current_user: CurrentUser, + user_id: UUID | None, +) -> SkinConditionSnapshot: + if user_id is None: + return get_owned_or_404( + session, SkinConditionSnapshot, snapshot_id, current_user + ) + target_user_id = _resolve_target_user_id(current_user, user_id) + snapshot = session.get(SkinConditionSnapshot, snapshot_id) + if snapshot is None or snapshot.user_id != target_user_id: + raise HTTPException(status_code=404, detail="SkinConditionSnapshot not found") + return snapshot + + MAX_IMAGE_BYTES = 5 * 1024 * 1024 # 5 MB @@ -142,6 +173,7 @@ MAX_IMAGE_BYTES = 5 * 1024 * 1024 # 5 MB async def analyze_skin_photos( photos: list[UploadFile] = File(...), session: Session = Depends(get_session), + current_user: CurrentUser = Depends(get_current_user), ) -> SkinPhotoAnalysisResponse: if not (1 <= len(photos) <= 3): raise HTTPException(status_code=422, detail="Send between 1 and 3 photos.") @@ -174,7 +206,11 @@ async def analyze_skin_photos( ) parts.append( genai_types.Part.from_text( - text=build_user_profile_context(session, reference_date=date.today()) + text=build_user_profile_context( + session, + reference_date=date.today(), + current_user=current_user, + ) ) ) @@ -224,9 +260,14 @@ def list_snapshots( from_date: Optional[date] = None, to_date: Optional[date] = None, overall_state: Optional[OverallSkinState] = None, + user_id: UUID | None = Query(default=None), session: Session = Depends(get_session), + current_user: CurrentUser = Depends(get_current_user), ): - stmt = select(SkinConditionSnapshot) + target_user_id = _resolve_target_user_id(current_user, user_id) + stmt = select(SkinConditionSnapshot).where( + SkinConditionSnapshot.user_id == target_user_id + ) if from_date is not None: stmt = stmt.where(SkinConditionSnapshot.snapshot_date >= from_date) if to_date is not None: @@ -237,8 +278,18 @@ def list_snapshots( @router.post("", response_model=SkinConditionSnapshotPublic, status_code=201) -def create_snapshot(data: SnapshotCreate, session: Session = Depends(get_session)): - snapshot = SkinConditionSnapshot(id=uuid4(), **data.model_dump()) +def create_snapshot( + data: SnapshotCreate, + user_id: UUID | None = Query(default=None), + session: Session = Depends(get_session), + current_user: CurrentUser = Depends(get_current_user), +): + target_user_id = _resolve_target_user_id(current_user, user_id) + snapshot = SkinConditionSnapshot( + id=uuid4(), + user_id=target_user_id, + **data.model_dump(), + ) session.add(snapshot) session.commit() session.refresh(snapshot) @@ -246,17 +297,34 @@ def create_snapshot(data: SnapshotCreate, session: Session = Depends(get_session @router.get("/{snapshot_id}", response_model=SkinConditionSnapshotPublic) -def get_snapshot(snapshot_id: UUID, session: Session = Depends(get_session)): - return get_or_404(session, SkinConditionSnapshot, snapshot_id) +def get_snapshot( + snapshot_id: UUID, + user_id: UUID | None = Query(default=None), + session: Session = Depends(get_session), + current_user: CurrentUser = Depends(get_current_user), +): + return _get_owned_or_admin_override( + session, + snapshot_id, + current_user, + user_id, + ) @router.patch("/{snapshot_id}", response_model=SkinConditionSnapshotPublic) def update_snapshot( snapshot_id: UUID, data: SnapshotUpdate, + user_id: UUID | None = Query(default=None), session: Session = Depends(get_session), + current_user: CurrentUser = Depends(get_current_user), ): - snapshot = get_or_404(session, SkinConditionSnapshot, snapshot_id) + snapshot = _get_owned_or_admin_override( + session, + snapshot_id, + current_user, + user_id, + ) for key, value in data.model_dump(exclude_unset=True).items(): setattr(snapshot, key, value) session.add(snapshot) @@ -266,7 +334,17 @@ def update_snapshot( @router.delete("/{snapshot_id}", status_code=204) -def delete_snapshot(snapshot_id: UUID, session: Session = Depends(get_session)): - snapshot = get_or_404(session, SkinConditionSnapshot, snapshot_id) +def delete_snapshot( + snapshot_id: UUID, + user_id: UUID | None = Query(default=None), + session: Session = Depends(get_session), + current_user: CurrentUser = Depends(get_current_user), +): + snapshot = _get_owned_or_admin_override( + session, + snapshot_id, + current_user, + user_id, + ) session.delete(snapshot) session.commit() diff --git a/backend/innercontext/services/pricing_jobs.py b/backend/innercontext/services/pricing_jobs.py index 9e9c9dd..9d6a24e 100644 --- a/backend/innercontext/services/pricing_jobs.py +++ b/backend/innercontext/services/pricing_jobs.py @@ -1,4 +1,5 @@ from datetime import datetime +from uuid import UUID from sqlmodel import Session, col, select @@ -66,9 +67,53 @@ def _apply_pricing_snapshot(session: Session, computed_at: datetime) -> int: return len(products) +def _scope_user_id(scope: str) -> UUID | None: + prefix = "user:" + if not scope.startswith(prefix): + return None + raw_user_id = scope[len(prefix) :].strip() + if not raw_user_id: + return None + try: + return UUID(raw_user_id) + except ValueError: + return None + + +def _apply_pricing_snapshot_for_scope( + session: Session, + *, + computed_at: datetime, + scope: str, +) -> int: + from innercontext.api.products import _compute_pricing_outputs + + scoped_user_id = _scope_user_id(scope) + stmt = select(Product) + if scoped_user_id is not None: + stmt = stmt.where(Product.user_id == scoped_user_id) + products = list(session.exec(stmt).all()) + pricing_outputs = _compute_pricing_outputs(products) + + for product in products: + tier, price_per_use_pln, tier_source = pricing_outputs.get( + product.id, (None, None, None) + ) + product.price_tier = tier + product.price_per_use_pln = price_per_use_pln + product.price_tier_source = tier_source + product.pricing_computed_at = computed_at + + return len(products) + + def process_pricing_job(session: Session, job: PricingRecalcJob) -> int: try: - updated_count = _apply_pricing_snapshot(session, computed_at=utc_now()) + updated_count = _apply_pricing_snapshot_for_scope( + session, + computed_at=utc_now(), + scope=job.scope, + ) job.status = "succeeded" job.finished_at = utc_now() job.error = None diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index e35dfba..b8831bb 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -37,27 +37,32 @@ def session(monkeypatch): @pytest.fixture() -def client(session, monkeypatch): +def current_user() -> CurrentUser: + claims = TokenClaims( + issuer="https://auth.test", + subject="test-user", + audience=("innercontext-web",), + expires_at=datetime.now(UTC) + timedelta(hours=1), + groups=("innercontext-admin",), + raw_claims={"iss": "https://auth.test", "sub": "test-user"}, + ) + return CurrentUser( + user_id=uuid4(), + role=Role.ADMIN, + identity=IdentityData.from_claims(claims), + claims=claims, + ) + + +@pytest.fixture() +def client(session, monkeypatch, current_user): """TestClient using the per-test session for every request.""" def _override(): yield session def _current_user_override(): - claims = TokenClaims( - issuer="https://auth.test", - subject="test-user", - audience=("innercontext-web",), - expires_at=datetime.now(UTC) + timedelta(hours=1), - groups=("innercontext-admin",), - raw_claims={"iss": "https://auth.test", "sub": "test-user"}, - ) - return CurrentUser( - user_id=uuid4(), - role=Role.ADMIN, - identity=IdentityData.from_claims(claims), - claims=claims, - ) + return current_user app.dependency_overrides[get_session] = _override app.dependency_overrides[get_current_user] = _current_user_override diff --git a/backend/tests/test_ai_logs.py b/backend/tests/test_ai_logs.py index 47a168e..8fe2a60 100644 --- a/backend/tests/test_ai_logs.py +++ b/backend/tests/test_ai_logs.py @@ -4,12 +4,13 @@ from typing import Any, cast from innercontext.models.ai_log import AICallLog -def test_list_ai_logs_normalizes_tool_trace_string(client, session): +def test_list_ai_logs_normalizes_tool_trace_string(client, session, current_user): log = AICallLog( id=uuid.uuid4(), endpoint="routines/suggest", model="gemini-3-flash-preview", success=True, + user_id=current_user.user_id, ) log.tool_trace = cast( Any, @@ -26,12 +27,13 @@ def test_list_ai_logs_normalizes_tool_trace_string(client, session): assert data[0]["tool_trace"]["events"][0]["function"] == "get_product_inci" -def test_get_ai_log_normalizes_tool_trace_string(client, session): +def test_get_ai_log_normalizes_tool_trace_string(client, session, current_user): log = AICallLog( id=uuid.uuid4(), endpoint="routines/suggest", model="gemini-3-flash-preview", success=True, + user_id=current_user.user_id, ) log.tool_trace = cast(Any, '{"mode":"function_tools","round":1}') session.add(log) diff --git a/backend/tests/test_routines.py b/backend/tests/test_routines.py index dae411f..28c885f 100644 --- a/backend/tests/test_routines.py +++ b/backend/tests/test_routines.py @@ -3,7 +3,7 @@ from datetime import date from unittest.mock import patch from innercontext.models import Routine, SkinConditionSnapshot -from innercontext.models.enums import BarrierState, OverallSkinState +from innercontext.models.enums import BarrierState, OverallSkinState, PartOfDay # --------------------------------------------------------------------------- # Routines @@ -223,13 +223,14 @@ def test_delete_grooming_schedule_not_found(client): assert r.status_code == 404 -def test_suggest_routine(client, session): +def test_suggest_routine(client, session, current_user): with patch( "innercontext.api.routines.call_gemini_with_function_tools" ) as mock_gemini: session.add( SkinConditionSnapshot( id=uuid.uuid4(), + user_id=current_user.user_id, snapshot_date=date(2026, 2, 22), overall_state=OverallSkinState.GOOD, hydration_level=4, @@ -272,18 +273,20 @@ def test_suggest_routine(client, session): assert "get_product_details" in kwargs["function_handlers"] -def test_suggest_batch(client, session): +def test_suggest_batch(client, session, current_user): with patch("innercontext.api.routines.call_gemini") as mock_gemini: session.add( Routine( id=uuid.uuid4(), + user_id=current_user.user_id, routine_date=date(2026, 2, 27), - part_of_day="pm", + part_of_day=PartOfDay.PM, ) ) session.add( SkinConditionSnapshot( id=uuid.uuid4(), + user_id=current_user.user_id, snapshot_date=date(2026, 2, 20), overall_state=OverallSkinState.GOOD, hydration_level=4, diff --git a/backend/tests/test_routines_auth.py b/backend/tests/test_routines_auth.py new file mode 100644 index 0000000..9696556 --- /dev/null +++ b/backend/tests/test_routines_auth.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from unittest.mock import patch +from uuid import uuid4 + +from innercontext.api.auth_deps import get_current_user +from innercontext.auth import CurrentUser, IdentityData, TokenClaims +from innercontext.models import Role +from main import app + + +def _user(subject: str, *, role: Role = Role.MEMBER) -> CurrentUser: + claims = TokenClaims( + issuer="https://auth.test", + subject=subject, + audience=("innercontext-web",), + expires_at=datetime.now(UTC) + timedelta(hours=1), + raw_claims={"iss": "https://auth.test", "sub": subject}, + ) + return CurrentUser( + user_id=uuid4(), + role=role, + identity=IdentityData.from_claims(claims), + claims=claims, + ) + + +def _set_current_user(user: CurrentUser) -> None: + app.dependency_overrides[get_current_user] = lambda: user + + +def test_suggest_uses_current_user_profile_and_visible_products_only(client): + owner = _user("owner") + other = _user("other") + + _set_current_user(owner) + owner_profile = client.patch( + "/profile", json={"birth_date": "1991-01-15", "sex_at_birth": "male"} + ) + owner_product = client.post( + "/products", + json={ + "name": "Owner Serum", + "brand": "Test", + "category": "serum", + "recommended_time": "both", + "leave_on": True, + }, + ) + assert owner_profile.status_code == 200 + assert owner_product.status_code == 201 + + _set_current_user(other) + other_profile = client.patch( + "/profile", json={"birth_date": "1975-06-20", "sex_at_birth": "female"} + ) + other_product = client.post( + "/products", + json={ + "name": "Other Serum", + "brand": "Test", + "category": "serum", + "recommended_time": "both", + "leave_on": True, + }, + ) + assert other_profile.status_code == 200 + assert other_product.status_code == 201 + + _set_current_user(owner) + + with patch( + "innercontext.api.routines.call_gemini_with_function_tools" + ) as mock_gemini: + mock_response = type( + "Response", + (), + { + "text": '{"steps": [{"product_id": null, "action_type": "shaving_razor"}], "reasoning": "ok", "summary": {"primary_goal": "safe", "constraints_applied": [], "confidence": 0.7}}' + }, + ) + mock_gemini.return_value = (mock_response, None) + + response = client.post( + "/routines/suggest", + json={ + "routine_date": "2026-03-05", + "part_of_day": "am", + "include_minoxidil_beard": False, + }, + ) + assert response.status_code == 200 + + kwargs = mock_gemini.call_args.kwargs + prompt = kwargs["contents"] + assert "Birth date: 1991-01-15" in prompt + assert "Birth date: 1975-06-20" not in prompt + assert "Owner Serum" in prompt + assert "Other Serum" not in prompt + + handler = kwargs["function_handlers"]["get_product_details"] + payload = handler( + { + "product_ids": [ + owner_product.json()["id"], + other_product.json()["id"], + ] + } + ) + assert len(payload["products"]) == 1 + assert payload["products"][0]["name"] == "Owner Serum" diff --git a/backend/tests/test_routines_helpers.py b/backend/tests/test_routines_helpers.py index b547f62..4041c22 100644 --- a/backend/tests/test_routines_helpers.py +++ b/backend/tests/test_routines_helpers.py @@ -78,17 +78,22 @@ def test_ev(): assert _ev("string") == "string" -def test_build_skin_context(session: Session): +def test_build_skin_context(session: Session, current_user): # Empty reference_date = date(2026, 3, 10) assert ( - _build_skin_context(session, reference_date=reference_date) + _build_skin_context( + session, + target_user_id=current_user.user_id, + reference_date=reference_date, + ) == "SKIN CONDITION: no data\n" ) # With data snap = SkinConditionSnapshot( id=uuid.uuid4(), + user_id=current_user.user_id, snapshot_date=reference_date, overall_state=OverallSkinState.GOOD, hydration_level=4, @@ -100,7 +105,11 @@ def test_build_skin_context(session: Session): session.add(snap) session.commit() - ctx = _build_skin_context(session, reference_date=reference_date) + ctx = _build_skin_context( + session, + target_user_id=current_user.user_id, + reference_date=reference_date, + ) assert "SKIN CONDITION (snapshot from" in ctx assert "Overall state: good" in ctx assert "Hydration: 4/5" in ctx @@ -112,10 +121,12 @@ def test_build_skin_context(session: Session): def test_build_skin_context_falls_back_to_recent_snapshot_within_14_days( session: Session, + current_user, ): reference_date = date(2026, 3, 20) snap = SkinConditionSnapshot( id=uuid.uuid4(), + user_id=current_user.user_id, snapshot_date=reference_date - timedelta(days=10), overall_state=OverallSkinState.FAIR, hydration_level=3, @@ -126,16 +137,23 @@ def test_build_skin_context_falls_back_to_recent_snapshot_within_14_days( session.add(snap) session.commit() - ctx = _build_skin_context(session, reference_date=reference_date) + ctx = _build_skin_context( + session, + target_user_id=current_user.user_id, + reference_date=reference_date, + ) assert f"snapshot from {reference_date - timedelta(days=10)}" in ctx assert "Barrier: compromised" in ctx -def test_build_skin_context_ignores_snapshot_older_than_14_days(session: Session): +def test_build_skin_context_ignores_snapshot_older_than_14_days( + session: Session, current_user +): reference_date = date(2026, 3, 20) snap = SkinConditionSnapshot( id=uuid.uuid4(), + user_id=current_user.user_id, snapshot_date=reference_date - timedelta(days=15), overall_state=OverallSkinState.FAIR, hydration_level=3, @@ -145,15 +163,20 @@ def test_build_skin_context_ignores_snapshot_older_than_14_days(session: Session session.commit() assert ( - _build_skin_context(session, reference_date=reference_date) + _build_skin_context( + session, + target_user_id=current_user.user_id, + reference_date=reference_date, + ) == "SKIN CONDITION: no data\n" ) -def test_get_recent_skin_snapshot_prefers_window_match(session: Session): +def test_get_recent_skin_snapshot_prefers_window_match(session: Session, current_user): reference_date = date(2026, 3, 20) older = SkinConditionSnapshot( id=uuid.uuid4(), + user_id=current_user.user_id, snapshot_date=reference_date - timedelta(days=10), overall_state=OverallSkinState.POOR, hydration_level=2, @@ -161,6 +184,7 @@ def test_get_recent_skin_snapshot_prefers_window_match(session: Session): ) newer = SkinConditionSnapshot( id=uuid.uuid4(), + user_id=current_user.user_id, snapshot_date=reference_date - timedelta(days=2), overall_state=OverallSkinState.GOOD, hydration_level=4, @@ -169,7 +193,11 @@ def test_get_recent_skin_snapshot_prefers_window_match(session: Session): session.add_all([older, newer]) session.commit() - snapshot = _get_recent_skin_snapshot(session, reference_date=reference_date) + snapshot = _get_recent_skin_snapshot( + session, + target_user_id=current_user.user_id, + reference_date=reference_date, + ) assert snapshot is not None assert snapshot.id == newer.id @@ -177,10 +205,12 @@ def test_get_recent_skin_snapshot_prefers_window_match(session: Session): def test_get_latest_skin_snapshot_within_days_uses_latest_within_14_days( session: Session, + current_user, ): reference_date = date(2026, 3, 20) older = SkinConditionSnapshot( id=uuid.uuid4(), + user_id=current_user.user_id, snapshot_date=reference_date - timedelta(days=10), overall_state=OverallSkinState.POOR, hydration_level=2, @@ -188,6 +218,7 @@ def test_get_latest_skin_snapshot_within_days_uses_latest_within_14_days( ) newer = SkinConditionSnapshot( id=uuid.uuid4(), + user_id=current_user.user_id, snapshot_date=reference_date - timedelta(days=2), overall_state=OverallSkinState.GOOD, hydration_level=4, @@ -198,6 +229,7 @@ def test_get_latest_skin_snapshot_within_days_uses_latest_within_14_days( snapshot = _get_latest_skin_snapshot_within_days( session, + target_user_id=current_user.user_id, reference_date=reference_date, ) @@ -205,39 +237,65 @@ def test_get_latest_skin_snapshot_within_days_uses_latest_within_14_days( assert snapshot.id == newer.id -def test_build_grooming_context(session: Session): - assert _build_grooming_context(session) == "GROOMING SCHEDULE: none\n" +def test_build_grooming_context(session: Session, current_user): + assert ( + _build_grooming_context(session, target_user_id=current_user.user_id) + == "GROOMING SCHEDULE: none\n" + ) sch = GroomingSchedule( - id=uuid.uuid4(), day_of_week=0, action="shaving_oneblade", notes="Morning" + id=uuid.uuid4(), + user_id=current_user.user_id, + day_of_week=0, + action="shaving_oneblade", + notes="Morning", ) session.add(sch) session.commit() - ctx = _build_grooming_context(session) + ctx = _build_grooming_context(session, target_user_id=current_user.user_id) assert "GROOMING SCHEDULE:" in ctx assert "poniedziałek: shaving_oneblade (Morning)" in ctx # Test weekdays filter - ctx2 = _build_grooming_context(session, weekdays=[1]) # not monday + ctx2 = _build_grooming_context( + session, + target_user_id=current_user.user_id, + weekdays=[1], + ) # not monday assert "(no entries for specified days)" in ctx2 -def test_build_upcoming_grooming_context(session: Session): +def test_build_upcoming_grooming_context(session: Session, current_user): assert ( - _build_upcoming_grooming_context(session, start_date=date(2026, 3, 2), days=7) + _build_upcoming_grooming_context( + session, + target_user_id=current_user.user_id, + start_date=date(2026, 3, 2), + days=7, + ) == "UPCOMING GROOMING (next 7 days): none\n" ) monday = GroomingSchedule( - id=uuid.uuid4(), day_of_week=0, action="shaving_oneblade", notes="Morning" + id=uuid.uuid4(), + user_id=current_user.user_id, + day_of_week=0, + action="shaving_oneblade", + notes="Morning", + ) + wednesday = GroomingSchedule( + id=uuid.uuid4(), + user_id=current_user.user_id, + day_of_week=2, + action="dermarolling", ) - wednesday = GroomingSchedule(id=uuid.uuid4(), day_of_week=2, action="dermarolling") session.add_all([monday, wednesday]) session.commit() ctx = _build_upcoming_grooming_context( session, + target_user_id=current_user.user_id, start_date=date(2026, 3, 2), days=7, ) @@ -246,14 +304,23 @@ def test_build_upcoming_grooming_context(session: Session): assert "za 2 dni (2026-03-04, środa): dermarolling" in ctx -def test_build_recent_history(session: Session): +def test_build_recent_history(session: Session, current_user): reference_date = date(2026, 3, 10) assert ( - _build_recent_history(session, reference_date=reference_date) + _build_recent_history( + session, + target_user_id=current_user.user_id, + reference_date=reference_date, + ) == "RECENT ROUTINES: none\n" ) - r = Routine(id=uuid.uuid4(), routine_date=reference_date, part_of_day="am") + r = Routine( + id=uuid.uuid4(), + user_id=current_user.user_id, + routine_date=reference_date, + part_of_day="am", + ) session.add(r) p = Product( id=uuid.uuid4(), @@ -268,19 +335,37 @@ def test_build_recent_history(session: Session): session.add(p) session.commit() - s1 = RoutineStep(id=uuid.uuid4(), routine_id=r.id, order_index=1, product_id=p.id) + s1 = RoutineStep( + id=uuid.uuid4(), + user_id=current_user.user_id, + routine_id=r.id, + order_index=1, + product_id=p.id, + ) s2 = RoutineStep( - id=uuid.uuid4(), routine_id=r.id, order_index=2, action_type="shaving_razor" + id=uuid.uuid4(), + user_id=current_user.user_id, + routine_id=r.id, + order_index=2, + action_type="shaving_razor", ) # Step with non-existent product s3 = RoutineStep( - id=uuid.uuid4(), routine_id=r.id, order_index=3, product_id=uuid.uuid4() + id=uuid.uuid4(), + user_id=current_user.user_id, + routine_id=r.id, + order_index=3, + product_id=uuid.uuid4(), ) session.add_all([s1, s2, s3]) session.commit() - ctx = _build_recent_history(session, reference_date=reference_date) + ctx = _build_recent_history( + session, + target_user_id=current_user.user_id, + reference_date=reference_date, + ) assert "RECENT ROUTINES:" in ctx assert "AM:" in ctx assert "cleanser [" in ctx @@ -288,31 +373,38 @@ def test_build_recent_history(session: Session): assert "unknown [" in ctx -def test_build_recent_history_uses_reference_window(session: Session): +def test_build_recent_history_uses_reference_window(session: Session, current_user): reference_date = date(2026, 3, 10) recent = Routine( id=uuid.uuid4(), + user_id=current_user.user_id, routine_date=reference_date - timedelta(days=3), part_of_day="pm", ) old = Routine( id=uuid.uuid4(), + user_id=current_user.user_id, routine_date=reference_date - timedelta(days=6), part_of_day="am", ) session.add_all([recent, old]) session.commit() - ctx = _build_recent_history(session, reference_date=reference_date) + ctx = _build_recent_history( + session, + target_user_id=current_user.user_id, + reference_date=reference_date, + ) assert str(recent.routine_date) in ctx assert str(old.routine_date) not in ctx -def test_build_recent_history_excludes_future_routines(session: Session): +def test_build_recent_history_excludes_future_routines(session: Session, current_user): reference_date = date(2026, 3, 10) future = Routine( id=uuid.uuid4(), + user_id=current_user.user_id, routine_date=reference_date + timedelta(days=1), part_of_day="am", ) @@ -320,12 +412,16 @@ def test_build_recent_history_excludes_future_routines(session: Session): session.commit() assert ( - _build_recent_history(session, reference_date=reference_date) + _build_recent_history( + session, + target_user_id=current_user.user_id, + reference_date=reference_date, + ) == "RECENT ROUTINES: none\n" ) -def test_build_products_context_summary_list(session: Session): +def test_build_products_context_summary_list(session: Session, current_user): p1 = Product( id=uuid.uuid4(), short_id=str(uuid.uuid4())[:8], @@ -336,6 +432,7 @@ def test_build_products_context_summary_list(session: Session): recommended_time="both", leave_on=True, product_effect_profile={}, + user_id=current_user.user_id, ) p2 = Product( id=uuid.uuid4(), @@ -350,11 +447,16 @@ def test_build_products_context_summary_list(session: Session): context_rules={"safe_after_shaving": False}, min_interval_hours=12, max_frequency_per_week=7, + user_id=current_user.user_id, ) session.add_all([p1, p2]) session.commit() - products_am = _get_available_products(session, time_filter="am") + products_am = _get_available_products( + session, + current_user=current_user, + time_filter="am", + ) ctx = build_products_context_summary_list(products_am, {p2.id}) assert "Regaine Minoxidil" in ctx @@ -375,7 +477,7 @@ def test_build_day_context(): assert "Leaving home: no" in _build_day_context(False) -def test_get_available_products_respects_filters(session: Session): +def test_get_available_products_respects_filters(session: Session, current_user): regular_med = Product( id=uuid.uuid4(), name="Tretinoin", @@ -385,6 +487,7 @@ def test_get_available_products_respects_filters(session: Session): recommended_time="pm", leave_on=True, product_effect_profile={}, + user_id=current_user.user_id, ) minoxidil_med = Product( id=uuid.uuid4(), @@ -395,6 +498,7 @@ def test_get_available_products_respects_filters(session: Session): recommended_time="both", leave_on=True, product_effect_profile={}, + user_id=current_user.user_id, ) am_product = Product( id=uuid.uuid4(), @@ -404,6 +508,7 @@ def test_get_available_products_respects_filters(session: Session): recommended_time="am", leave_on=True, product_effect_profile={}, + user_id=current_user.user_id, ) pm_product = Product( id=uuid.uuid4(), @@ -413,11 +518,16 @@ def test_get_available_products_respects_filters(session: Session): recommended_time="pm", leave_on=True, product_effect_profile={}, + user_id=current_user.user_id, ) session.add_all([regular_med, minoxidil_med, am_product, pm_product]) session.commit() - am_available = _get_available_products(session, time_filter="am") + am_available = _get_available_products( + session, + current_user=current_user, + time_filter="am", + ) am_names = {p.name for p in am_available} assert "Tretinoin" not in am_names assert "Minoxidil 5%" in am_names @@ -508,7 +618,10 @@ def test_extract_active_names_uses_compact_distinct_names(session: Session): assert names == ["Niacinamide", "Zinc PCA"] -def test_get_available_products_excludes_minoxidil_when_flag_false(session: Session): +def test_get_available_products_excludes_minoxidil_when_flag_false( + session: Session, + current_user, +): minoxidil = Product( id=uuid.uuid4(), name="Minoxidil 5%", @@ -518,6 +631,7 @@ def test_get_available_products_excludes_minoxidil_when_flag_false(session: Sess recommended_time="both", leave_on=True, product_effect_profile={}, + user_id=current_user.user_id, ) regular = Product( id=uuid.uuid4(), @@ -527,18 +641,27 @@ def test_get_available_products_excludes_minoxidil_when_flag_false(session: Sess recommended_time="both", leave_on=False, product_effect_profile={}, + user_id=current_user.user_id, ) session.add_all([minoxidil, regular]) session.commit() # With flag True (default) - minoxidil included - products = _get_available_products(session, include_minoxidil=True) + products = _get_available_products( + session, + current_user=current_user, + include_minoxidil=True, + ) names = {p.name for p in products} assert "Minoxidil 5%" in names assert "Cleanser" in names # With flag False - minoxidil excluded - products = _get_available_products(session, include_minoxidil=False) + products = _get_available_products( + session, + current_user=current_user, + include_minoxidil=False, + ) names = {p.name for p in products} assert "Minoxidil 5%" not in names assert "Cleanser" in names diff --git a/backend/tests/test_skincare.py b/backend/tests/test_skincare.py index b6ce4b0..ac62a99 100644 --- a/backend/tests/test_skincare.py +++ b/backend/tests/test_skincare.py @@ -140,7 +140,7 @@ def test_analyze_photos_includes_user_profile_context(client, monkeypatch): def _fake_call_gemini(**kwargs): captured.update(kwargs) - return _FakeResponse() + return _FakeResponse(), None monkeypatch.setattr(skincare_api, "call_gemini", _fake_call_gemini) diff --git a/backend/tests/test_tenancy_domains.py b/backend/tests/test_tenancy_domains.py new file mode 100644 index 0000000..bbe1ce9 --- /dev/null +++ b/backend/tests/test_tenancy_domains.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from uuid import uuid4 + +from innercontext.api.auth_deps import get_current_user +from innercontext.auth import CurrentUser, IdentityData, TokenClaims +from innercontext.models import Role +from innercontext.models.ai_log import AICallLog +from main import app + + +def _user(subject: str, *, role: Role = Role.MEMBER) -> CurrentUser: + claims = TokenClaims( + issuer="https://auth.test", + subject=subject, + audience=("innercontext-web",), + expires_at=datetime.now(UTC) + timedelta(hours=1), + raw_claims={"iss": "https://auth.test", "sub": subject}, + ) + return CurrentUser( + user_id=uuid4(), + role=role, + identity=IdentityData.from_claims(claims), + claims=claims, + ) + + +def _set_current_user(user: CurrentUser) -> None: + app.dependency_overrides[get_current_user] = lambda: user + + +def test_profile_health_routines_skincare_ai_logs_are_user_scoped_by_default( + client, session +): + owner = _user("owner") + intruder = _user("intruder") + + _set_current_user(owner) + profile = client.patch( + "/profile", json={"birth_date": "1991-01-15", "sex_at_birth": "male"} + ) + medication = client.post( + "/health/medications", json={"kind": "prescription", "product_name": "Owner Rx"} + ) + routine = client.post( + "/routines", json={"routine_date": "2026-03-01", "part_of_day": "am"} + ) + snapshot = client.post("/skincare", json={"snapshot_date": "2026-03-01"}) + log = AICallLog(endpoint="routines/suggest", model="gemini-3-flash-preview") + log.user_id = owner.user_id + session.add(log) + session.commit() + session.refresh(log) + + assert profile.status_code == 200 + assert medication.status_code == 201 + assert routine.status_code == 201 + assert snapshot.status_code == 201 + + medication_id = medication.json()["record_id"] + routine_id = routine.json()["id"] + snapshot_id = snapshot.json()["id"] + + _set_current_user(intruder) + assert client.get("/profile").json() is None + assert client.get("/health/medications").json() == [] + assert client.get("/routines").json() == [] + assert client.get("/skincare").json() == [] + assert client.get("/ai-logs").json() == [] + + assert client.get(f"/health/medications/{medication_id}").status_code == 404 + assert client.get(f"/routines/{routine_id}").status_code == 404 + assert client.get(f"/skincare/{snapshot_id}").status_code == 404 + assert client.get(f"/ai-logs/{log.id}").status_code == 404 + + +def test_health_admin_override_requires_explicit_user_id(client): + owner = _user("owner") + admin = _user("admin", role=Role.ADMIN) + + _set_current_user(owner) + created = client.post( + "/health/lab-results", + json={ + "collected_at": "2026-03-01T00:00:00", + "test_code": "718-7", + "test_name_original": "Hemoglobin", + }, + ) + assert created.status_code == 201 + + _set_current_user(admin) + default_scope = client.get("/health/lab-results") + assert default_scope.status_code == 200 + assert default_scope.json()["items"] == [] + + overridden = client.get(f"/health/lab-results?user_id={owner.user_id}") + assert overridden.status_code == 200 + assert len(overridden.json()["items"]) == 1