feat(api): add INCI tool-calling with normalized tool traces

Enable on-demand INCI retrieval in /routines/suggest through Gemini function calling so detailed ingredient data is fetched only when needed. Persist and normalize tool_trace data in AI logs to make function-call behavior directly inspectable via /ai-logs endpoints.
This commit is contained in:
Piotr Oleszczyk 2026-03-04 11:35:19 +01:00
parent c0eeb0425d
commit cfd2485b7e
8 changed files with 455 additions and 29 deletions

View file

@ -0,0 +1,29 @@
"""add_tool_trace_to_ai_call_logs
Revision ID: d3e4f5a6b7c8
Revises: b2c3d4e5f6a1
Create Date: 2026-03-04 00:00:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
revision: str = "d3e4f5a6b7c8"
down_revision: Union[str, None] = "b2c3d4e5f6a1"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column(
"ai_call_logs",
sa.Column("tool_trace", sa.JSON(), nullable=True),
)
def downgrade() -> None:
op.drop_column("ai_call_logs", "tool_trace")

View file

@ -1,4 +1,5 @@
from typing import Optional
import json
from typing import Any, Optional
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
@ -10,6 +11,22 @@ from innercontext.models.ai_log import AICallLog
router = APIRouter()
def _normalize_tool_trace(value: object) -> dict[str, Any] | None:
if value is None:
return None
if isinstance(value, dict):
return {str(k): v for k, v in value.items()}
if isinstance(value, str):
try:
parsed = json.loads(value)
except json.JSONDecodeError:
return None
if isinstance(parsed, dict):
return {str(k): v for k, v in parsed.items()}
return None
return None
class AICallLogPublic(SQLModel):
"""List-friendly view: omits large text fields."""
@ -21,6 +38,7 @@ class AICallLogPublic(SQLModel):
completion_tokens: Optional[int] = None
total_tokens: Optional[int] = None
duration_ms: Optional[int] = None
tool_trace: Optional[dict[str, Any]] = None
success: bool
error_detail: Optional[str] = None
@ -37,7 +55,23 @@ def list_ai_logs(
stmt = stmt.where(AICallLog.endpoint == endpoint)
if success is not None:
stmt = stmt.where(AICallLog.success == success)
return session.exec(stmt).all()
logs = session.exec(stmt).all()
return [
AICallLogPublic(
id=log.id,
created_at=log.created_at,
endpoint=log.endpoint,
model=log.model,
prompt_tokens=log.prompt_tokens,
completion_tokens=log.completion_tokens,
total_tokens=log.total_tokens,
duration_ms=log.duration_ms,
tool_trace=_normalize_tool_trace(getattr(log, "tool_trace", None)),
success=log.success,
error_detail=log.error_detail,
)
for log in logs
]
@router.get("/{log_id}", response_model=AICallLog)
@ -45,4 +79,5 @@ def get_ai_log(log_id: UUID, session: Session = Depends(get_session)):
log = session.get(AICallLog, log_id)
if log is None:
raise HTTPException(status_code=404, detail="Log not found")
log.tool_trace = _normalize_tool_trace(getattr(log, "tool_trace", None))
return log

View file

@ -4,12 +4,17 @@ from typing import Optional
from uuid import UUID, uuid4
from fastapi import APIRouter, Depends, HTTPException
from google.genai import types as genai_types
from pydantic import BaseModel as PydanticBase
from sqlmodel import Field, Session, SQLModel, col, select
from db import get_session
from innercontext.api.utils import get_or_404
from innercontext.llm import call_gemini, get_creative_config
from innercontext.llm import (
call_gemini,
call_gemini_with_function_tools,
get_creative_config,
)
from innercontext.models import (
GroomingSchedule,
Product,
@ -302,8 +307,7 @@ def _build_products_context(
time_filter: Optional[str] = None,
reference_date: Optional[date] = None,
) -> str:
stmt = select(Product).where(col(Product.is_tool).is_(False))
products = session.exec(stmt).all()
products = _get_available_products(session, time_filter=time_filter)
product_ids = [p.id for p in products]
inventory_rows = (
session.exec(
@ -333,10 +337,6 @@ def _build_products_context(
lines = ["AVAILABLE PRODUCTS:"]
for p in products:
if p.is_medication and not _is_minoxidil_product(p):
continue
if time_filter and _ev(p.recommended_time) not in (time_filter, "both"):
continue
p.inventory = inv_by_product.get(p.id, [])
ctx = p.to_llm_context()
entry = (
@ -404,6 +404,83 @@ def _build_products_context(
return "\n".join(lines) + "\n"
def _get_available_products(
session: Session,
time_filter: Optional[str] = None,
) -> list[Product]:
stmt = select(Product).where(col(Product.is_tool).is_(False))
products = session.exec(stmt).all()
result: list[Product] = []
for p in products:
if p.is_medication and not _is_minoxidil_product(p):
continue
if time_filter and _ev(p.recommended_time) not in (time_filter, "both"):
continue
result.append(p)
return result
def _build_inci_tool_handler(
products: list[Product],
):
available_by_id = {str(p.id): p for p in products}
def _handler(args: dict[str, object]) -> dict[str, object]:
raw_ids = args.get("product_ids")
if not isinstance(raw_ids, list):
return {"products": []}
requested_ids: list[str] = []
seen: set[str] = set()
for raw_id in raw_ids:
if not isinstance(raw_id, str):
continue
if raw_id in seen:
continue
seen.add(raw_id)
requested_ids.append(raw_id)
if len(requested_ids) >= 8:
break
products_payload = []
for pid in requested_ids:
product = available_by_id.get(pid)
if product is None:
continue
inci = product.inci or []
compact_inci = [str(i)[:120] for i in inci[:128]]
products_payload.append(
{
"id": pid,
"name": product.name,
"inci": compact_inci,
}
)
return {"products": products_payload}
return _handler
_INCI_FUNCTION_DECLARATION = genai_types.FunctionDeclaration(
name="get_product_inci",
description=(
"Return exact INCI ingredient lists for products identified by UUID from "
"the AVAILABLE PRODUCTS list."
),
parameters=genai_types.Schema(
type=genai_types.Type.OBJECT,
properties={
"product_ids": genai_types.Schema(
type=genai_types.Type.ARRAY,
items=genai_types.Schema(type=genai_types.Type.STRING),
description="Product UUIDs from AVAILABLE PRODUCTS.",
)
},
required=["product_ids"],
),
)
def _build_objectives_context(include_minoxidil_beard: bool) -> str:
if include_minoxidil_beard:
return (
@ -574,6 +651,10 @@ def suggest_routine(
products_ctx = _build_products_context(
session, time_filter=data.part_of_day.value, reference_date=data.routine_date
)
available_products = _get_available_products(
session,
time_filter=data.part_of_day.value,
)
objectives_ctx = _build_objectives_context(data.include_minoxidil_beard)
mode_line = "MODE: standard"
@ -586,20 +667,43 @@ def suggest_routine(
f"{mode_line}\n"
"INPUT DATA:\n"
f"{skin_ctx}\n{grooming_ctx}\n{history_ctx}\n{day_ctx}\n{products_ctx}\n{objectives_ctx}"
"\nNARZEDZIA:\n"
"- Masz dostep do funkcji get_product_inci(product_ids).\n"
"- Uzyj jej tylko, gdy potrzebujesz dokladnego skladu INCI do oceny bezpieczenstwa, kompatybilnosci lub redundancji aktywnych.\n"
"- Nie zgaduj INCI; jesli potrzebujesz skladu, wywolaj funkcje dla konkretnych UUID.\n"
f"{notes_line}"
f"{_ROUTINES_SINGLE_EXTRA}\n"
"Zwróć JSON zgodny ze schematem."
)
response = call_gemini(
endpoint="routines/suggest",
contents=prompt,
config = get_creative_config(
system_instruction=_ROUTINES_SYSTEM_PROMPT,
response_schema=_SuggestionOut,
max_output_tokens=4096,
).model_copy(
update={
"tools": [
genai_types.Tool(
function_declarations=[_INCI_FUNCTION_DECLARATION],
)
],
"tool_config": genai_types.ToolConfig(
function_calling_config=genai_types.FunctionCallingConfig(
mode=genai_types.FunctionCallingConfigMode.AUTO,
)
),
}
)
response = call_gemini_with_function_tools(
endpoint="routines/suggest",
contents=prompt,
config=config,
function_handlers={
"get_product_inci": _build_inci_tool_handler(available_products)
},
user_input=prompt,
max_tool_roundtrips=2,
)
raw = response.text

View file

@ -2,6 +2,7 @@
import os
import time
from collections.abc import Callable
from contextlib import suppress
from typing import Any
@ -67,6 +68,7 @@ def call_gemini(
contents,
config: genai_types.GenerateContentConfig,
user_input: str | None = None,
tool_trace: dict[str, Any] | None = None,
):
"""Call Gemini, log full request + response to DB, return response unchanged."""
from sqlmodel import Session
@ -119,6 +121,7 @@ def call_gemini(
system_prompt=sys_prompt,
user_input=user_input,
response_text=response.text if response else None,
tool_trace=tool_trace,
prompt_tokens=(
response.usage_metadata.prompt_token_count
if response and response.usage_metadata
@ -143,3 +146,110 @@ def call_gemini(
s.add(log)
s.commit()
return response
def call_gemini_with_function_tools(
*,
endpoint: str,
contents,
config: genai_types.GenerateContentConfig,
function_handlers: dict[str, Callable[[dict[str, Any]], dict[str, Any]]],
user_input: str | None = None,
max_tool_roundtrips: int = 2,
):
"""Call Gemini with function-calling loop until final response text is produced."""
if max_tool_roundtrips < 0:
raise ValueError("max_tool_roundtrips must be >= 0")
history = list(contents) if isinstance(contents, list) else [contents]
rounds = 0
trace_events: list[dict[str, Any]] = []
while True:
response = call_gemini(
endpoint=endpoint,
contents=history,
config=config,
user_input=user_input,
tool_trace={
"mode": "function_tools",
"round": rounds,
"events": trace_events,
},
)
function_calls = list(getattr(response, "function_calls", None) or [])
if not function_calls:
return response
if rounds >= max_tool_roundtrips:
raise HTTPException(
status_code=502,
detail="Gemini requested too many function calls",
)
candidate_content = None
candidates = getattr(response, "candidates", None) or []
if candidates:
candidate_content = getattr(candidates[0], "content", None)
if candidate_content is not None:
history.append(candidate_content)
else:
history.append(
genai_types.ModelContent(
parts=[genai_types.Part(function_call=fc) for fc in function_calls]
)
)
response_parts: list[genai_types.Part] = []
for fc in function_calls:
name = getattr(fc, "name", None)
if not isinstance(name, str) or not name:
raise HTTPException(
status_code=502,
detail="Gemini requested a function without a valid name",
)
handler = function_handlers.get(name)
if handler is None:
raise HTTPException(
status_code=502,
detail=f"Gemini requested unknown function: {name}",
)
args = getattr(fc, "args", None) or {}
if not isinstance(args, dict):
raise HTTPException(
status_code=502,
detail=f"Gemini returned invalid arguments for function: {name}",
)
tool_response = handler(args)
if not isinstance(tool_response, dict):
raise HTTPException(
status_code=502,
detail=f"Function handler must return an object for: {name}",
)
trace_event: dict[str, Any] = {
"round": rounds + 1,
"function": name,
}
product_ids = args.get("product_ids")
if isinstance(product_ids, list):
clean_ids = [x for x in product_ids if isinstance(x, str)]
trace_event["requested_ids_count"] = len(clean_ids)
trace_event["requested_ids"] = clean_ids[:8]
products = tool_response.get("products")
if isinstance(products, list):
trace_event["returned_products_count"] = len(products)
trace_events.append(trace_event)
response_parts.append(
genai_types.Part.from_function_response(
name=name,
response=tool_response,
)
)
history.append(genai_types.UserContent(parts=response_parts))
rounds += 1

View file

@ -1,7 +1,8 @@
from datetime import datetime
from typing import ClassVar
from typing import Any, ClassVar
from uuid import UUID, uuid4
from sqlalchemy import JSON, Column
from sqlmodel import Field, SQLModel
from .base import utc_now
@ -24,5 +25,9 @@ class AICallLog(SQLModel, table=True):
total_tokens: int | None = Field(default=None)
duration_ms: int | None = Field(default=None)
finish_reason: str | None = Field(default=None)
tool_trace: dict[str, Any] | None = Field(
default=None,
sa_column=Column(JSON, nullable=True),
)
success: bool = Field(default=True, index=True)
error_detail: str | None = Field(default=None)

View file

@ -0,0 +1,44 @@
import uuid
from typing import Any, cast
from innercontext.models.ai_log import AICallLog
def test_list_ai_logs_normalizes_tool_trace_string(client, session):
log = AICallLog(
id=uuid.uuid4(),
endpoint="routines/suggest",
model="gemini-3-flash-preview",
success=True,
)
log.tool_trace = cast(
Any,
'{"mode":"function_tools","events":[{"function":"get_product_inci"}]}',
)
session.add(log)
session.commit()
response = client.get("/ai-logs")
assert response.status_code == 200
data = response.json()
assert len(data) == 1
assert data[0]["tool_trace"]["mode"] == "function_tools"
assert data[0]["tool_trace"]["events"][0]["function"] == "get_product_inci"
def test_get_ai_log_normalizes_tool_trace_string(client, session):
log = AICallLog(
id=uuid.uuid4(),
endpoint="routines/suggest",
model="gemini-3-flash-preview",
success=True,
)
log.tool_trace = cast(Any, '{"mode":"function_tools","round":1}')
session.add(log)
session.commit()
response = client.get(f"/ai-logs/{log.id}")
assert response.status_code == 200
payload = response.json()
assert payload["tool_trace"]["mode"] == "function_tools"
assert payload["tool_trace"]["round"] == 1

View file

@ -220,7 +220,9 @@ def test_delete_grooming_schedule_not_found(client):
def test_suggest_routine(client, session):
with patch("innercontext.api.routines.call_gemini") as mock_gemini:
with patch(
"innercontext.api.routines.call_gemini_with_function_tools"
) as mock_gemini:
# Mock the Gemini response
mock_response = type(
"Response",
@ -245,6 +247,9 @@ def test_suggest_routine(client, session):
assert len(data["steps"]) == 1
assert data["steps"][0]["action_type"] == "shaving_razor"
assert data["reasoning"] == "because"
kwargs = mock_gemini.call_args.kwargs
assert "function_handlers" in kwargs
assert "get_product_inci" in kwargs["function_handlers"]
def test_suggest_batch(client, session):

View file

@ -1,26 +1,28 @@
from datetime import date, timedelta
import uuid
from datetime import date, timedelta
from sqlmodel import Session
from innercontext.api.routines import (
_contains_minoxidil_text,
_is_minoxidil_product,
_ev,
_build_skin_context,
_build_grooming_context,
_build_recent_history,
_build_products_context,
_build_objectives_context,
_build_day_context,
_build_grooming_context,
_build_inci_tool_handler,
_build_objectives_context,
_build_products_context,
_build_recent_history,
_build_skin_context,
_contains_minoxidil_text,
_ev,
_get_available_products,
_is_minoxidil_product,
)
from innercontext.models import (
Product,
SkinConditionSnapshot,
GroomingSchedule,
Product,
ProductInventory,
Routine,
RoutineStep,
ProductInventory,
SkinConditionSnapshot,
)
@ -242,3 +244,95 @@ def test_build_day_context():
assert _build_day_context(None) == ""
assert "Leaving home: yes" in _build_day_context(True)
assert "Leaving home: no" in _build_day_context(False)
def test_get_available_products_respects_filters(session: Session):
regular_med = Product(
id=uuid.uuid4(),
name="Tretinoin",
category="serum",
is_medication=True,
brand="Test",
recommended_time="pm",
leave_on=True,
product_effect_profile={},
)
minoxidil_med = Product(
id=uuid.uuid4(),
name="Minoxidil 5%",
category="serum",
is_medication=True,
brand="Test",
recommended_time="both",
leave_on=True,
product_effect_profile={},
)
am_product = Product(
id=uuid.uuid4(),
name="AM SPF",
category="spf",
brand="Test",
recommended_time="am",
leave_on=True,
product_effect_profile={},
)
pm_product = Product(
id=uuid.uuid4(),
name="PM Cream",
category="moisturizer",
brand="Test",
recommended_time="pm",
leave_on=True,
product_effect_profile={},
)
session.add_all([regular_med, minoxidil_med, am_product, pm_product])
session.commit()
am_available = _get_available_products(session, time_filter="am")
am_names = {p.name for p in am_available}
assert "Tretinoin" not in am_names
assert "Minoxidil 5%" in am_names
assert "AM SPF" in am_names
assert "PM Cream" not in am_names
def test_build_inci_tool_handler_returns_only_available_ids(session: Session):
available = Product(
id=uuid.uuid4(),
name="Available",
category="serum",
brand="Test",
recommended_time="both",
leave_on=True,
inci=["Water", "Niacinamide"],
product_effect_profile={},
)
unavailable = Product(
id=uuid.uuid4(),
name="Unavailable",
category="serum",
brand="Test",
recommended_time="both",
leave_on=True,
inci=["Water", "Retinol"],
product_effect_profile={},
)
handler = _build_inci_tool_handler([available])
payload = handler(
{
"product_ids": [
str(available.id),
str(unavailable.id),
str(available.id),
123,
]
}
)
assert "products" in payload
products = payload["products"]
assert len(products) == 1
assert products[0]["id"] == str(available.id)
assert products[0]["name"] == "Available"
assert products[0]["inci"] == ["Water", "Niacinamide"]