diff --git a/backend/alembic/versions/d3e4f5a6b7c8_add_tool_trace_to_ai_call_logs.py b/backend/alembic/versions/d3e4f5a6b7c8_add_tool_trace_to_ai_call_logs.py new file mode 100644 index 0000000..a6def60 --- /dev/null +++ b/backend/alembic/versions/d3e4f5a6b7c8_add_tool_trace_to_ai_call_logs.py @@ -0,0 +1,29 @@ +"""add_tool_trace_to_ai_call_logs + +Revision ID: d3e4f5a6b7c8 +Revises: b2c3d4e5f6a1 +Create Date: 2026-03-04 00:00:00.000000 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +revision: str = "d3e4f5a6b7c8" +down_revision: Union[str, None] = "b2c3d4e5f6a1" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column( + "ai_call_logs", + sa.Column("tool_trace", sa.JSON(), nullable=True), + ) + + +def downgrade() -> None: + op.drop_column("ai_call_logs", "tool_trace") diff --git a/backend/innercontext/api/ai_logs.py b/backend/innercontext/api/ai_logs.py index 184be95..040d47e 100644 --- a/backend/innercontext/api/ai_logs.py +++ b/backend/innercontext/api/ai_logs.py @@ -1,4 +1,5 @@ -from typing import Optional +import json +from typing import Any, Optional from uuid import UUID from fastapi import APIRouter, Depends, HTTPException @@ -10,6 +11,22 @@ from innercontext.models.ai_log import AICallLog router = APIRouter() +def _normalize_tool_trace(value: object) -> dict[str, Any] | None: + if value is None: + return None + if isinstance(value, dict): + return {str(k): v for k, v in value.items()} + if isinstance(value, str): + try: + parsed = json.loads(value) + except json.JSONDecodeError: + return None + if isinstance(parsed, dict): + return {str(k): v for k, v in parsed.items()} + return None + return None + + class AICallLogPublic(SQLModel): """List-friendly view: omits large text fields.""" @@ -21,6 +38,7 @@ class AICallLogPublic(SQLModel): completion_tokens: Optional[int] = None total_tokens: Optional[int] = None duration_ms: Optional[int] = None + tool_trace: Optional[dict[str, Any]] = None success: bool error_detail: Optional[str] = None @@ -37,7 +55,23 @@ def list_ai_logs( stmt = stmt.where(AICallLog.endpoint == endpoint) if success is not None: stmt = stmt.where(AICallLog.success == success) - return session.exec(stmt).all() + logs = session.exec(stmt).all() + return [ + AICallLogPublic( + id=log.id, + created_at=log.created_at, + endpoint=log.endpoint, + model=log.model, + prompt_tokens=log.prompt_tokens, + completion_tokens=log.completion_tokens, + total_tokens=log.total_tokens, + duration_ms=log.duration_ms, + tool_trace=_normalize_tool_trace(getattr(log, "tool_trace", None)), + success=log.success, + error_detail=log.error_detail, + ) + for log in logs + ] @router.get("/{log_id}", response_model=AICallLog) @@ -45,4 +79,5 @@ def get_ai_log(log_id: UUID, session: Session = Depends(get_session)): log = session.get(AICallLog, log_id) if log is None: raise HTTPException(status_code=404, detail="Log not found") + log.tool_trace = _normalize_tool_trace(getattr(log, "tool_trace", None)) return log diff --git a/backend/innercontext/api/routines.py b/backend/innercontext/api/routines.py index 475405f..3f682d5 100644 --- a/backend/innercontext/api/routines.py +++ b/backend/innercontext/api/routines.py @@ -4,12 +4,17 @@ from typing import Optional from uuid import UUID, uuid4 from fastapi import APIRouter, Depends, HTTPException +from google.genai import types as genai_types from pydantic import BaseModel as PydanticBase from sqlmodel import Field, Session, SQLModel, col, select from db import get_session from innercontext.api.utils import get_or_404 -from innercontext.llm import call_gemini, get_creative_config +from innercontext.llm import ( + call_gemini, + call_gemini_with_function_tools, + get_creative_config, +) from innercontext.models import ( GroomingSchedule, Product, @@ -302,8 +307,7 @@ def _build_products_context( time_filter: Optional[str] = None, reference_date: Optional[date] = None, ) -> str: - stmt = select(Product).where(col(Product.is_tool).is_(False)) - products = session.exec(stmt).all() + products = _get_available_products(session, time_filter=time_filter) product_ids = [p.id for p in products] inventory_rows = ( session.exec( @@ -333,10 +337,6 @@ def _build_products_context( lines = ["AVAILABLE PRODUCTS:"] for p in products: - if p.is_medication and not _is_minoxidil_product(p): - continue - if time_filter and _ev(p.recommended_time) not in (time_filter, "both"): - continue p.inventory = inv_by_product.get(p.id, []) ctx = p.to_llm_context() entry = ( @@ -404,6 +404,83 @@ def _build_products_context( return "\n".join(lines) + "\n" +def _get_available_products( + session: Session, + time_filter: Optional[str] = None, +) -> list[Product]: + stmt = select(Product).where(col(Product.is_tool).is_(False)) + products = session.exec(stmt).all() + result: list[Product] = [] + for p in products: + if p.is_medication and not _is_minoxidil_product(p): + continue + if time_filter and _ev(p.recommended_time) not in (time_filter, "both"): + continue + result.append(p) + return result + + +def _build_inci_tool_handler( + products: list[Product], +): + available_by_id = {str(p.id): p for p in products} + + def _handler(args: dict[str, object]) -> dict[str, object]: + raw_ids = args.get("product_ids") + if not isinstance(raw_ids, list): + return {"products": []} + + requested_ids: list[str] = [] + seen: set[str] = set() + for raw_id in raw_ids: + if not isinstance(raw_id, str): + continue + if raw_id in seen: + continue + seen.add(raw_id) + requested_ids.append(raw_id) + if len(requested_ids) >= 8: + break + + products_payload = [] + for pid in requested_ids: + product = available_by_id.get(pid) + if product is None: + continue + inci = product.inci or [] + compact_inci = [str(i)[:120] for i in inci[:128]] + products_payload.append( + { + "id": pid, + "name": product.name, + "inci": compact_inci, + } + ) + return {"products": products_payload} + + return _handler + + +_INCI_FUNCTION_DECLARATION = genai_types.FunctionDeclaration( + name="get_product_inci", + description=( + "Return exact INCI ingredient lists for products identified by UUID from " + "the AVAILABLE PRODUCTS list." + ), + parameters=genai_types.Schema( + type=genai_types.Type.OBJECT, + properties={ + "product_ids": genai_types.Schema( + type=genai_types.Type.ARRAY, + items=genai_types.Schema(type=genai_types.Type.STRING), + description="Product UUIDs from AVAILABLE PRODUCTS.", + ) + }, + required=["product_ids"], + ), +) + + def _build_objectives_context(include_minoxidil_beard: bool) -> str: if include_minoxidil_beard: return ( @@ -574,6 +651,10 @@ def suggest_routine( products_ctx = _build_products_context( session, time_filter=data.part_of_day.value, reference_date=data.routine_date ) + available_products = _get_available_products( + session, + time_filter=data.part_of_day.value, + ) objectives_ctx = _build_objectives_context(data.include_minoxidil_beard) mode_line = "MODE: standard" @@ -586,20 +667,43 @@ def suggest_routine( f"{mode_line}\n" "INPUT DATA:\n" f"{skin_ctx}\n{grooming_ctx}\n{history_ctx}\n{day_ctx}\n{products_ctx}\n{objectives_ctx}" + "\nNARZEDZIA:\n" + "- Masz dostep do funkcji get_product_inci(product_ids).\n" + "- Uzyj jej tylko, gdy potrzebujesz dokladnego skladu INCI do oceny bezpieczenstwa, kompatybilnosci lub redundancji aktywnych.\n" + "- Nie zgaduj INCI; jesli potrzebujesz skladu, wywolaj funkcje dla konkretnych UUID.\n" f"{notes_line}" f"{_ROUTINES_SINGLE_EXTRA}\n" "Zwróć JSON zgodny ze schematem." ) - response = call_gemini( + config = get_creative_config( + system_instruction=_ROUTINES_SYSTEM_PROMPT, + response_schema=_SuggestionOut, + max_output_tokens=4096, + ).model_copy( + update={ + "tools": [ + genai_types.Tool( + function_declarations=[_INCI_FUNCTION_DECLARATION], + ) + ], + "tool_config": genai_types.ToolConfig( + function_calling_config=genai_types.FunctionCallingConfig( + mode=genai_types.FunctionCallingConfigMode.AUTO, + ) + ), + } + ) + + response = call_gemini_with_function_tools( endpoint="routines/suggest", contents=prompt, - config=get_creative_config( - system_instruction=_ROUTINES_SYSTEM_PROMPT, - response_schema=_SuggestionOut, - max_output_tokens=4096, - ), + config=config, + function_handlers={ + "get_product_inci": _build_inci_tool_handler(available_products) + }, user_input=prompt, + max_tool_roundtrips=2, ) raw = response.text diff --git a/backend/innercontext/llm.py b/backend/innercontext/llm.py index 3381566..13f0b10 100644 --- a/backend/innercontext/llm.py +++ b/backend/innercontext/llm.py @@ -2,6 +2,7 @@ import os import time +from collections.abc import Callable from contextlib import suppress from typing import Any @@ -67,6 +68,7 @@ def call_gemini( contents, config: genai_types.GenerateContentConfig, user_input: str | None = None, + tool_trace: dict[str, Any] | None = None, ): """Call Gemini, log full request + response to DB, return response unchanged.""" from sqlmodel import Session @@ -119,6 +121,7 @@ def call_gemini( system_prompt=sys_prompt, user_input=user_input, response_text=response.text if response else None, + tool_trace=tool_trace, prompt_tokens=( response.usage_metadata.prompt_token_count if response and response.usage_metadata @@ -143,3 +146,110 @@ def call_gemini( s.add(log) s.commit() return response + + +def call_gemini_with_function_tools( + *, + endpoint: str, + contents, + config: genai_types.GenerateContentConfig, + function_handlers: dict[str, Callable[[dict[str, Any]], dict[str, Any]]], + user_input: str | None = None, + max_tool_roundtrips: int = 2, +): + """Call Gemini with function-calling loop until final response text is produced.""" + if max_tool_roundtrips < 0: + raise ValueError("max_tool_roundtrips must be >= 0") + + history = list(contents) if isinstance(contents, list) else [contents] + rounds = 0 + trace_events: list[dict[str, Any]] = [] + + while True: + response = call_gemini( + endpoint=endpoint, + contents=history, + config=config, + user_input=user_input, + tool_trace={ + "mode": "function_tools", + "round": rounds, + "events": trace_events, + }, + ) + function_calls = list(getattr(response, "function_calls", None) or []) + if not function_calls: + return response + + if rounds >= max_tool_roundtrips: + raise HTTPException( + status_code=502, + detail="Gemini requested too many function calls", + ) + + candidate_content = None + candidates = getattr(response, "candidates", None) or [] + if candidates: + candidate_content = getattr(candidates[0], "content", None) + if candidate_content is not None: + history.append(candidate_content) + else: + history.append( + genai_types.ModelContent( + parts=[genai_types.Part(function_call=fc) for fc in function_calls] + ) + ) + + response_parts: list[genai_types.Part] = [] + for fc in function_calls: + name = getattr(fc, "name", None) + if not isinstance(name, str) or not name: + raise HTTPException( + status_code=502, + detail="Gemini requested a function without a valid name", + ) + + handler = function_handlers.get(name) + if handler is None: + raise HTTPException( + status_code=502, + detail=f"Gemini requested unknown function: {name}", + ) + + args = getattr(fc, "args", None) or {} + if not isinstance(args, dict): + raise HTTPException( + status_code=502, + detail=f"Gemini returned invalid arguments for function: {name}", + ) + + tool_response = handler(args) + if not isinstance(tool_response, dict): + raise HTTPException( + status_code=502, + detail=f"Function handler must return an object for: {name}", + ) + + trace_event: dict[str, Any] = { + "round": rounds + 1, + "function": name, + } + product_ids = args.get("product_ids") + if isinstance(product_ids, list): + clean_ids = [x for x in product_ids if isinstance(x, str)] + trace_event["requested_ids_count"] = len(clean_ids) + trace_event["requested_ids"] = clean_ids[:8] + products = tool_response.get("products") + if isinstance(products, list): + trace_event["returned_products_count"] = len(products) + trace_events.append(trace_event) + + response_parts.append( + genai_types.Part.from_function_response( + name=name, + response=tool_response, + ) + ) + + history.append(genai_types.UserContent(parts=response_parts)) + rounds += 1 diff --git a/backend/innercontext/models/ai_log.py b/backend/innercontext/models/ai_log.py index 44bb749..dbd2d5e 100644 --- a/backend/innercontext/models/ai_log.py +++ b/backend/innercontext/models/ai_log.py @@ -1,7 +1,8 @@ from datetime import datetime -from typing import ClassVar +from typing import Any, ClassVar from uuid import UUID, uuid4 +from sqlalchemy import JSON, Column from sqlmodel import Field, SQLModel from .base import utc_now @@ -24,5 +25,9 @@ class AICallLog(SQLModel, table=True): total_tokens: int | None = Field(default=None) duration_ms: int | None = Field(default=None) finish_reason: str | None = Field(default=None) + tool_trace: dict[str, Any] | None = Field( + default=None, + sa_column=Column(JSON, nullable=True), + ) success: bool = Field(default=True, index=True) error_detail: str | None = Field(default=None) diff --git a/backend/tests/test_ai_logs.py b/backend/tests/test_ai_logs.py new file mode 100644 index 0000000..47a168e --- /dev/null +++ b/backend/tests/test_ai_logs.py @@ -0,0 +1,44 @@ +import uuid +from typing import Any, cast + +from innercontext.models.ai_log import AICallLog + + +def test_list_ai_logs_normalizes_tool_trace_string(client, session): + log = AICallLog( + id=uuid.uuid4(), + endpoint="routines/suggest", + model="gemini-3-flash-preview", + success=True, + ) + log.tool_trace = cast( + Any, + '{"mode":"function_tools","events":[{"function":"get_product_inci"}]}', + ) + session.add(log) + session.commit() + + response = client.get("/ai-logs") + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["tool_trace"]["mode"] == "function_tools" + assert data[0]["tool_trace"]["events"][0]["function"] == "get_product_inci" + + +def test_get_ai_log_normalizes_tool_trace_string(client, session): + log = AICallLog( + id=uuid.uuid4(), + endpoint="routines/suggest", + model="gemini-3-flash-preview", + success=True, + ) + log.tool_trace = cast(Any, '{"mode":"function_tools","round":1}') + session.add(log) + session.commit() + + response = client.get(f"/ai-logs/{log.id}") + assert response.status_code == 200 + payload = response.json() + assert payload["tool_trace"]["mode"] == "function_tools" + assert payload["tool_trace"]["round"] == 1 diff --git a/backend/tests/test_routines.py b/backend/tests/test_routines.py index 8b63df1..9001c6c 100644 --- a/backend/tests/test_routines.py +++ b/backend/tests/test_routines.py @@ -220,7 +220,9 @@ def test_delete_grooming_schedule_not_found(client): def test_suggest_routine(client, session): - with patch("innercontext.api.routines.call_gemini") as mock_gemini: + with patch( + "innercontext.api.routines.call_gemini_with_function_tools" + ) as mock_gemini: # Mock the Gemini response mock_response = type( "Response", @@ -245,6 +247,9 @@ def test_suggest_routine(client, session): assert len(data["steps"]) == 1 assert data["steps"][0]["action_type"] == "shaving_razor" assert data["reasoning"] == "because" + kwargs = mock_gemini.call_args.kwargs + assert "function_handlers" in kwargs + assert "get_product_inci" in kwargs["function_handlers"] def test_suggest_batch(client, session): diff --git a/backend/tests/test_routines_helpers.py b/backend/tests/test_routines_helpers.py index 029d539..29dfc4e 100644 --- a/backend/tests/test_routines_helpers.py +++ b/backend/tests/test_routines_helpers.py @@ -1,26 +1,28 @@ -from datetime import date, timedelta import uuid +from datetime import date, timedelta from sqlmodel import Session from innercontext.api.routines import ( - _contains_minoxidil_text, - _is_minoxidil_product, - _ev, - _build_skin_context, - _build_grooming_context, - _build_recent_history, - _build_products_context, - _build_objectives_context, _build_day_context, + _build_grooming_context, + _build_inci_tool_handler, + _build_objectives_context, + _build_products_context, + _build_recent_history, + _build_skin_context, + _contains_minoxidil_text, + _ev, + _get_available_products, + _is_minoxidil_product, ) from innercontext.models import ( - Product, - SkinConditionSnapshot, GroomingSchedule, + Product, + ProductInventory, Routine, RoutineStep, - ProductInventory, + SkinConditionSnapshot, ) @@ -242,3 +244,95 @@ def test_build_day_context(): assert _build_day_context(None) == "" assert "Leaving home: yes" in _build_day_context(True) assert "Leaving home: no" in _build_day_context(False) + + +def test_get_available_products_respects_filters(session: Session): + regular_med = Product( + id=uuid.uuid4(), + name="Tretinoin", + category="serum", + is_medication=True, + brand="Test", + recommended_time="pm", + leave_on=True, + product_effect_profile={}, + ) + minoxidil_med = Product( + id=uuid.uuid4(), + name="Minoxidil 5%", + category="serum", + is_medication=True, + brand="Test", + recommended_time="both", + leave_on=True, + product_effect_profile={}, + ) + am_product = Product( + id=uuid.uuid4(), + name="AM SPF", + category="spf", + brand="Test", + recommended_time="am", + leave_on=True, + product_effect_profile={}, + ) + pm_product = Product( + id=uuid.uuid4(), + name="PM Cream", + category="moisturizer", + brand="Test", + recommended_time="pm", + leave_on=True, + product_effect_profile={}, + ) + session.add_all([regular_med, minoxidil_med, am_product, pm_product]) + session.commit() + + am_available = _get_available_products(session, time_filter="am") + am_names = {p.name for p in am_available} + assert "Tretinoin" not in am_names + assert "Minoxidil 5%" in am_names + assert "AM SPF" in am_names + assert "PM Cream" not in am_names + + +def test_build_inci_tool_handler_returns_only_available_ids(session: Session): + available = Product( + id=uuid.uuid4(), + name="Available", + category="serum", + brand="Test", + recommended_time="both", + leave_on=True, + inci=["Water", "Niacinamide"], + product_effect_profile={}, + ) + unavailable = Product( + id=uuid.uuid4(), + name="Unavailable", + category="serum", + brand="Test", + recommended_time="both", + leave_on=True, + inci=["Water", "Retinol"], + product_effect_profile={}, + ) + + handler = _build_inci_tool_handler([available]) + payload = handler( + { + "product_ids": [ + str(available.id), + str(unavailable.id), + str(available.id), + 123, + ] + } + ) + + assert "products" in payload + products = payload["products"] + assert len(products) == 1 + assert products[0]["id"] == str(available.id) + assert products[0]["name"] == "Available" + assert products[0]["inci"] == ["Water", "Niacinamide"]