innercontext/backend/innercontext/validators/batch_validator.py
Piotr Oleszczyk 2a9391ad32 feat(api): add LLM response validation and input sanitization
Implement Phase 1: Safety & Validation for all LLM-based suggestion engines.

- Add input sanitization module to prevent prompt injection attacks
- Implement 5 comprehensive validators (routine, batch, shopping, product parse, photo)
- Add 10+ critical safety checks (retinoid+acid conflicts, barrier compatibility, etc.)
- Integrate validation into all 5 API endpoints (routines, products, skincare)
- Add validation fields to ai_call_logs table (validation_errors, validation_warnings, auto_fixed)
- Create database migration for validation fields
- Add comprehensive test suite (9/9 tests passing, 88% coverage on validators)

Safety improvements:
- Blocks retinoid + acid conflicts in same routine/day
- Rejects unknown product IDs
- Enforces min_interval_hours rules
- Protects compromised skin barriers
- Prevents prohibited fields (dose, amount) in responses
- Validates all enum values and score ranges

All validation failures are logged and responses are rejected with HTTP 502.
2026-03-06 10:16:47 +01:00

276 lines
9.3 KiB
Python

"""Validator for batch routine suggestions (multi-day plans)."""
from collections import defaultdict
from dataclasses import dataclass
from datetime import date
from typing import Any
from uuid import UUID
from innercontext.validators.base import BaseValidator, ValidationResult
from innercontext.validators.routine_validator import (
RoutineSuggestionValidator,
RoutineValidationContext,
)
@dataclass
class BatchValidationContext:
"""Context needed to validate batch routine suggestions."""
valid_product_ids: set[UUID]
"""Set of product IDs that exist in the database."""
barrier_state: str | None
"""Current barrier state: 'intact', 'mildly_compromised', 'compromised'"""
products_by_id: dict[UUID, Any]
"""Map of product_id -> Product object for detailed checks."""
last_used_dates: dict[UUID, date]
"""Map of product_id -> last used date before batch period."""
class BatchValidator(BaseValidator):
"""Validates batch routine suggestions (multi-day AM+PM plans)."""
def __init__(self):
self.routine_validator = RoutineSuggestionValidator()
def validate(
self, response: Any, context: BatchValidationContext
) -> ValidationResult:
"""
Validate a batch routine suggestion.
Checks:
1. All individual routines pass single-routine validation
2. No retinoid + acid on same day (across AM/PM)
3. Product frequency limits respected across the batch
4. Min interval hours respected across days
Args:
response: Parsed batch suggestion with days
context: Validation context
Returns:
ValidationResult with any errors/warnings
"""
result = ValidationResult()
if not hasattr(response, "days"):
result.add_error("Response missing 'days' field")
return result
days = response.days
if not isinstance(days, list):
result.add_error("'days' must be a list")
return result
if not days:
result.add_error("'days' cannot be empty")
return result
# Track product usage for frequency checks
product_usage_dates: dict[UUID, list[date]] = defaultdict(list)
# Validate each day
for i, day in enumerate(days):
day_num = i + 1
if not hasattr(day, "date"):
result.add_error(f"Day {day_num}: missing 'date' field")
continue
day_date = day.date
if isinstance(day_date, str):
try:
day_date = date.fromisoformat(day_date)
except ValueError:
result.add_error(f"Day {day_num}: invalid date format '{day.date}'")
continue
# Collect products used this day for same-day conflict checking
day_products: set[UUID] = set()
day_has_retinoid = False
day_has_acid = False
# Validate AM routine if present
if hasattr(day, "am_steps") and day.am_steps:
am_result = self._validate_single_routine(
day.am_steps,
day_date,
"am",
context,
product_usage_dates,
f"Day {day_num} AM",
)
result.errors.extend(am_result.errors)
result.warnings.extend(am_result.warnings)
# Track products for same-day checking
products, has_retinoid, has_acid = self._get_routine_products(
day.am_steps, context
)
day_products.update(products)
if has_retinoid:
day_has_retinoid = True
if has_acid:
day_has_acid = True
# Validate PM routine if present
if hasattr(day, "pm_steps") and day.pm_steps:
pm_result = self._validate_single_routine(
day.pm_steps,
day_date,
"pm",
context,
product_usage_dates,
f"Day {day_num} PM",
)
result.errors.extend(pm_result.errors)
result.warnings.extend(pm_result.warnings)
# Track products for same-day checking
products, has_retinoid, has_acid = self._get_routine_products(
day.pm_steps, context
)
day_products.update(products)
if has_retinoid:
day_has_retinoid = True
if has_acid:
day_has_acid = True
# Check same-day retinoid + acid conflict
if day_has_retinoid and day_has_acid:
result.add_error(
f"Day {day_num} ({day_date}): SAFETY - cannot use retinoid and acid "
"on the same day (across AM+PM)"
)
# Check frequency limits across the batch
self._check_batch_frequency_limits(product_usage_dates, context, result)
return result
def _validate_single_routine(
self,
steps: list,
routine_date: date,
part_of_day: str,
context: BatchValidationContext,
product_usage_dates: dict[UUID, list[date]],
routine_label: str,
) -> ValidationResult:
"""Validate a single routine within the batch."""
# Build context for single routine validation
routine_context = RoutineValidationContext(
valid_product_ids=context.valid_product_ids,
routine_date=routine_date,
part_of_day=part_of_day,
leaving_home=None, # Not checked in batch mode
barrier_state=context.barrier_state,
products_by_id=context.products_by_id,
last_used_dates=context.last_used_dates,
just_shaved=False, # Not checked in batch mode
)
# Create a mock response object with steps
class MockRoutine:
def __init__(self, steps):
self.steps = steps
mock_response = MockRoutine(steps)
# Validate using routine validator
result = self.routine_validator.validate(mock_response, routine_context)
# Update product usage tracking
for step in steps:
if hasattr(step, "product_id") and step.product_id:
product_id = step.product_id
if isinstance(product_id, str):
try:
product_id = UUID(product_id)
except ValueError:
continue
product_usage_dates[product_id].append(routine_date)
# Prefix all errors/warnings with routine label
result.errors = [f"{routine_label}: {err}" for err in result.errors]
result.warnings = [f"{routine_label}: {warn}" for warn in result.warnings]
return result
def _get_routine_products(
self, steps: list, context: BatchValidationContext
) -> tuple[set[UUID], bool, bool]:
"""
Get products used in routine and check for retinoids/acids.
Returns:
(product_ids, has_retinoid, has_acid)
"""
products = set()
has_retinoid = False
has_acid = False
for step in steps:
if not hasattr(step, "product_id") or not step.product_id:
continue
product_id = step.product_id
if isinstance(product_id, str):
try:
product_id = UUID(product_id)
except ValueError:
continue
products.add(product_id)
product = context.products_by_id.get(product_id)
if not product:
continue
if self.routine_validator._has_retinoid(product):
has_retinoid = True
if self.routine_validator._has_acid(product):
has_acid = True
return products, has_retinoid, has_acid
def _check_batch_frequency_limits(
self,
product_usage_dates: dict[UUID, list[date]],
context: BatchValidationContext,
result: ValidationResult,
) -> None:
"""Check max_frequency_per_week limits across the batch."""
for product_id, usage_dates in product_usage_dates.items():
product = context.products_by_id.get(product_id)
if not product:
continue
if (
not hasattr(product, "max_frequency_per_week")
or not product.max_frequency_per_week
):
continue
max_per_week = product.max_frequency_per_week
# Group usage by week
weeks: dict[tuple[int, int], int] = defaultdict(
int
) # (year, week) -> count
for usage_date in usage_dates:
week_key = (usage_date.year, usage_date.isocalendar()[1])
weeks[week_key] += 1
# Check each week
for (year, week_num), count in weeks.items():
if count > max_per_week:
result.add_error(
f"Product {product.name}: used {count}x in week {week_num}/{year}, "
f"exceeds max_frequency_per_week={max_per_week}"
)