Implement Phase 1: Safety & Validation for all LLM-based suggestion engines. - Add input sanitization module to prevent prompt injection attacks - Implement 5 comprehensive validators (routine, batch, shopping, product parse, photo) - Add 10+ critical safety checks (retinoid+acid conflicts, barrier compatibility, etc.) - Integrate validation into all 5 API endpoints (routines, products, skincare) - Add validation fields to ai_call_logs table (validation_errors, validation_warnings, auto_fixed) - Create database migration for validation fields - Add comprehensive test suite (9/9 tests passing, 88% coverage on validators) Safety improvements: - Blocks retinoid + acid conflicts in same routine/day - Rejects unknown product IDs - Enforces min_interval_hours rules - Protects compromised skin barriers - Prevents prohibited fields (dose, amount) in responses - Validates all enum values and score ranges All validation failures are logged and responses are rejected with HTTP 502.
276 lines
9.3 KiB
Python
276 lines
9.3 KiB
Python
"""Validator for batch routine suggestions (multi-day plans)."""
|
|
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass
|
|
from datetime import date
|
|
from typing import Any
|
|
from uuid import UUID
|
|
|
|
from innercontext.validators.base import BaseValidator, ValidationResult
|
|
from innercontext.validators.routine_validator import (
|
|
RoutineSuggestionValidator,
|
|
RoutineValidationContext,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class BatchValidationContext:
|
|
"""Context needed to validate batch routine suggestions."""
|
|
|
|
valid_product_ids: set[UUID]
|
|
"""Set of product IDs that exist in the database."""
|
|
|
|
barrier_state: str | None
|
|
"""Current barrier state: 'intact', 'mildly_compromised', 'compromised'"""
|
|
|
|
products_by_id: dict[UUID, Any]
|
|
"""Map of product_id -> Product object for detailed checks."""
|
|
|
|
last_used_dates: dict[UUID, date]
|
|
"""Map of product_id -> last used date before batch period."""
|
|
|
|
|
|
class BatchValidator(BaseValidator):
|
|
"""Validates batch routine suggestions (multi-day AM+PM plans)."""
|
|
|
|
def __init__(self):
|
|
self.routine_validator = RoutineSuggestionValidator()
|
|
|
|
def validate(
|
|
self, response: Any, context: BatchValidationContext
|
|
) -> ValidationResult:
|
|
"""
|
|
Validate a batch routine suggestion.
|
|
|
|
Checks:
|
|
1. All individual routines pass single-routine validation
|
|
2. No retinoid + acid on same day (across AM/PM)
|
|
3. Product frequency limits respected across the batch
|
|
4. Min interval hours respected across days
|
|
|
|
Args:
|
|
response: Parsed batch suggestion with days
|
|
context: Validation context
|
|
|
|
Returns:
|
|
ValidationResult with any errors/warnings
|
|
"""
|
|
result = ValidationResult()
|
|
|
|
if not hasattr(response, "days"):
|
|
result.add_error("Response missing 'days' field")
|
|
return result
|
|
|
|
days = response.days
|
|
|
|
if not isinstance(days, list):
|
|
result.add_error("'days' must be a list")
|
|
return result
|
|
|
|
if not days:
|
|
result.add_error("'days' cannot be empty")
|
|
return result
|
|
|
|
# Track product usage for frequency checks
|
|
product_usage_dates: dict[UUID, list[date]] = defaultdict(list)
|
|
|
|
# Validate each day
|
|
for i, day in enumerate(days):
|
|
day_num = i + 1
|
|
|
|
if not hasattr(day, "date"):
|
|
result.add_error(f"Day {day_num}: missing 'date' field")
|
|
continue
|
|
|
|
day_date = day.date
|
|
if isinstance(day_date, str):
|
|
try:
|
|
day_date = date.fromisoformat(day_date)
|
|
except ValueError:
|
|
result.add_error(f"Day {day_num}: invalid date format '{day.date}'")
|
|
continue
|
|
|
|
# Collect products used this day for same-day conflict checking
|
|
day_products: set[UUID] = set()
|
|
day_has_retinoid = False
|
|
day_has_acid = False
|
|
|
|
# Validate AM routine if present
|
|
if hasattr(day, "am_steps") and day.am_steps:
|
|
am_result = self._validate_single_routine(
|
|
day.am_steps,
|
|
day_date,
|
|
"am",
|
|
context,
|
|
product_usage_dates,
|
|
f"Day {day_num} AM",
|
|
)
|
|
result.errors.extend(am_result.errors)
|
|
result.warnings.extend(am_result.warnings)
|
|
|
|
# Track products for same-day checking
|
|
products, has_retinoid, has_acid = self._get_routine_products(
|
|
day.am_steps, context
|
|
)
|
|
day_products.update(products)
|
|
if has_retinoid:
|
|
day_has_retinoid = True
|
|
if has_acid:
|
|
day_has_acid = True
|
|
|
|
# Validate PM routine if present
|
|
if hasattr(day, "pm_steps") and day.pm_steps:
|
|
pm_result = self._validate_single_routine(
|
|
day.pm_steps,
|
|
day_date,
|
|
"pm",
|
|
context,
|
|
product_usage_dates,
|
|
f"Day {day_num} PM",
|
|
)
|
|
result.errors.extend(pm_result.errors)
|
|
result.warnings.extend(pm_result.warnings)
|
|
|
|
# Track products for same-day checking
|
|
products, has_retinoid, has_acid = self._get_routine_products(
|
|
day.pm_steps, context
|
|
)
|
|
day_products.update(products)
|
|
if has_retinoid:
|
|
day_has_retinoid = True
|
|
if has_acid:
|
|
day_has_acid = True
|
|
|
|
# Check same-day retinoid + acid conflict
|
|
if day_has_retinoid and day_has_acid:
|
|
result.add_error(
|
|
f"Day {day_num} ({day_date}): SAFETY - cannot use retinoid and acid "
|
|
"on the same day (across AM+PM)"
|
|
)
|
|
|
|
# Check frequency limits across the batch
|
|
self._check_batch_frequency_limits(product_usage_dates, context, result)
|
|
|
|
return result
|
|
|
|
def _validate_single_routine(
|
|
self,
|
|
steps: list,
|
|
routine_date: date,
|
|
part_of_day: str,
|
|
context: BatchValidationContext,
|
|
product_usage_dates: dict[UUID, list[date]],
|
|
routine_label: str,
|
|
) -> ValidationResult:
|
|
"""Validate a single routine within the batch."""
|
|
# Build context for single routine validation
|
|
routine_context = RoutineValidationContext(
|
|
valid_product_ids=context.valid_product_ids,
|
|
routine_date=routine_date,
|
|
part_of_day=part_of_day,
|
|
leaving_home=None, # Not checked in batch mode
|
|
barrier_state=context.barrier_state,
|
|
products_by_id=context.products_by_id,
|
|
last_used_dates=context.last_used_dates,
|
|
just_shaved=False, # Not checked in batch mode
|
|
)
|
|
|
|
# Create a mock response object with steps
|
|
class MockRoutine:
|
|
def __init__(self, steps):
|
|
self.steps = steps
|
|
|
|
mock_response = MockRoutine(steps)
|
|
|
|
# Validate using routine validator
|
|
result = self.routine_validator.validate(mock_response, routine_context)
|
|
|
|
# Update product usage tracking
|
|
for step in steps:
|
|
if hasattr(step, "product_id") and step.product_id:
|
|
product_id = step.product_id
|
|
if isinstance(product_id, str):
|
|
try:
|
|
product_id = UUID(product_id)
|
|
except ValueError:
|
|
continue
|
|
product_usage_dates[product_id].append(routine_date)
|
|
|
|
# Prefix all errors/warnings with routine label
|
|
result.errors = [f"{routine_label}: {err}" for err in result.errors]
|
|
result.warnings = [f"{routine_label}: {warn}" for warn in result.warnings]
|
|
|
|
return result
|
|
|
|
def _get_routine_products(
|
|
self, steps: list, context: BatchValidationContext
|
|
) -> tuple[set[UUID], bool, bool]:
|
|
"""
|
|
Get products used in routine and check for retinoids/acids.
|
|
|
|
Returns:
|
|
(product_ids, has_retinoid, has_acid)
|
|
"""
|
|
products = set()
|
|
has_retinoid = False
|
|
has_acid = False
|
|
|
|
for step in steps:
|
|
if not hasattr(step, "product_id") or not step.product_id:
|
|
continue
|
|
|
|
product_id = step.product_id
|
|
if isinstance(product_id, str):
|
|
try:
|
|
product_id = UUID(product_id)
|
|
except ValueError:
|
|
continue
|
|
|
|
products.add(product_id)
|
|
|
|
product = context.products_by_id.get(product_id)
|
|
if not product:
|
|
continue
|
|
|
|
if self.routine_validator._has_retinoid(product):
|
|
has_retinoid = True
|
|
if self.routine_validator._has_acid(product):
|
|
has_acid = True
|
|
|
|
return products, has_retinoid, has_acid
|
|
|
|
def _check_batch_frequency_limits(
|
|
self,
|
|
product_usage_dates: dict[UUID, list[date]],
|
|
context: BatchValidationContext,
|
|
result: ValidationResult,
|
|
) -> None:
|
|
"""Check max_frequency_per_week limits across the batch."""
|
|
for product_id, usage_dates in product_usage_dates.items():
|
|
product = context.products_by_id.get(product_id)
|
|
if not product:
|
|
continue
|
|
|
|
if (
|
|
not hasattr(product, "max_frequency_per_week")
|
|
or not product.max_frequency_per_week
|
|
):
|
|
continue
|
|
|
|
max_per_week = product.max_frequency_per_week
|
|
|
|
# Group usage by week
|
|
weeks: dict[tuple[int, int], int] = defaultdict(
|
|
int
|
|
) # (year, week) -> count
|
|
for usage_date in usage_dates:
|
|
week_key = (usage_date.year, usage_date.isocalendar()[1])
|
|
weeks[week_key] += 1
|
|
|
|
# Check each week
|
|
for (year, week_num), count in weeks.items():
|
|
if count > max_per_week:
|
|
result.add_error(
|
|
f"Product {product.name}: used {count}x in week {week_num}/{year}, "
|
|
f"exceeds max_frequency_per_week={max_per_week}"
|
|
)
|