Enable on-demand INCI retrieval in /routines/suggest through Gemini function calling so detailed ingredient data is fetched only when needed. Persist and normalize tool_trace data in AI logs to make function-call behavior directly inspectable via /ai-logs endpoints.
255 lines
8.7 KiB
Python
255 lines
8.7 KiB
Python
"""Shared helpers for Gemini API access."""
|
|
|
|
import os
|
|
import time
|
|
from collections.abc import Callable
|
|
from contextlib import suppress
|
|
from typing import Any
|
|
|
|
from fastapi import HTTPException
|
|
from google import genai
|
|
from google.genai import types as genai_types
|
|
|
|
_DEFAULT_MODEL = "gemini-3-flash-preview"
|
|
|
|
|
|
def get_extraction_config(
|
|
system_instruction: str,
|
|
response_schema: Any,
|
|
max_output_tokens: int = 8192,
|
|
) -> genai_types.GenerateContentConfig:
|
|
"""Config for strict data extraction (deterministic, minimal thinking)."""
|
|
return genai_types.GenerateContentConfig(
|
|
system_instruction=system_instruction,
|
|
response_mime_type="application/json",
|
|
response_schema=response_schema,
|
|
max_output_tokens=max_output_tokens,
|
|
temperature=0.0,
|
|
thinking_config=genai_types.ThinkingConfig(
|
|
thinking_level=genai_types.ThinkingLevel.MINIMAL
|
|
),
|
|
)
|
|
|
|
|
|
def get_creative_config(
|
|
system_instruction: str,
|
|
response_schema: Any,
|
|
max_output_tokens: int = 4096,
|
|
) -> genai_types.GenerateContentConfig:
|
|
"""Config for creative tasks like recommendations (balanced creativity)."""
|
|
return genai_types.GenerateContentConfig(
|
|
system_instruction=system_instruction,
|
|
response_mime_type="application/json",
|
|
response_schema=response_schema,
|
|
max_output_tokens=max_output_tokens,
|
|
temperature=0.4,
|
|
top_p=0.8,
|
|
thinking_config=genai_types.ThinkingConfig(
|
|
thinking_level=genai_types.ThinkingLevel.LOW
|
|
),
|
|
)
|
|
|
|
|
|
def get_gemini_client() -> tuple[genai.Client, str]:
|
|
"""Return an authenticated Gemini client and the configured model name.
|
|
|
|
Raises HTTP 503 if GEMINI_API_KEY is not set.
|
|
"""
|
|
api_key = os.environ.get("GEMINI_API_KEY")
|
|
if not api_key:
|
|
raise HTTPException(status_code=503, detail="GEMINI_API_KEY not configured")
|
|
model = os.environ.get("GEMINI_MODEL", _DEFAULT_MODEL)
|
|
return genai.Client(api_key=api_key), model
|
|
|
|
|
|
def call_gemini(
|
|
*,
|
|
endpoint: str,
|
|
contents,
|
|
config: genai_types.GenerateContentConfig,
|
|
user_input: str | None = None,
|
|
tool_trace: dict[str, Any] | None = None,
|
|
):
|
|
"""Call Gemini, log full request + response to DB, return response unchanged."""
|
|
from sqlmodel import Session
|
|
|
|
from db import engine # deferred to avoid circular import at module load
|
|
from innercontext.models.ai_log import AICallLog
|
|
|
|
client, model = get_gemini_client()
|
|
|
|
sys_prompt = None
|
|
if config.system_instruction:
|
|
raw = config.system_instruction
|
|
sys_prompt = raw if isinstance(raw, str) else str(raw)
|
|
if user_input is None:
|
|
with suppress(Exception):
|
|
user_input = str(contents)
|
|
|
|
start = time.monotonic()
|
|
success, error_detail, response, finish_reason = True, None, None, None
|
|
try:
|
|
response = client.models.generate_content(
|
|
model=model, contents=contents, config=config
|
|
)
|
|
candidates = getattr(response, "candidates", None)
|
|
if candidates:
|
|
first_candidate = candidates[0]
|
|
reason = getattr(first_candidate, "finish_reason", None)
|
|
reason_name = getattr(reason, "name", None)
|
|
if isinstance(reason_name, str):
|
|
finish_reason = reason_name
|
|
if finish_reason and finish_reason != "STOP":
|
|
success = False
|
|
error_detail = f"finish_reason: {finish_reason}"
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail=f"Gemini stopped early (finish_reason={finish_reason})",
|
|
)
|
|
except HTTPException:
|
|
raise
|
|
except Exception as exc:
|
|
success = False
|
|
error_detail = str(exc)
|
|
raise HTTPException(status_code=502, detail=f"Gemini API error: {exc}") from exc
|
|
finally:
|
|
duration_ms = int((time.monotonic() - start) * 1000)
|
|
with suppress(Exception):
|
|
log = AICallLog(
|
|
endpoint=endpoint,
|
|
model=model,
|
|
system_prompt=sys_prompt,
|
|
user_input=user_input,
|
|
response_text=response.text if response else None,
|
|
tool_trace=tool_trace,
|
|
prompt_tokens=(
|
|
response.usage_metadata.prompt_token_count
|
|
if response and response.usage_metadata
|
|
else None
|
|
),
|
|
completion_tokens=(
|
|
response.usage_metadata.candidates_token_count
|
|
if response and response.usage_metadata
|
|
else None
|
|
),
|
|
total_tokens=(
|
|
response.usage_metadata.total_token_count
|
|
if response and response.usage_metadata
|
|
else None
|
|
),
|
|
duration_ms=duration_ms,
|
|
finish_reason=finish_reason,
|
|
success=success,
|
|
error_detail=error_detail,
|
|
)
|
|
with Session(engine) as s:
|
|
s.add(log)
|
|
s.commit()
|
|
return response
|
|
|
|
|
|
def call_gemini_with_function_tools(
|
|
*,
|
|
endpoint: str,
|
|
contents,
|
|
config: genai_types.GenerateContentConfig,
|
|
function_handlers: dict[str, Callable[[dict[str, Any]], dict[str, Any]]],
|
|
user_input: str | None = None,
|
|
max_tool_roundtrips: int = 2,
|
|
):
|
|
"""Call Gemini with function-calling loop until final response text is produced."""
|
|
if max_tool_roundtrips < 0:
|
|
raise ValueError("max_tool_roundtrips must be >= 0")
|
|
|
|
history = list(contents) if isinstance(contents, list) else [contents]
|
|
rounds = 0
|
|
trace_events: list[dict[str, Any]] = []
|
|
|
|
while True:
|
|
response = call_gemini(
|
|
endpoint=endpoint,
|
|
contents=history,
|
|
config=config,
|
|
user_input=user_input,
|
|
tool_trace={
|
|
"mode": "function_tools",
|
|
"round": rounds,
|
|
"events": trace_events,
|
|
},
|
|
)
|
|
function_calls = list(getattr(response, "function_calls", None) or [])
|
|
if not function_calls:
|
|
return response
|
|
|
|
if rounds >= max_tool_roundtrips:
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail="Gemini requested too many function calls",
|
|
)
|
|
|
|
candidate_content = None
|
|
candidates = getattr(response, "candidates", None) or []
|
|
if candidates:
|
|
candidate_content = getattr(candidates[0], "content", None)
|
|
if candidate_content is not None:
|
|
history.append(candidate_content)
|
|
else:
|
|
history.append(
|
|
genai_types.ModelContent(
|
|
parts=[genai_types.Part(function_call=fc) for fc in function_calls]
|
|
)
|
|
)
|
|
|
|
response_parts: list[genai_types.Part] = []
|
|
for fc in function_calls:
|
|
name = getattr(fc, "name", None)
|
|
if not isinstance(name, str) or not name:
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail="Gemini requested a function without a valid name",
|
|
)
|
|
|
|
handler = function_handlers.get(name)
|
|
if handler is None:
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail=f"Gemini requested unknown function: {name}",
|
|
)
|
|
|
|
args = getattr(fc, "args", None) or {}
|
|
if not isinstance(args, dict):
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail=f"Gemini returned invalid arguments for function: {name}",
|
|
)
|
|
|
|
tool_response = handler(args)
|
|
if not isinstance(tool_response, dict):
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail=f"Function handler must return an object for: {name}",
|
|
)
|
|
|
|
trace_event: dict[str, Any] = {
|
|
"round": rounds + 1,
|
|
"function": name,
|
|
}
|
|
product_ids = args.get("product_ids")
|
|
if isinstance(product_ids, list):
|
|
clean_ids = [x for x in product_ids if isinstance(x, str)]
|
|
trace_event["requested_ids_count"] = len(clean_ids)
|
|
trace_event["requested_ids"] = clean_ids[:8]
|
|
products = tool_response.get("products")
|
|
if isinstance(products, list):
|
|
trace_event["returned_products_count"] = len(products)
|
|
trace_events.append(trace_event)
|
|
|
|
response_parts.append(
|
|
genai_types.Part.from_function_response(
|
|
name=name,
|
|
response=tool_response,
|
|
)
|
|
)
|
|
|
|
history.append(genai_types.UserContent(parts=response_parts))
|
|
rounds += 1
|