feat(api): add tool-calling flow for shopping suggestions
Keep /products/suggest lean by exposing product UUIDs and fetching INCI, safety rules, actives, and usage notes on demand through Gemini function tools. Add conservative fallback behavior for tool roundtrip limits and expand helper tests to cover tool wiring and payload handlers.
This commit is contained in:
parent
558708653c
commit
b58fcb1440
2 changed files with 370 additions and 30 deletions
|
|
@ -4,13 +4,19 @@ from typing import Optional
|
|||
from uuid import UUID, uuid4
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from google.genai import types as genai_types
|
||||
from pydantic import BaseModel as PydanticBase
|
||||
from pydantic import ValidationError
|
||||
from sqlmodel import Session, SQLModel, col, select
|
||||
|
||||
from db import get_session
|
||||
from innercontext.api.utils import get_or_404
|
||||
from innercontext.llm import call_gemini, get_creative_config, get_extraction_config
|
||||
from innercontext.llm import (
|
||||
call_gemini,
|
||||
call_gemini_with_function_tools,
|
||||
get_creative_config,
|
||||
get_extraction_config,
|
||||
)
|
||||
from innercontext.models import (
|
||||
Product,
|
||||
ProductBase,
|
||||
|
|
@ -541,8 +547,7 @@ def _build_shopping_context(session: Session) -> str:
|
|||
else:
|
||||
skin_lines.append(" (brak danych)")
|
||||
|
||||
stmt = select(Product).where(col(Product.is_tool).is_(False))
|
||||
products = session.exec(stmt).all()
|
||||
products = _get_shopping_products(session)
|
||||
|
||||
product_ids = [p.id for p in products]
|
||||
inventory_rows = (
|
||||
|
|
@ -563,17 +568,11 @@ def _build_shopping_context(session: Session) -> str:
|
|||
" Legenda: [✓] = produkt dostępny (w magazynie), [✗] = brak w magazynie"
|
||||
)
|
||||
for p in products:
|
||||
if p.is_medication:
|
||||
continue
|
||||
active_inv = [i for i in inv_by_product.get(p.id, []) if i.finished_at is None]
|
||||
has_stock = len(active_inv) > 0 # any unfinished inventory = in stock
|
||||
stock = "✓" if has_stock else "✗"
|
||||
|
||||
actives = []
|
||||
for a in p.actives or []:
|
||||
name = a.get("name") if isinstance(a, dict) else getattr(a, "name", None)
|
||||
if name:
|
||||
actives.append(name)
|
||||
actives = _extract_active_names(p)
|
||||
actives_str = f", actives: {actives}" if actives else ""
|
||||
|
||||
ep = p.product_effect_profile
|
||||
|
|
@ -590,13 +589,226 @@ def _build_shopping_context(session: Session) -> str:
|
|||
targets = [_ev(t) for t in (p.targets or [])]
|
||||
|
||||
products_lines.append(
|
||||
f" [{stock}] {p.name} ({p.brand or ''}) - {_ev(p.category)}, "
|
||||
f" [{stock}] id={p.id} {p.name} ({p.brand or ''}) - {_ev(p.category)}, "
|
||||
f"targets: {targets}{actives_str}{effects_str}"
|
||||
)
|
||||
|
||||
return "\n".join(skin_lines) + "\n\n" + "\n".join(products_lines)
|
||||
|
||||
|
||||
def _get_shopping_products(session: Session) -> list[Product]:
|
||||
stmt = select(Product).where(col(Product.is_tool).is_(False))
|
||||
products = session.exec(stmt).all()
|
||||
return [p for p in products if not p.is_medication]
|
||||
|
||||
|
||||
def _extract_active_names(product: Product) -> list[str]:
|
||||
names: list[str] = []
|
||||
for active in product.actives or []:
|
||||
if isinstance(active, dict):
|
||||
name = str(active.get("name") or "").strip()
|
||||
else:
|
||||
name = str(getattr(active, "name", "") or "").strip()
|
||||
if not name:
|
||||
continue
|
||||
if name in names:
|
||||
continue
|
||||
names.append(name)
|
||||
if len(names) >= 12:
|
||||
break
|
||||
return names
|
||||
|
||||
|
||||
def _extract_requested_product_ids(
|
||||
args: dict[str, object], max_ids: int = 8
|
||||
) -> list[str]:
|
||||
raw_ids = args.get("product_ids")
|
||||
if not isinstance(raw_ids, list):
|
||||
return []
|
||||
|
||||
requested_ids: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for raw_id in raw_ids:
|
||||
if not isinstance(raw_id, str):
|
||||
continue
|
||||
if raw_id in seen:
|
||||
continue
|
||||
seen.add(raw_id)
|
||||
requested_ids.append(raw_id)
|
||||
if len(requested_ids) >= max_ids:
|
||||
break
|
||||
return requested_ids
|
||||
|
||||
|
||||
def _build_product_details_tool_handler(products: list[Product], mapper):
|
||||
available_by_id = {str(p.id): p for p in products}
|
||||
|
||||
def _handler(args: dict[str, object]) -> dict[str, object]:
|
||||
requested_ids = _extract_requested_product_ids(args)
|
||||
products_payload = []
|
||||
for pid in requested_ids:
|
||||
product = available_by_id.get(pid)
|
||||
if product is None:
|
||||
continue
|
||||
products_payload.append(mapper(product, pid))
|
||||
return {"products": products_payload}
|
||||
|
||||
return _handler
|
||||
|
||||
|
||||
def _build_inci_tool_handler(products: list[Product]):
|
||||
def _mapper(product: Product, pid: str) -> dict[str, object]:
|
||||
inci = product.inci or []
|
||||
compact_inci = [str(i)[:120] for i in inci[:128]]
|
||||
return {"id": pid, "name": product.name, "inci": compact_inci}
|
||||
|
||||
return _build_product_details_tool_handler(products, mapper=_mapper)
|
||||
|
||||
|
||||
def _build_actives_tool_handler(products: list[Product]):
|
||||
def _mapper(product: Product, pid: str) -> dict[str, object]:
|
||||
payload = []
|
||||
for active in product.actives or []:
|
||||
if isinstance(active, dict):
|
||||
name = str(active.get("name") or "").strip()
|
||||
if not name:
|
||||
continue
|
||||
item = {"name": name}
|
||||
percent = active.get("percent")
|
||||
if percent is not None:
|
||||
item["percent"] = percent
|
||||
functions = active.get("functions")
|
||||
if isinstance(functions, list):
|
||||
item["functions"] = [str(f) for f in functions[:4]]
|
||||
strength_level = active.get("strength_level")
|
||||
if strength_level is not None:
|
||||
item["strength_level"] = str(strength_level)
|
||||
payload.append(item)
|
||||
continue
|
||||
|
||||
name = str(getattr(active, "name", "") or "").strip()
|
||||
if not name:
|
||||
continue
|
||||
item = {"name": name}
|
||||
percent = getattr(active, "percent", None)
|
||||
if percent is not None:
|
||||
item["percent"] = percent
|
||||
functions = getattr(active, "functions", None)
|
||||
if isinstance(functions, list):
|
||||
item["functions"] = [_ev(f) for f in functions[:4]]
|
||||
strength_level = getattr(active, "strength_level", None)
|
||||
if strength_level is not None:
|
||||
item["strength_level"] = _ev(strength_level)
|
||||
payload.append(item)
|
||||
return {"id": pid, "name": product.name, "actives": payload[:24]}
|
||||
|
||||
return _build_product_details_tool_handler(products, mapper=_mapper)
|
||||
|
||||
|
||||
def _build_usage_notes_tool_handler(products: list[Product]):
|
||||
def _mapper(product: Product, pid: str) -> dict[str, object]:
|
||||
notes = " ".join(str(product.usage_notes or "").split())
|
||||
if len(notes) > 500:
|
||||
notes = notes[:497] + "..."
|
||||
return {"id": pid, "name": product.name, "usage_notes": notes}
|
||||
|
||||
return _build_product_details_tool_handler(products, mapper=_mapper)
|
||||
|
||||
|
||||
def _build_safety_rules_tool_handler(products: list[Product]):
|
||||
def _mapper(product: Product, pid: str) -> dict[str, object]:
|
||||
ctx = product.to_llm_context()
|
||||
return {
|
||||
"id": pid,
|
||||
"name": product.name,
|
||||
"incompatible_with": (ctx.get("incompatible_with") or [])[:24],
|
||||
"contraindications": (ctx.get("contraindications") or [])[:24],
|
||||
"context_rules": ctx.get("context_rules") or {},
|
||||
"safety": ctx.get("safety") or {},
|
||||
"min_interval_hours": ctx.get("min_interval_hours"),
|
||||
"max_frequency_per_week": ctx.get("max_frequency_per_week"),
|
||||
}
|
||||
|
||||
return _build_product_details_tool_handler(products, mapper=_mapper)
|
||||
|
||||
|
||||
_INCI_FUNCTION_DECLARATION = genai_types.FunctionDeclaration(
|
||||
name="get_product_inci",
|
||||
description=(
|
||||
"Return exact INCI ingredient lists for selected product UUIDs from "
|
||||
"POSIADANE PRODUKTY."
|
||||
),
|
||||
parameters=genai_types.Schema(
|
||||
type=genai_types.Type.OBJECT,
|
||||
properties={
|
||||
"product_ids": genai_types.Schema(
|
||||
type=genai_types.Type.ARRAY,
|
||||
items=genai_types.Schema(type=genai_types.Type.STRING),
|
||||
description="Product UUIDs from POSIADANE PRODUKTY.",
|
||||
)
|
||||
},
|
||||
required=["product_ids"],
|
||||
),
|
||||
)
|
||||
|
||||
_SAFETY_RULES_FUNCTION_DECLARATION = genai_types.FunctionDeclaration(
|
||||
name="get_product_safety_rules",
|
||||
description=(
|
||||
"Return safety and compatibility rules for selected product UUIDs, "
|
||||
"including incompatible_with, contraindications, context_rules and safety flags."
|
||||
),
|
||||
parameters=genai_types.Schema(
|
||||
type=genai_types.Type.OBJECT,
|
||||
properties={
|
||||
"product_ids": genai_types.Schema(
|
||||
type=genai_types.Type.ARRAY,
|
||||
items=genai_types.Schema(type=genai_types.Type.STRING),
|
||||
description="Product UUIDs from POSIADANE PRODUKTY.",
|
||||
)
|
||||
},
|
||||
required=["product_ids"],
|
||||
),
|
||||
)
|
||||
|
||||
_ACTIVES_FUNCTION_DECLARATION = genai_types.FunctionDeclaration(
|
||||
name="get_product_actives",
|
||||
description=(
|
||||
"Return detailed active ingredients for selected product UUIDs, "
|
||||
"including concentration and functions when available."
|
||||
),
|
||||
parameters=genai_types.Schema(
|
||||
type=genai_types.Type.OBJECT,
|
||||
properties={
|
||||
"product_ids": genai_types.Schema(
|
||||
type=genai_types.Type.ARRAY,
|
||||
items=genai_types.Schema(type=genai_types.Type.STRING),
|
||||
description="Product UUIDs from POSIADANE PRODUKTY.",
|
||||
)
|
||||
},
|
||||
required=["product_ids"],
|
||||
),
|
||||
)
|
||||
|
||||
_USAGE_NOTES_FUNCTION_DECLARATION = genai_types.FunctionDeclaration(
|
||||
name="get_product_usage_notes",
|
||||
description=(
|
||||
"Return compact usage notes for selected product UUIDs "
|
||||
"(timing, application method and cautions)."
|
||||
),
|
||||
parameters=genai_types.Schema(
|
||||
type=genai_types.Type.OBJECT,
|
||||
properties={
|
||||
"product_ids": genai_types.Schema(
|
||||
type=genai_types.Type.ARRAY,
|
||||
items=genai_types.Schema(type=genai_types.Type.STRING),
|
||||
description="Product UUIDs from POSIADANE PRODUKTY.",
|
||||
)
|
||||
},
|
||||
required=["product_ids"],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
_SHOPPING_SYSTEM_PROMPT = """Jesteś asystentem zakupowym w dziedzinie pielęgnacji skóry.
|
||||
Twoim zadaniem jest przeanalizować stan skóry użytkownika oraz produkty, które już posiada,
|
||||
a następnie zasugerować TYPY produktów (bez marek), które mogłyby uzupełnić ich rutynę.
|
||||
|
|
@ -623,23 +835,87 @@ Format odpowiedzi - zwróć wyłącznie JSON zgodny z podanym schematem."""
|
|||
@router.post("/suggest", response_model=ShoppingSuggestionResponse)
|
||||
def suggest_shopping(session: Session = Depends(get_session)):
|
||||
context = _build_shopping_context(session)
|
||||
shopping_products = _get_shopping_products(session)
|
||||
|
||||
prompt = (
|
||||
f"Na podstawie poniższych danych przeanalizuj, jakie TYPY produktów "
|
||||
f"mogłyby uzupełnić rutynę pielęgnacyjną użytkownika.\n\n"
|
||||
f"{context}\n\n"
|
||||
"NARZEDZIA:\n"
|
||||
"- Masz dostep do funkcji: get_product_inci, get_product_safety_rules, get_product_actives, get_product_usage_notes.\n"
|
||||
"- Wywoluj narzedzia tylko, gdy potrzebujesz detali do oceny konfliktow skladnikow lub ryzyka podraznien.\n"
|
||||
"- Grupuj UUID: staraj sie pobierac dane dla wielu produktow jednym wywolaniem.\n"
|
||||
f"Zwróć wyłącznie JSON zgodny ze schematem."
|
||||
)
|
||||
|
||||
response = call_gemini(
|
||||
config = get_creative_config(
|
||||
system_instruction=_SHOPPING_SYSTEM_PROMPT,
|
||||
response_schema=_ShoppingSuggestionsOut,
|
||||
max_output_tokens=4096,
|
||||
).model_copy(
|
||||
update={
|
||||
"tools": [
|
||||
genai_types.Tool(
|
||||
function_declarations=[
|
||||
_INCI_FUNCTION_DECLARATION,
|
||||
_SAFETY_RULES_FUNCTION_DECLARATION,
|
||||
_ACTIVES_FUNCTION_DECLARATION,
|
||||
_USAGE_NOTES_FUNCTION_DECLARATION,
|
||||
]
|
||||
)
|
||||
],
|
||||
"tool_config": genai_types.ToolConfig(
|
||||
function_calling_config=genai_types.FunctionCallingConfig(
|
||||
mode=genai_types.FunctionCallingConfigMode.AUTO,
|
||||
)
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
function_handlers = {
|
||||
"get_product_inci": _build_inci_tool_handler(shopping_products),
|
||||
"get_product_safety_rules": _build_safety_rules_tool_handler(shopping_products),
|
||||
"get_product_actives": _build_actives_tool_handler(shopping_products),
|
||||
"get_product_usage_notes": _build_usage_notes_tool_handler(shopping_products),
|
||||
}
|
||||
|
||||
try:
|
||||
response = call_gemini_with_function_tools(
|
||||
endpoint="products/suggest",
|
||||
contents=prompt,
|
||||
config=config,
|
||||
function_handlers=function_handlers,
|
||||
user_input=prompt,
|
||||
max_tool_roundtrips=3,
|
||||
)
|
||||
except HTTPException as exc:
|
||||
if (
|
||||
exc.status_code != 502
|
||||
or str(exc.detail) != "Gemini requested too many function calls"
|
||||
):
|
||||
raise
|
||||
|
||||
conservative_prompt = (
|
||||
f"{prompt}\n\n"
|
||||
"TRYB AWARYJNY (KONSERWATYWNY):\n"
|
||||
"- Osiagnieto limit wywolan narzedzi.\n"
|
||||
"- Nie wywoluj narzedzi ponownie.\n"
|
||||
"- Zasugeruj tylko najbardziej bezpieczne i realistyczne typy produktow do uzupelnienia brakow,"
|
||||
" unikaj agresywnych aktywnych przy niepelnych danych.\n"
|
||||
)
|
||||
response = call_gemini(
|
||||
endpoint="products/suggest",
|
||||
contents=conservative_prompt,
|
||||
config=get_creative_config(
|
||||
system_instruction=_SHOPPING_SYSTEM_PROMPT,
|
||||
response_schema=_ShoppingSuggestionsOut,
|
||||
max_output_tokens=4096,
|
||||
),
|
||||
user_input=prompt,
|
||||
user_input=conservative_prompt,
|
||||
tool_trace={
|
||||
"mode": "fallback_conservative",
|
||||
"reason": "max_tool_roundtrips_exceeded",
|
||||
},
|
||||
)
|
||||
|
||||
raw = response.text
|
||||
|
|
|
|||
|
|
@ -4,8 +4,16 @@ from unittest.mock import patch
|
|||
|
||||
from sqlmodel import Session
|
||||
|
||||
from innercontext.api.products import _build_shopping_context
|
||||
from innercontext.models import Product, SkinConditionSnapshot, ProductInventory
|
||||
from innercontext.api.products import (
|
||||
_build_actives_tool_handler,
|
||||
_build_inci_tool_handler,
|
||||
_build_safety_rules_tool_handler,
|
||||
_build_shopping_context,
|
||||
_build_usage_notes_tool_handler,
|
||||
_extract_requested_product_ids,
|
||||
)
|
||||
from innercontext.models import Product, ProductInventory, SkinConditionSnapshot
|
||||
|
||||
|
||||
def test_build_shopping_context(session: Session):
|
||||
# Empty context
|
||||
|
|
@ -23,7 +31,7 @@ def test_build_shopping_context(session: Session):
|
|||
sensitivity_level=4,
|
||||
barrier_state="mildly_compromised",
|
||||
active_concerns=["redness"],
|
||||
priorities=["soothing"]
|
||||
priorities=["soothing"],
|
||||
)
|
||||
session.add(snap)
|
||||
|
||||
|
|
@ -37,7 +45,7 @@ def test_build_shopping_context(session: Session):
|
|||
leave_on=True,
|
||||
targets=["redness"],
|
||||
product_effect_profile={"soothing_strength": 4, "hydration_immediate": 1},
|
||||
actives=[{"name": "Centella"}]
|
||||
actives=[{"name": "Centella"}],
|
||||
)
|
||||
session.add(p)
|
||||
session.commit()
|
||||
|
|
@ -55,7 +63,9 @@ def test_build_shopping_context(session: Session):
|
|||
assert "Priorytety: soothing" in ctx
|
||||
|
||||
# Check product
|
||||
assert "[✓] Soothing Serum" in ctx
|
||||
assert "[✓] id=" in ctx
|
||||
assert "Soothing Serum" in ctx
|
||||
assert f"id={p.id}" in ctx
|
||||
assert "BrandX" in ctx
|
||||
assert "targets: ['redness']" in ctx
|
||||
assert "actives: ['Centella']" in ctx
|
||||
|
|
@ -63,8 +73,16 @@ def test_build_shopping_context(session: Session):
|
|||
|
||||
|
||||
def test_suggest_shopping(client, session):
|
||||
with patch("innercontext.api.products.call_gemini") as mock_gemini:
|
||||
mock_response = type("Response", (), {"text": '{"suggestions": [{"category": "cleanser", "product_type": "cleanser", "priority": "high", "key_ingredients": [], "target_concerns": [], "why_needed": "reason", "recommended_time": "am", "frequency": "daily"}], "reasoning": "Test shopping"}'})
|
||||
with patch(
|
||||
"innercontext.api.products.call_gemini_with_function_tools"
|
||||
) as mock_gemini:
|
||||
mock_response = type(
|
||||
"Response",
|
||||
(),
|
||||
{
|
||||
"text": '{"suggestions": [{"category": "cleanser", "product_type": "cleanser", "priority": "high", "key_ingredients": [], "target_concerns": [], "why_needed": "reason", "recommended_time": "am", "frequency": "daily"}], "reasoning": "Test shopping"}'
|
||||
},
|
||||
)
|
||||
mock_gemini.return_value = mock_response
|
||||
|
||||
r = client.post("/products/suggest")
|
||||
|
|
@ -73,6 +91,13 @@ def test_suggest_shopping(client, session):
|
|||
assert len(data["suggestions"]) == 1
|
||||
assert data["suggestions"][0]["product_type"] == "cleanser"
|
||||
assert data["reasoning"] == "Test shopping"
|
||||
kwargs = mock_gemini.call_args.kwargs
|
||||
assert "function_handlers" in kwargs
|
||||
assert "get_product_inci" in kwargs["function_handlers"]
|
||||
assert "get_product_safety_rules" in kwargs["function_handlers"]
|
||||
assert "get_product_actives" in kwargs["function_handlers"]
|
||||
assert "get_product_usage_notes" in kwargs["function_handlers"]
|
||||
|
||||
|
||||
def test_shopping_context_medication_skip(session: Session):
|
||||
p = Product(
|
||||
|
|
@ -83,7 +108,7 @@ def test_shopping_context_medication_skip(session: Session):
|
|||
recommended_time="pm",
|
||||
leave_on=True,
|
||||
is_medication=True,
|
||||
product_effect_profile={}
|
||||
product_effect_profile={},
|
||||
)
|
||||
session.add(p)
|
||||
session.commit()
|
||||
|
|
@ -91,3 +116,42 @@ def test_shopping_context_medication_skip(session: Session):
|
|||
ctx = _build_shopping_context(session)
|
||||
assert "Epiduo" not in ctx
|
||||
|
||||
|
||||
def test_extract_requested_product_ids_dedupes_and_limits():
|
||||
ids = _extract_requested_product_ids(
|
||||
{"product_ids": ["a", "b", "a", 1, "c", "d"]},
|
||||
max_ids=3,
|
||||
)
|
||||
assert ids == ["a", "b", "c"]
|
||||
|
||||
|
||||
def test_shopping_tool_handlers_return_payloads(session: Session):
|
||||
product = Product(
|
||||
id=uuid.uuid4(),
|
||||
name="Test Product",
|
||||
brand="Brand",
|
||||
category="serum",
|
||||
recommended_time="both",
|
||||
leave_on=True,
|
||||
usage_notes="Use AM and PM on clean skin.",
|
||||
inci=["Water", "Niacinamide"],
|
||||
actives=[{"name": "Niacinamide", "percent": 5, "functions": ["niacinamide"]}],
|
||||
incompatible_with=[{"target": "Vitamin C", "scope": "same_step"}],
|
||||
context_rules={"safe_after_shaving": True},
|
||||
product_effect_profile={},
|
||||
)
|
||||
|
||||
payload = {"product_ids": [str(product.id)]}
|
||||
|
||||
inci_data = _build_inci_tool_handler([product])(payload)
|
||||
assert inci_data["products"][0]["inci"] == ["Water", "Niacinamide"]
|
||||
|
||||
actives_data = _build_actives_tool_handler([product])(payload)
|
||||
assert actives_data["products"][0]["actives"][0]["name"] == "Niacinamide"
|
||||
|
||||
notes_data = _build_usage_notes_tool_handler([product])(payload)
|
||||
assert notes_data["products"][0]["usage_notes"] == "Use AM and PM on clean skin."
|
||||
|
||||
safety_data = _build_safety_rules_tool_handler([product])(payload)
|
||||
assert "incompatible_with" in safety_data["products"][0]
|
||||
assert "context_rules" in safety_data["products"][0]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue