feat(api): add INCI tool-calling with normalized tool traces

Enable on-demand INCI retrieval in /routines/suggest through Gemini function calling so detailed ingredient data is fetched only when needed. Persist and normalize tool_trace data in AI logs to make function-call behavior directly inspectable via /ai-logs endpoints.
This commit is contained in:
Piotr Oleszczyk 2026-03-04 11:35:19 +01:00
parent c0eeb0425d
commit cfd2485b7e
8 changed files with 455 additions and 29 deletions

View file

@ -1,4 +1,5 @@
from typing import Optional
import json
from typing import Any, Optional
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
@ -10,6 +11,22 @@ from innercontext.models.ai_log import AICallLog
router = APIRouter()
def _normalize_tool_trace(value: object) -> dict[str, Any] | None:
if value is None:
return None
if isinstance(value, dict):
return {str(k): v for k, v in value.items()}
if isinstance(value, str):
try:
parsed = json.loads(value)
except json.JSONDecodeError:
return None
if isinstance(parsed, dict):
return {str(k): v for k, v in parsed.items()}
return None
return None
class AICallLogPublic(SQLModel):
"""List-friendly view: omits large text fields."""
@ -21,6 +38,7 @@ class AICallLogPublic(SQLModel):
completion_tokens: Optional[int] = None
total_tokens: Optional[int] = None
duration_ms: Optional[int] = None
tool_trace: Optional[dict[str, Any]] = None
success: bool
error_detail: Optional[str] = None
@ -37,7 +55,23 @@ def list_ai_logs(
stmt = stmt.where(AICallLog.endpoint == endpoint)
if success is not None:
stmt = stmt.where(AICallLog.success == success)
return session.exec(stmt).all()
logs = session.exec(stmt).all()
return [
AICallLogPublic(
id=log.id,
created_at=log.created_at,
endpoint=log.endpoint,
model=log.model,
prompt_tokens=log.prompt_tokens,
completion_tokens=log.completion_tokens,
total_tokens=log.total_tokens,
duration_ms=log.duration_ms,
tool_trace=_normalize_tool_trace(getattr(log, "tool_trace", None)),
success=log.success,
error_detail=log.error_detail,
)
for log in logs
]
@router.get("/{log_id}", response_model=AICallLog)
@ -45,4 +79,5 @@ def get_ai_log(log_id: UUID, session: Session = Depends(get_session)):
log = session.get(AICallLog, log_id)
if log is None:
raise HTTPException(status_code=404, detail="Log not found")
log.tool_trace = _normalize_tool_trace(getattr(log, "tool_trace", None))
return log

View file

@ -4,12 +4,17 @@ from typing import Optional
from uuid import UUID, uuid4
from fastapi import APIRouter, Depends, HTTPException
from google.genai import types as genai_types
from pydantic import BaseModel as PydanticBase
from sqlmodel import Field, Session, SQLModel, col, select
from db import get_session
from innercontext.api.utils import get_or_404
from innercontext.llm import call_gemini, get_creative_config
from innercontext.llm import (
call_gemini,
call_gemini_with_function_tools,
get_creative_config,
)
from innercontext.models import (
GroomingSchedule,
Product,
@ -302,8 +307,7 @@ def _build_products_context(
time_filter: Optional[str] = None,
reference_date: Optional[date] = None,
) -> str:
stmt = select(Product).where(col(Product.is_tool).is_(False))
products = session.exec(stmt).all()
products = _get_available_products(session, time_filter=time_filter)
product_ids = [p.id for p in products]
inventory_rows = (
session.exec(
@ -333,10 +337,6 @@ def _build_products_context(
lines = ["AVAILABLE PRODUCTS:"]
for p in products:
if p.is_medication and not _is_minoxidil_product(p):
continue
if time_filter and _ev(p.recommended_time) not in (time_filter, "both"):
continue
p.inventory = inv_by_product.get(p.id, [])
ctx = p.to_llm_context()
entry = (
@ -404,6 +404,83 @@ def _build_products_context(
return "\n".join(lines) + "\n"
def _get_available_products(
session: Session,
time_filter: Optional[str] = None,
) -> list[Product]:
stmt = select(Product).where(col(Product.is_tool).is_(False))
products = session.exec(stmt).all()
result: list[Product] = []
for p in products:
if p.is_medication and not _is_minoxidil_product(p):
continue
if time_filter and _ev(p.recommended_time) not in (time_filter, "both"):
continue
result.append(p)
return result
def _build_inci_tool_handler(
products: list[Product],
):
available_by_id = {str(p.id): p for p in products}
def _handler(args: dict[str, object]) -> dict[str, object]:
raw_ids = args.get("product_ids")
if not isinstance(raw_ids, list):
return {"products": []}
requested_ids: list[str] = []
seen: set[str] = set()
for raw_id in raw_ids:
if not isinstance(raw_id, str):
continue
if raw_id in seen:
continue
seen.add(raw_id)
requested_ids.append(raw_id)
if len(requested_ids) >= 8:
break
products_payload = []
for pid in requested_ids:
product = available_by_id.get(pid)
if product is None:
continue
inci = product.inci or []
compact_inci = [str(i)[:120] for i in inci[:128]]
products_payload.append(
{
"id": pid,
"name": product.name,
"inci": compact_inci,
}
)
return {"products": products_payload}
return _handler
_INCI_FUNCTION_DECLARATION = genai_types.FunctionDeclaration(
name="get_product_inci",
description=(
"Return exact INCI ingredient lists for products identified by UUID from "
"the AVAILABLE PRODUCTS list."
),
parameters=genai_types.Schema(
type=genai_types.Type.OBJECT,
properties={
"product_ids": genai_types.Schema(
type=genai_types.Type.ARRAY,
items=genai_types.Schema(type=genai_types.Type.STRING),
description="Product UUIDs from AVAILABLE PRODUCTS.",
)
},
required=["product_ids"],
),
)
def _build_objectives_context(include_minoxidil_beard: bool) -> str:
if include_minoxidil_beard:
return (
@ -574,6 +651,10 @@ def suggest_routine(
products_ctx = _build_products_context(
session, time_filter=data.part_of_day.value, reference_date=data.routine_date
)
available_products = _get_available_products(
session,
time_filter=data.part_of_day.value,
)
objectives_ctx = _build_objectives_context(data.include_minoxidil_beard)
mode_line = "MODE: standard"
@ -586,20 +667,43 @@ def suggest_routine(
f"{mode_line}\n"
"INPUT DATA:\n"
f"{skin_ctx}\n{grooming_ctx}\n{history_ctx}\n{day_ctx}\n{products_ctx}\n{objectives_ctx}"
"\nNARZEDZIA:\n"
"- Masz dostep do funkcji get_product_inci(product_ids).\n"
"- Uzyj jej tylko, gdy potrzebujesz dokladnego skladu INCI do oceny bezpieczenstwa, kompatybilnosci lub redundancji aktywnych.\n"
"- Nie zgaduj INCI; jesli potrzebujesz skladu, wywolaj funkcje dla konkretnych UUID.\n"
f"{notes_line}"
f"{_ROUTINES_SINGLE_EXTRA}\n"
"Zwróć JSON zgodny ze schematem."
)
response = call_gemini(
config = get_creative_config(
system_instruction=_ROUTINES_SYSTEM_PROMPT,
response_schema=_SuggestionOut,
max_output_tokens=4096,
).model_copy(
update={
"tools": [
genai_types.Tool(
function_declarations=[_INCI_FUNCTION_DECLARATION],
)
],
"tool_config": genai_types.ToolConfig(
function_calling_config=genai_types.FunctionCallingConfig(
mode=genai_types.FunctionCallingConfigMode.AUTO,
)
),
}
)
response = call_gemini_with_function_tools(
endpoint="routines/suggest",
contents=prompt,
config=get_creative_config(
system_instruction=_ROUTINES_SYSTEM_PROMPT,
response_schema=_SuggestionOut,
max_output_tokens=4096,
),
config=config,
function_handlers={
"get_product_inci": _build_inci_tool_handler(available_products)
},
user_input=prompt,
max_tool_roundtrips=2,
)
raw = response.text

View file

@ -2,6 +2,7 @@
import os
import time
from collections.abc import Callable
from contextlib import suppress
from typing import Any
@ -67,6 +68,7 @@ def call_gemini(
contents,
config: genai_types.GenerateContentConfig,
user_input: str | None = None,
tool_trace: dict[str, Any] | None = None,
):
"""Call Gemini, log full request + response to DB, return response unchanged."""
from sqlmodel import Session
@ -119,6 +121,7 @@ def call_gemini(
system_prompt=sys_prompt,
user_input=user_input,
response_text=response.text if response else None,
tool_trace=tool_trace,
prompt_tokens=(
response.usage_metadata.prompt_token_count
if response and response.usage_metadata
@ -143,3 +146,110 @@ def call_gemini(
s.add(log)
s.commit()
return response
def call_gemini_with_function_tools(
*,
endpoint: str,
contents,
config: genai_types.GenerateContentConfig,
function_handlers: dict[str, Callable[[dict[str, Any]], dict[str, Any]]],
user_input: str | None = None,
max_tool_roundtrips: int = 2,
):
"""Call Gemini with function-calling loop until final response text is produced."""
if max_tool_roundtrips < 0:
raise ValueError("max_tool_roundtrips must be >= 0")
history = list(contents) if isinstance(contents, list) else [contents]
rounds = 0
trace_events: list[dict[str, Any]] = []
while True:
response = call_gemini(
endpoint=endpoint,
contents=history,
config=config,
user_input=user_input,
tool_trace={
"mode": "function_tools",
"round": rounds,
"events": trace_events,
},
)
function_calls = list(getattr(response, "function_calls", None) or [])
if not function_calls:
return response
if rounds >= max_tool_roundtrips:
raise HTTPException(
status_code=502,
detail="Gemini requested too many function calls",
)
candidate_content = None
candidates = getattr(response, "candidates", None) or []
if candidates:
candidate_content = getattr(candidates[0], "content", None)
if candidate_content is not None:
history.append(candidate_content)
else:
history.append(
genai_types.ModelContent(
parts=[genai_types.Part(function_call=fc) for fc in function_calls]
)
)
response_parts: list[genai_types.Part] = []
for fc in function_calls:
name = getattr(fc, "name", None)
if not isinstance(name, str) or not name:
raise HTTPException(
status_code=502,
detail="Gemini requested a function without a valid name",
)
handler = function_handlers.get(name)
if handler is None:
raise HTTPException(
status_code=502,
detail=f"Gemini requested unknown function: {name}",
)
args = getattr(fc, "args", None) or {}
if not isinstance(args, dict):
raise HTTPException(
status_code=502,
detail=f"Gemini returned invalid arguments for function: {name}",
)
tool_response = handler(args)
if not isinstance(tool_response, dict):
raise HTTPException(
status_code=502,
detail=f"Function handler must return an object for: {name}",
)
trace_event: dict[str, Any] = {
"round": rounds + 1,
"function": name,
}
product_ids = args.get("product_ids")
if isinstance(product_ids, list):
clean_ids = [x for x in product_ids if isinstance(x, str)]
trace_event["requested_ids_count"] = len(clean_ids)
trace_event["requested_ids"] = clean_ids[:8]
products = tool_response.get("products")
if isinstance(products, list):
trace_event["returned_products_count"] = len(products)
trace_events.append(trace_event)
response_parts.append(
genai_types.Part.from_function_response(
name=name,
response=tool_response,
)
)
history.append(genai_types.UserContent(parts=response_parts))
rounds += 1

View file

@ -1,7 +1,8 @@
from datetime import datetime
from typing import ClassVar
from typing import Any, ClassVar
from uuid import UUID, uuid4
from sqlalchemy import JSON, Column
from sqlmodel import Field, SQLModel
from .base import utc_now
@ -24,5 +25,9 @@ class AICallLog(SQLModel, table=True):
total_tokens: int | None = Field(default=None)
duration_ms: int | None = Field(default=None)
finish_reason: str | None = Field(default=None)
tool_trace: dict[str, Any] | None = Field(
default=None,
sa_column=Column(JSON, nullable=True),
)
success: bool = Field(default=True, index=True)
error_detail: str | None = Field(default=None)