feat(api): add INCI tool-calling with normalized tool traces
Enable on-demand INCI retrieval in /routines/suggest through Gemini function calling so detailed ingredient data is fetched only when needed. Persist and normalize tool_trace data in AI logs to make function-call behavior directly inspectable via /ai-logs endpoints.
This commit is contained in:
parent
c0eeb0425d
commit
cfd2485b7e
8 changed files with 455 additions and 29 deletions
|
|
@ -0,0 +1,29 @@
|
||||||
|
"""add_tool_trace_to_ai_call_logs
|
||||||
|
|
||||||
|
Revision ID: d3e4f5a6b7c8
|
||||||
|
Revises: b2c3d4e5f6a1
|
||||||
|
Create Date: 2026-03-04 00:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "d3e4f5a6b7c8"
|
||||||
|
down_revision: Union[str, None] = "b2c3d4e5f6a1"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column(
|
||||||
|
"ai_call_logs",
|
||||||
|
sa.Column("tool_trace", sa.JSON(), nullable=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_column("ai_call_logs", "tool_trace")
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
from typing import Optional
|
import json
|
||||||
|
from typing import Any, Optional
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
|
@ -10,6 +11,22 @@ from innercontext.models.ai_log import AICallLog
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_tool_trace(value: object) -> dict[str, Any] | None:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
if isinstance(value, dict):
|
||||||
|
return {str(k): v for k, v in value.items()}
|
||||||
|
if isinstance(value, str):
|
||||||
|
try:
|
||||||
|
parsed = json.loads(value)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return None
|
||||||
|
if isinstance(parsed, dict):
|
||||||
|
return {str(k): v for k, v in parsed.items()}
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class AICallLogPublic(SQLModel):
|
class AICallLogPublic(SQLModel):
|
||||||
"""List-friendly view: omits large text fields."""
|
"""List-friendly view: omits large text fields."""
|
||||||
|
|
||||||
|
|
@ -21,6 +38,7 @@ class AICallLogPublic(SQLModel):
|
||||||
completion_tokens: Optional[int] = None
|
completion_tokens: Optional[int] = None
|
||||||
total_tokens: Optional[int] = None
|
total_tokens: Optional[int] = None
|
||||||
duration_ms: Optional[int] = None
|
duration_ms: Optional[int] = None
|
||||||
|
tool_trace: Optional[dict[str, Any]] = None
|
||||||
success: bool
|
success: bool
|
||||||
error_detail: Optional[str] = None
|
error_detail: Optional[str] = None
|
||||||
|
|
||||||
|
|
@ -37,7 +55,23 @@ def list_ai_logs(
|
||||||
stmt = stmt.where(AICallLog.endpoint == endpoint)
|
stmt = stmt.where(AICallLog.endpoint == endpoint)
|
||||||
if success is not None:
|
if success is not None:
|
||||||
stmt = stmt.where(AICallLog.success == success)
|
stmt = stmt.where(AICallLog.success == success)
|
||||||
return session.exec(stmt).all()
|
logs = session.exec(stmt).all()
|
||||||
|
return [
|
||||||
|
AICallLogPublic(
|
||||||
|
id=log.id,
|
||||||
|
created_at=log.created_at,
|
||||||
|
endpoint=log.endpoint,
|
||||||
|
model=log.model,
|
||||||
|
prompt_tokens=log.prompt_tokens,
|
||||||
|
completion_tokens=log.completion_tokens,
|
||||||
|
total_tokens=log.total_tokens,
|
||||||
|
duration_ms=log.duration_ms,
|
||||||
|
tool_trace=_normalize_tool_trace(getattr(log, "tool_trace", None)),
|
||||||
|
success=log.success,
|
||||||
|
error_detail=log.error_detail,
|
||||||
|
)
|
||||||
|
for log in logs
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{log_id}", response_model=AICallLog)
|
@router.get("/{log_id}", response_model=AICallLog)
|
||||||
|
|
@ -45,4 +79,5 @@ def get_ai_log(log_id: UUID, session: Session = Depends(get_session)):
|
||||||
log = session.get(AICallLog, log_id)
|
log = session.get(AICallLog, log_id)
|
||||||
if log is None:
|
if log is None:
|
||||||
raise HTTPException(status_code=404, detail="Log not found")
|
raise HTTPException(status_code=404, detail="Log not found")
|
||||||
|
log.tool_trace = _normalize_tool_trace(getattr(log, "tool_trace", None))
|
||||||
return log
|
return log
|
||||||
|
|
|
||||||
|
|
@ -4,12 +4,17 @@ from typing import Optional
|
||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from google.genai import types as genai_types
|
||||||
from pydantic import BaseModel as PydanticBase
|
from pydantic import BaseModel as PydanticBase
|
||||||
from sqlmodel import Field, Session, SQLModel, col, select
|
from sqlmodel import Field, Session, SQLModel, col, select
|
||||||
|
|
||||||
from db import get_session
|
from db import get_session
|
||||||
from innercontext.api.utils import get_or_404
|
from innercontext.api.utils import get_or_404
|
||||||
from innercontext.llm import call_gemini, get_creative_config
|
from innercontext.llm import (
|
||||||
|
call_gemini,
|
||||||
|
call_gemini_with_function_tools,
|
||||||
|
get_creative_config,
|
||||||
|
)
|
||||||
from innercontext.models import (
|
from innercontext.models import (
|
||||||
GroomingSchedule,
|
GroomingSchedule,
|
||||||
Product,
|
Product,
|
||||||
|
|
@ -302,8 +307,7 @@ def _build_products_context(
|
||||||
time_filter: Optional[str] = None,
|
time_filter: Optional[str] = None,
|
||||||
reference_date: Optional[date] = None,
|
reference_date: Optional[date] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
stmt = select(Product).where(col(Product.is_tool).is_(False))
|
products = _get_available_products(session, time_filter=time_filter)
|
||||||
products = session.exec(stmt).all()
|
|
||||||
product_ids = [p.id for p in products]
|
product_ids = [p.id for p in products]
|
||||||
inventory_rows = (
|
inventory_rows = (
|
||||||
session.exec(
|
session.exec(
|
||||||
|
|
@ -333,10 +337,6 @@ def _build_products_context(
|
||||||
|
|
||||||
lines = ["AVAILABLE PRODUCTS:"]
|
lines = ["AVAILABLE PRODUCTS:"]
|
||||||
for p in products:
|
for p in products:
|
||||||
if p.is_medication and not _is_minoxidil_product(p):
|
|
||||||
continue
|
|
||||||
if time_filter and _ev(p.recommended_time) not in (time_filter, "both"):
|
|
||||||
continue
|
|
||||||
p.inventory = inv_by_product.get(p.id, [])
|
p.inventory = inv_by_product.get(p.id, [])
|
||||||
ctx = p.to_llm_context()
|
ctx = p.to_llm_context()
|
||||||
entry = (
|
entry = (
|
||||||
|
|
@ -404,6 +404,83 @@ def _build_products_context(
|
||||||
return "\n".join(lines) + "\n"
|
return "\n".join(lines) + "\n"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_available_products(
|
||||||
|
session: Session,
|
||||||
|
time_filter: Optional[str] = None,
|
||||||
|
) -> list[Product]:
|
||||||
|
stmt = select(Product).where(col(Product.is_tool).is_(False))
|
||||||
|
products = session.exec(stmt).all()
|
||||||
|
result: list[Product] = []
|
||||||
|
for p in products:
|
||||||
|
if p.is_medication and not _is_minoxidil_product(p):
|
||||||
|
continue
|
||||||
|
if time_filter and _ev(p.recommended_time) not in (time_filter, "both"):
|
||||||
|
continue
|
||||||
|
result.append(p)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _build_inci_tool_handler(
|
||||||
|
products: list[Product],
|
||||||
|
):
|
||||||
|
available_by_id = {str(p.id): p for p in products}
|
||||||
|
|
||||||
|
def _handler(args: dict[str, object]) -> dict[str, object]:
|
||||||
|
raw_ids = args.get("product_ids")
|
||||||
|
if not isinstance(raw_ids, list):
|
||||||
|
return {"products": []}
|
||||||
|
|
||||||
|
requested_ids: list[str] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
for raw_id in raw_ids:
|
||||||
|
if not isinstance(raw_id, str):
|
||||||
|
continue
|
||||||
|
if raw_id in seen:
|
||||||
|
continue
|
||||||
|
seen.add(raw_id)
|
||||||
|
requested_ids.append(raw_id)
|
||||||
|
if len(requested_ids) >= 8:
|
||||||
|
break
|
||||||
|
|
||||||
|
products_payload = []
|
||||||
|
for pid in requested_ids:
|
||||||
|
product = available_by_id.get(pid)
|
||||||
|
if product is None:
|
||||||
|
continue
|
||||||
|
inci = product.inci or []
|
||||||
|
compact_inci = [str(i)[:120] for i in inci[:128]]
|
||||||
|
products_payload.append(
|
||||||
|
{
|
||||||
|
"id": pid,
|
||||||
|
"name": product.name,
|
||||||
|
"inci": compact_inci,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return {"products": products_payload}
|
||||||
|
|
||||||
|
return _handler
|
||||||
|
|
||||||
|
|
||||||
|
_INCI_FUNCTION_DECLARATION = genai_types.FunctionDeclaration(
|
||||||
|
name="get_product_inci",
|
||||||
|
description=(
|
||||||
|
"Return exact INCI ingredient lists for products identified by UUID from "
|
||||||
|
"the AVAILABLE PRODUCTS list."
|
||||||
|
),
|
||||||
|
parameters=genai_types.Schema(
|
||||||
|
type=genai_types.Type.OBJECT,
|
||||||
|
properties={
|
||||||
|
"product_ids": genai_types.Schema(
|
||||||
|
type=genai_types.Type.ARRAY,
|
||||||
|
items=genai_types.Schema(type=genai_types.Type.STRING),
|
||||||
|
description="Product UUIDs from AVAILABLE PRODUCTS.",
|
||||||
|
)
|
||||||
|
},
|
||||||
|
required=["product_ids"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _build_objectives_context(include_minoxidil_beard: bool) -> str:
|
def _build_objectives_context(include_minoxidil_beard: bool) -> str:
|
||||||
if include_minoxidil_beard:
|
if include_minoxidil_beard:
|
||||||
return (
|
return (
|
||||||
|
|
@ -574,6 +651,10 @@ def suggest_routine(
|
||||||
products_ctx = _build_products_context(
|
products_ctx = _build_products_context(
|
||||||
session, time_filter=data.part_of_day.value, reference_date=data.routine_date
|
session, time_filter=data.part_of_day.value, reference_date=data.routine_date
|
||||||
)
|
)
|
||||||
|
available_products = _get_available_products(
|
||||||
|
session,
|
||||||
|
time_filter=data.part_of_day.value,
|
||||||
|
)
|
||||||
objectives_ctx = _build_objectives_context(data.include_minoxidil_beard)
|
objectives_ctx = _build_objectives_context(data.include_minoxidil_beard)
|
||||||
|
|
||||||
mode_line = "MODE: standard"
|
mode_line = "MODE: standard"
|
||||||
|
|
@ -586,20 +667,43 @@ def suggest_routine(
|
||||||
f"{mode_line}\n"
|
f"{mode_line}\n"
|
||||||
"INPUT DATA:\n"
|
"INPUT DATA:\n"
|
||||||
f"{skin_ctx}\n{grooming_ctx}\n{history_ctx}\n{day_ctx}\n{products_ctx}\n{objectives_ctx}"
|
f"{skin_ctx}\n{grooming_ctx}\n{history_ctx}\n{day_ctx}\n{products_ctx}\n{objectives_ctx}"
|
||||||
|
"\nNARZEDZIA:\n"
|
||||||
|
"- Masz dostep do funkcji get_product_inci(product_ids).\n"
|
||||||
|
"- Uzyj jej tylko, gdy potrzebujesz dokladnego skladu INCI do oceny bezpieczenstwa, kompatybilnosci lub redundancji aktywnych.\n"
|
||||||
|
"- Nie zgaduj INCI; jesli potrzebujesz skladu, wywolaj funkcje dla konkretnych UUID.\n"
|
||||||
f"{notes_line}"
|
f"{notes_line}"
|
||||||
f"{_ROUTINES_SINGLE_EXTRA}\n"
|
f"{_ROUTINES_SINGLE_EXTRA}\n"
|
||||||
"Zwróć JSON zgodny ze schematem."
|
"Zwróć JSON zgodny ze schematem."
|
||||||
)
|
)
|
||||||
|
|
||||||
response = call_gemini(
|
config = get_creative_config(
|
||||||
endpoint="routines/suggest",
|
|
||||||
contents=prompt,
|
|
||||||
config=get_creative_config(
|
|
||||||
system_instruction=_ROUTINES_SYSTEM_PROMPT,
|
system_instruction=_ROUTINES_SYSTEM_PROMPT,
|
||||||
response_schema=_SuggestionOut,
|
response_schema=_SuggestionOut,
|
||||||
max_output_tokens=4096,
|
max_output_tokens=4096,
|
||||||
|
).model_copy(
|
||||||
|
update={
|
||||||
|
"tools": [
|
||||||
|
genai_types.Tool(
|
||||||
|
function_declarations=[_INCI_FUNCTION_DECLARATION],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
"tool_config": genai_types.ToolConfig(
|
||||||
|
function_calling_config=genai_types.FunctionCallingConfig(
|
||||||
|
mode=genai_types.FunctionCallingConfigMode.AUTO,
|
||||||
|
)
|
||||||
),
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
response = call_gemini_with_function_tools(
|
||||||
|
endpoint="routines/suggest",
|
||||||
|
contents=prompt,
|
||||||
|
config=config,
|
||||||
|
function_handlers={
|
||||||
|
"get_product_inci": _build_inci_tool_handler(available_products)
|
||||||
|
},
|
||||||
user_input=prompt,
|
user_input=prompt,
|
||||||
|
max_tool_roundtrips=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
raw = response.text
|
raw = response.text
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
from collections.abc import Callable
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
@ -67,6 +68,7 @@ def call_gemini(
|
||||||
contents,
|
contents,
|
||||||
config: genai_types.GenerateContentConfig,
|
config: genai_types.GenerateContentConfig,
|
||||||
user_input: str | None = None,
|
user_input: str | None = None,
|
||||||
|
tool_trace: dict[str, Any] | None = None,
|
||||||
):
|
):
|
||||||
"""Call Gemini, log full request + response to DB, return response unchanged."""
|
"""Call Gemini, log full request + response to DB, return response unchanged."""
|
||||||
from sqlmodel import Session
|
from sqlmodel import Session
|
||||||
|
|
@ -119,6 +121,7 @@ def call_gemini(
|
||||||
system_prompt=sys_prompt,
|
system_prompt=sys_prompt,
|
||||||
user_input=user_input,
|
user_input=user_input,
|
||||||
response_text=response.text if response else None,
|
response_text=response.text if response else None,
|
||||||
|
tool_trace=tool_trace,
|
||||||
prompt_tokens=(
|
prompt_tokens=(
|
||||||
response.usage_metadata.prompt_token_count
|
response.usage_metadata.prompt_token_count
|
||||||
if response and response.usage_metadata
|
if response and response.usage_metadata
|
||||||
|
|
@ -143,3 +146,110 @@ def call_gemini(
|
||||||
s.add(log)
|
s.add(log)
|
||||||
s.commit()
|
s.commit()
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def call_gemini_with_function_tools(
|
||||||
|
*,
|
||||||
|
endpoint: str,
|
||||||
|
contents,
|
||||||
|
config: genai_types.GenerateContentConfig,
|
||||||
|
function_handlers: dict[str, Callable[[dict[str, Any]], dict[str, Any]]],
|
||||||
|
user_input: str | None = None,
|
||||||
|
max_tool_roundtrips: int = 2,
|
||||||
|
):
|
||||||
|
"""Call Gemini with function-calling loop until final response text is produced."""
|
||||||
|
if max_tool_roundtrips < 0:
|
||||||
|
raise ValueError("max_tool_roundtrips must be >= 0")
|
||||||
|
|
||||||
|
history = list(contents) if isinstance(contents, list) else [contents]
|
||||||
|
rounds = 0
|
||||||
|
trace_events: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
while True:
|
||||||
|
response = call_gemini(
|
||||||
|
endpoint=endpoint,
|
||||||
|
contents=history,
|
||||||
|
config=config,
|
||||||
|
user_input=user_input,
|
||||||
|
tool_trace={
|
||||||
|
"mode": "function_tools",
|
||||||
|
"round": rounds,
|
||||||
|
"events": trace_events,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
function_calls = list(getattr(response, "function_calls", None) or [])
|
||||||
|
if not function_calls:
|
||||||
|
return response
|
||||||
|
|
||||||
|
if rounds >= max_tool_roundtrips:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=502,
|
||||||
|
detail="Gemini requested too many function calls",
|
||||||
|
)
|
||||||
|
|
||||||
|
candidate_content = None
|
||||||
|
candidates = getattr(response, "candidates", None) or []
|
||||||
|
if candidates:
|
||||||
|
candidate_content = getattr(candidates[0], "content", None)
|
||||||
|
if candidate_content is not None:
|
||||||
|
history.append(candidate_content)
|
||||||
|
else:
|
||||||
|
history.append(
|
||||||
|
genai_types.ModelContent(
|
||||||
|
parts=[genai_types.Part(function_call=fc) for fc in function_calls]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
response_parts: list[genai_types.Part] = []
|
||||||
|
for fc in function_calls:
|
||||||
|
name = getattr(fc, "name", None)
|
||||||
|
if not isinstance(name, str) or not name:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=502,
|
||||||
|
detail="Gemini requested a function without a valid name",
|
||||||
|
)
|
||||||
|
|
||||||
|
handler = function_handlers.get(name)
|
||||||
|
if handler is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=502,
|
||||||
|
detail=f"Gemini requested unknown function: {name}",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = getattr(fc, "args", None) or {}
|
||||||
|
if not isinstance(args, dict):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=502,
|
||||||
|
detail=f"Gemini returned invalid arguments for function: {name}",
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_response = handler(args)
|
||||||
|
if not isinstance(tool_response, dict):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=502,
|
||||||
|
detail=f"Function handler must return an object for: {name}",
|
||||||
|
)
|
||||||
|
|
||||||
|
trace_event: dict[str, Any] = {
|
||||||
|
"round": rounds + 1,
|
||||||
|
"function": name,
|
||||||
|
}
|
||||||
|
product_ids = args.get("product_ids")
|
||||||
|
if isinstance(product_ids, list):
|
||||||
|
clean_ids = [x for x in product_ids if isinstance(x, str)]
|
||||||
|
trace_event["requested_ids_count"] = len(clean_ids)
|
||||||
|
trace_event["requested_ids"] = clean_ids[:8]
|
||||||
|
products = tool_response.get("products")
|
||||||
|
if isinstance(products, list):
|
||||||
|
trace_event["returned_products_count"] = len(products)
|
||||||
|
trace_events.append(trace_event)
|
||||||
|
|
||||||
|
response_parts.append(
|
||||||
|
genai_types.Part.from_function_response(
|
||||||
|
name=name,
|
||||||
|
response=tool_response,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
history.append(genai_types.UserContent(parts=response_parts))
|
||||||
|
rounds += 1
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import ClassVar
|
from typing import Any, ClassVar
|
||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
|
from sqlalchemy import JSON, Column
|
||||||
from sqlmodel import Field, SQLModel
|
from sqlmodel import Field, SQLModel
|
||||||
|
|
||||||
from .base import utc_now
|
from .base import utc_now
|
||||||
|
|
@ -24,5 +25,9 @@ class AICallLog(SQLModel, table=True):
|
||||||
total_tokens: int | None = Field(default=None)
|
total_tokens: int | None = Field(default=None)
|
||||||
duration_ms: int | None = Field(default=None)
|
duration_ms: int | None = Field(default=None)
|
||||||
finish_reason: str | None = Field(default=None)
|
finish_reason: str | None = Field(default=None)
|
||||||
|
tool_trace: dict[str, Any] | None = Field(
|
||||||
|
default=None,
|
||||||
|
sa_column=Column(JSON, nullable=True),
|
||||||
|
)
|
||||||
success: bool = Field(default=True, index=True)
|
success: bool = Field(default=True, index=True)
|
||||||
error_detail: str | None = Field(default=None)
|
error_detail: str | None = Field(default=None)
|
||||||
|
|
|
||||||
44
backend/tests/test_ai_logs.py
Normal file
44
backend/tests/test_ai_logs.py
Normal file
|
|
@ -0,0 +1,44 @@
|
||||||
|
import uuid
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from innercontext.models.ai_log import AICallLog
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_ai_logs_normalizes_tool_trace_string(client, session):
|
||||||
|
log = AICallLog(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
endpoint="routines/suggest",
|
||||||
|
model="gemini-3-flash-preview",
|
||||||
|
success=True,
|
||||||
|
)
|
||||||
|
log.tool_trace = cast(
|
||||||
|
Any,
|
||||||
|
'{"mode":"function_tools","events":[{"function":"get_product_inci"}]}',
|
||||||
|
)
|
||||||
|
session.add(log)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
response = client.get("/ai-logs")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert len(data) == 1
|
||||||
|
assert data[0]["tool_trace"]["mode"] == "function_tools"
|
||||||
|
assert data[0]["tool_trace"]["events"][0]["function"] == "get_product_inci"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_ai_log_normalizes_tool_trace_string(client, session):
|
||||||
|
log = AICallLog(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
endpoint="routines/suggest",
|
||||||
|
model="gemini-3-flash-preview",
|
||||||
|
success=True,
|
||||||
|
)
|
||||||
|
log.tool_trace = cast(Any, '{"mode":"function_tools","round":1}')
|
||||||
|
session.add(log)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
response = client.get(f"/ai-logs/{log.id}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
payload = response.json()
|
||||||
|
assert payload["tool_trace"]["mode"] == "function_tools"
|
||||||
|
assert payload["tool_trace"]["round"] == 1
|
||||||
|
|
@ -220,7 +220,9 @@ def test_delete_grooming_schedule_not_found(client):
|
||||||
|
|
||||||
|
|
||||||
def test_suggest_routine(client, session):
|
def test_suggest_routine(client, session):
|
||||||
with patch("innercontext.api.routines.call_gemini") as mock_gemini:
|
with patch(
|
||||||
|
"innercontext.api.routines.call_gemini_with_function_tools"
|
||||||
|
) as mock_gemini:
|
||||||
# Mock the Gemini response
|
# Mock the Gemini response
|
||||||
mock_response = type(
|
mock_response = type(
|
||||||
"Response",
|
"Response",
|
||||||
|
|
@ -245,6 +247,9 @@ def test_suggest_routine(client, session):
|
||||||
assert len(data["steps"]) == 1
|
assert len(data["steps"]) == 1
|
||||||
assert data["steps"][0]["action_type"] == "shaving_razor"
|
assert data["steps"][0]["action_type"] == "shaving_razor"
|
||||||
assert data["reasoning"] == "because"
|
assert data["reasoning"] == "because"
|
||||||
|
kwargs = mock_gemini.call_args.kwargs
|
||||||
|
assert "function_handlers" in kwargs
|
||||||
|
assert "get_product_inci" in kwargs["function_handlers"]
|
||||||
|
|
||||||
|
|
||||||
def test_suggest_batch(client, session):
|
def test_suggest_batch(client, session):
|
||||||
|
|
|
||||||
|
|
@ -1,26 +1,28 @@
|
||||||
from datetime import date, timedelta
|
|
||||||
import uuid
|
import uuid
|
||||||
|
from datetime import date, timedelta
|
||||||
|
|
||||||
from sqlmodel import Session
|
from sqlmodel import Session
|
||||||
|
|
||||||
from innercontext.api.routines import (
|
from innercontext.api.routines import (
|
||||||
_contains_minoxidil_text,
|
|
||||||
_is_minoxidil_product,
|
|
||||||
_ev,
|
|
||||||
_build_skin_context,
|
|
||||||
_build_grooming_context,
|
|
||||||
_build_recent_history,
|
|
||||||
_build_products_context,
|
|
||||||
_build_objectives_context,
|
|
||||||
_build_day_context,
|
_build_day_context,
|
||||||
|
_build_grooming_context,
|
||||||
|
_build_inci_tool_handler,
|
||||||
|
_build_objectives_context,
|
||||||
|
_build_products_context,
|
||||||
|
_build_recent_history,
|
||||||
|
_build_skin_context,
|
||||||
|
_contains_minoxidil_text,
|
||||||
|
_ev,
|
||||||
|
_get_available_products,
|
||||||
|
_is_minoxidil_product,
|
||||||
)
|
)
|
||||||
from innercontext.models import (
|
from innercontext.models import (
|
||||||
Product,
|
|
||||||
SkinConditionSnapshot,
|
|
||||||
GroomingSchedule,
|
GroomingSchedule,
|
||||||
|
Product,
|
||||||
|
ProductInventory,
|
||||||
Routine,
|
Routine,
|
||||||
RoutineStep,
|
RoutineStep,
|
||||||
ProductInventory,
|
SkinConditionSnapshot,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -242,3 +244,95 @@ def test_build_day_context():
|
||||||
assert _build_day_context(None) == ""
|
assert _build_day_context(None) == ""
|
||||||
assert "Leaving home: yes" in _build_day_context(True)
|
assert "Leaving home: yes" in _build_day_context(True)
|
||||||
assert "Leaving home: no" in _build_day_context(False)
|
assert "Leaving home: no" in _build_day_context(False)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_available_products_respects_filters(session: Session):
|
||||||
|
regular_med = Product(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Tretinoin",
|
||||||
|
category="serum",
|
||||||
|
is_medication=True,
|
||||||
|
brand="Test",
|
||||||
|
recommended_time="pm",
|
||||||
|
leave_on=True,
|
||||||
|
product_effect_profile={},
|
||||||
|
)
|
||||||
|
minoxidil_med = Product(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Minoxidil 5%",
|
||||||
|
category="serum",
|
||||||
|
is_medication=True,
|
||||||
|
brand="Test",
|
||||||
|
recommended_time="both",
|
||||||
|
leave_on=True,
|
||||||
|
product_effect_profile={},
|
||||||
|
)
|
||||||
|
am_product = Product(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="AM SPF",
|
||||||
|
category="spf",
|
||||||
|
brand="Test",
|
||||||
|
recommended_time="am",
|
||||||
|
leave_on=True,
|
||||||
|
product_effect_profile={},
|
||||||
|
)
|
||||||
|
pm_product = Product(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="PM Cream",
|
||||||
|
category="moisturizer",
|
||||||
|
brand="Test",
|
||||||
|
recommended_time="pm",
|
||||||
|
leave_on=True,
|
||||||
|
product_effect_profile={},
|
||||||
|
)
|
||||||
|
session.add_all([regular_med, minoxidil_med, am_product, pm_product])
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
am_available = _get_available_products(session, time_filter="am")
|
||||||
|
am_names = {p.name for p in am_available}
|
||||||
|
assert "Tretinoin" not in am_names
|
||||||
|
assert "Minoxidil 5%" in am_names
|
||||||
|
assert "AM SPF" in am_names
|
||||||
|
assert "PM Cream" not in am_names
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_inci_tool_handler_returns_only_available_ids(session: Session):
|
||||||
|
available = Product(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Available",
|
||||||
|
category="serum",
|
||||||
|
brand="Test",
|
||||||
|
recommended_time="both",
|
||||||
|
leave_on=True,
|
||||||
|
inci=["Water", "Niacinamide"],
|
||||||
|
product_effect_profile={},
|
||||||
|
)
|
||||||
|
unavailable = Product(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Unavailable",
|
||||||
|
category="serum",
|
||||||
|
brand="Test",
|
||||||
|
recommended_time="both",
|
||||||
|
leave_on=True,
|
||||||
|
inci=["Water", "Retinol"],
|
||||||
|
product_effect_profile={},
|
||||||
|
)
|
||||||
|
|
||||||
|
handler = _build_inci_tool_handler([available])
|
||||||
|
payload = handler(
|
||||||
|
{
|
||||||
|
"product_ids": [
|
||||||
|
str(available.id),
|
||||||
|
str(unavailable.id),
|
||||||
|
str(available.id),
|
||||||
|
123,
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "products" in payload
|
||||||
|
products = payload["products"]
|
||||||
|
assert len(products) == 1
|
||||||
|
assert products[0]["id"] == str(available.id)
|
||||||
|
assert products[0]["name"] == "Available"
|
||||||
|
assert products[0]["inci"] == ["Water", "Niacinamide"]
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue