innercontext/backend/innercontext/llm.py
Piotr Oleszczyk cfd2485b7e feat(api): add INCI tool-calling with normalized tool traces
Enable on-demand INCI retrieval in /routines/suggest through Gemini function calling so detailed ingredient data is fetched only when needed. Persist and normalize tool_trace data in AI logs to make function-call behavior directly inspectable via /ai-logs endpoints.
2026-03-04 11:35:19 +01:00

255 lines
8.7 KiB
Python

"""Shared helpers for Gemini API access."""
import os
import time
from collections.abc import Callable
from contextlib import suppress
from typing import Any
from fastapi import HTTPException
from google import genai
from google.genai import types as genai_types
_DEFAULT_MODEL = "gemini-3-flash-preview"
def get_extraction_config(
system_instruction: str,
response_schema: Any,
max_output_tokens: int = 8192,
) -> genai_types.GenerateContentConfig:
"""Config for strict data extraction (deterministic, minimal thinking)."""
return genai_types.GenerateContentConfig(
system_instruction=system_instruction,
response_mime_type="application/json",
response_schema=response_schema,
max_output_tokens=max_output_tokens,
temperature=0.0,
thinking_config=genai_types.ThinkingConfig(
thinking_level=genai_types.ThinkingLevel.MINIMAL
),
)
def get_creative_config(
system_instruction: str,
response_schema: Any,
max_output_tokens: int = 4096,
) -> genai_types.GenerateContentConfig:
"""Config for creative tasks like recommendations (balanced creativity)."""
return genai_types.GenerateContentConfig(
system_instruction=system_instruction,
response_mime_type="application/json",
response_schema=response_schema,
max_output_tokens=max_output_tokens,
temperature=0.4,
top_p=0.8,
thinking_config=genai_types.ThinkingConfig(
thinking_level=genai_types.ThinkingLevel.LOW
),
)
def get_gemini_client() -> tuple[genai.Client, str]:
"""Return an authenticated Gemini client and the configured model name.
Raises HTTP 503 if GEMINI_API_KEY is not set.
"""
api_key = os.environ.get("GEMINI_API_KEY")
if not api_key:
raise HTTPException(status_code=503, detail="GEMINI_API_KEY not configured")
model = os.environ.get("GEMINI_MODEL", _DEFAULT_MODEL)
return genai.Client(api_key=api_key), model
def call_gemini(
*,
endpoint: str,
contents,
config: genai_types.GenerateContentConfig,
user_input: str | None = None,
tool_trace: dict[str, Any] | None = None,
):
"""Call Gemini, log full request + response to DB, return response unchanged."""
from sqlmodel import Session
from db import engine # deferred to avoid circular import at module load
from innercontext.models.ai_log import AICallLog
client, model = get_gemini_client()
sys_prompt = None
if config.system_instruction:
raw = config.system_instruction
sys_prompt = raw if isinstance(raw, str) else str(raw)
if user_input is None:
with suppress(Exception):
user_input = str(contents)
start = time.monotonic()
success, error_detail, response, finish_reason = True, None, None, None
try:
response = client.models.generate_content(
model=model, contents=contents, config=config
)
candidates = getattr(response, "candidates", None)
if candidates:
first_candidate = candidates[0]
reason = getattr(first_candidate, "finish_reason", None)
reason_name = getattr(reason, "name", None)
if isinstance(reason_name, str):
finish_reason = reason_name
if finish_reason and finish_reason != "STOP":
success = False
error_detail = f"finish_reason: {finish_reason}"
raise HTTPException(
status_code=502,
detail=f"Gemini stopped early (finish_reason={finish_reason})",
)
except HTTPException:
raise
except Exception as exc:
success = False
error_detail = str(exc)
raise HTTPException(status_code=502, detail=f"Gemini API error: {exc}") from exc
finally:
duration_ms = int((time.monotonic() - start) * 1000)
with suppress(Exception):
log = AICallLog(
endpoint=endpoint,
model=model,
system_prompt=sys_prompt,
user_input=user_input,
response_text=response.text if response else None,
tool_trace=tool_trace,
prompt_tokens=(
response.usage_metadata.prompt_token_count
if response and response.usage_metadata
else None
),
completion_tokens=(
response.usage_metadata.candidates_token_count
if response and response.usage_metadata
else None
),
total_tokens=(
response.usage_metadata.total_token_count
if response and response.usage_metadata
else None
),
duration_ms=duration_ms,
finish_reason=finish_reason,
success=success,
error_detail=error_detail,
)
with Session(engine) as s:
s.add(log)
s.commit()
return response
def call_gemini_with_function_tools(
*,
endpoint: str,
contents,
config: genai_types.GenerateContentConfig,
function_handlers: dict[str, Callable[[dict[str, Any]], dict[str, Any]]],
user_input: str | None = None,
max_tool_roundtrips: int = 2,
):
"""Call Gemini with function-calling loop until final response text is produced."""
if max_tool_roundtrips < 0:
raise ValueError("max_tool_roundtrips must be >= 0")
history = list(contents) if isinstance(contents, list) else [contents]
rounds = 0
trace_events: list[dict[str, Any]] = []
while True:
response = call_gemini(
endpoint=endpoint,
contents=history,
config=config,
user_input=user_input,
tool_trace={
"mode": "function_tools",
"round": rounds,
"events": trace_events,
},
)
function_calls = list(getattr(response, "function_calls", None) or [])
if not function_calls:
return response
if rounds >= max_tool_roundtrips:
raise HTTPException(
status_code=502,
detail="Gemini requested too many function calls",
)
candidate_content = None
candidates = getattr(response, "candidates", None) or []
if candidates:
candidate_content = getattr(candidates[0], "content", None)
if candidate_content is not None:
history.append(candidate_content)
else:
history.append(
genai_types.ModelContent(
parts=[genai_types.Part(function_call=fc) for fc in function_calls]
)
)
response_parts: list[genai_types.Part] = []
for fc in function_calls:
name = getattr(fc, "name", None)
if not isinstance(name, str) or not name:
raise HTTPException(
status_code=502,
detail="Gemini requested a function without a valid name",
)
handler = function_handlers.get(name)
if handler is None:
raise HTTPException(
status_code=502,
detail=f"Gemini requested unknown function: {name}",
)
args = getattr(fc, "args", None) or {}
if not isinstance(args, dict):
raise HTTPException(
status_code=502,
detail=f"Gemini returned invalid arguments for function: {name}",
)
tool_response = handler(args)
if not isinstance(tool_response, dict):
raise HTTPException(
status_code=502,
detail=f"Function handler must return an object for: {name}",
)
trace_event: dict[str, Any] = {
"round": rounds + 1,
"function": name,
}
product_ids = args.get("product_ids")
if isinstance(product_ids, list):
clean_ids = [x for x in product_ids if isinstance(x, str)]
trace_event["requested_ids_count"] = len(clean_ids)
trace_event["requested_ids"] = clean_ids[:8]
products = tool_response.get("products")
if isinstance(products, list):
trace_event["returned_products_count"] = len(products)
trace_events.append(trace_event)
response_parts.append(
genai_types.Part.from_function_response(
name=name,
response=tool_response,
)
)
history.append(genai_types.UserContent(parts=response_parts))
rounds += 1