feat(backend): move product pricing to async persisted jobs

This commit is contained in:
Piotr Oleszczyk 2026-03-04 22:46:16 +01:00
parent c869f88db2
commit 0e439b4ca7
18 changed files with 468 additions and 67 deletions

View file

@ -7,7 +7,8 @@ from fastapi import APIRouter, Depends, HTTPException, Query
from google.genai import types as genai_types
from pydantic import BaseModel as PydanticBase
from pydantic import ValidationError
from sqlmodel import Session, SQLModel, col, select
from sqlalchemy import inspect, select as sa_select
from sqlmodel import Field, Session, SQLModel, col, select
from db import get_session
from innercontext.api.utils import get_or_404
@ -18,6 +19,7 @@ from innercontext.llm import (
get_extraction_config,
)
from innercontext.services.fx import convert_to_pln
from innercontext.services.pricing_jobs import enqueue_pricing_recalc
from innercontext.models import (
Product,
ProductBase,
@ -43,6 +45,10 @@ from innercontext.models.product import (
router = APIRouter()
PricingSource = Literal["category", "fallback", "insufficient_data"]
PricingOutput = tuple[PriceTier | None, float | None, PricingSource | None]
PricingOutputs = dict[UUID, PricingOutput]
# ---------------------------------------------------------------------------
# Request / response schemas
@ -150,6 +156,19 @@ class ProductParseResponse(SQLModel):
needle_length_mm: Optional[float] = None
class ProductListItem(SQLModel):
id: UUID
name: str
brand: str
category: ProductCategory
recommended_time: DayTime
targets: list[SkinConcern] = Field(default_factory=list)
is_owned: bool
price_tier: PriceTier | None = None
price_per_use_pln: float | None = None
price_tier_source: PricingSource | None = None
class AIActiveIngredient(ActiveIngredient):
# Gemini API rejects int-enum values in response_schema; override with plain int.
strength_level: Optional[int] = None # type: ignore[assignment]
@ -317,14 +336,7 @@ def _thresholds(values: list[float]) -> tuple[float, float, float]:
def _compute_pricing_outputs(
products: list[Product],
) -> dict[
UUID,
tuple[
PriceTier | None,
float | None,
Literal["category", "fallback", "insufficient_data"] | None,
],
]:
) -> PricingOutputs:
price_per_use_by_id: dict[UUID, float] = {}
grouped: dict[ProductCategory, list[tuple[UUID, float]]] = {}
@ -335,14 +347,7 @@ def _compute_pricing_outputs(
price_per_use_by_id[product.id] = ppu
grouped.setdefault(product.category, []).append((product.id, ppu))
outputs: dict[
UUID,
tuple[
PriceTier | None,
float | None,
Literal["category", "fallback", "insufficient_data"] | None,
],
] = {
outputs: PricingOutputs = {
p.id: (
None,
price_per_use_by_id.get(p.id),
@ -385,21 +390,6 @@ def _compute_pricing_outputs(
return outputs
def _with_pricing(
view: ProductPublic,
pricing: tuple[
PriceTier | None,
float | None,
Literal["category", "fallback", "insufficient_data"] | None,
],
) -> ProductPublic:
price_tier, price_per_use_pln, price_tier_source = pricing
view.price_tier = price_tier
view.price_per_use_pln = price_per_use_pln
view.price_tier_source = price_tier_source
return view
# ---------------------------------------------------------------------------
# Product routes
# ---------------------------------------------------------------------------
@ -424,7 +414,7 @@ def list_products(
if is_tool is not None:
stmt = stmt.where(Product.is_tool == is_tool)
products = session.exec(stmt).all()
products = list(session.exec(stmt).all())
# Filter by targets (JSON column — done in Python)
if targets:
@ -454,12 +444,8 @@ def list_products(
inv_by_product.setdefault(inv.product_id, []).append(inv)
results = []
pricing_pool = list(session.exec(select(Product)).all()) if products else []
pricing_outputs = _compute_pricing_outputs(pricing_pool)
for p in products:
r = ProductWithInventory.model_validate(p, from_attributes=True)
_with_pricing(r, pricing_outputs.get(p.id, (None, None, None)))
r.inventory = inv_by_product.get(p.id, [])
results.append(r)
return results
@ -476,6 +462,7 @@ def create_product(data: ProductCreate, session: Session = Depends(get_session))
**payload,
)
session.add(product)
enqueue_pricing_recalc(session)
session.commit()
session.refresh(product)
return product
@ -631,17 +618,104 @@ def parse_product_text(data: ProductParseRequest) -> ProductParseResponse:
raise HTTPException(status_code=422, detail=e.errors())
@router.get("/summary", response_model=list[ProductListItem])
def list_products_summary(
category: Optional[ProductCategory] = None,
brand: Optional[str] = None,
targets: Optional[list[SkinConcern]] = Query(default=None),
is_medication: Optional[bool] = None,
is_tool: Optional[bool] = None,
session: Session = Depends(get_session),
):
product_table = inspect(Product).local_table
stmt = sa_select(
product_table.c.id,
product_table.c.name,
product_table.c.brand,
product_table.c.category,
product_table.c.recommended_time,
product_table.c.targets,
product_table.c.price_tier,
product_table.c.price_per_use_pln,
product_table.c.price_tier_source,
)
if category is not None:
stmt = stmt.where(product_table.c.category == category)
if brand is not None:
stmt = stmt.where(product_table.c.brand == brand)
if is_medication is not None:
stmt = stmt.where(product_table.c.is_medication == is_medication)
if is_tool is not None:
stmt = stmt.where(product_table.c.is_tool == is_tool)
rows = list(session.execute(stmt).all())
if targets:
target_values = {t.value for t in targets}
rows = [
row
for row in rows
if any(
(t.value if hasattr(t, "value") else t) in target_values
for t in (row[5] or [])
)
]
product_ids = [row[0] for row in rows]
inventory_rows = (
session.exec(
select(ProductInventory).where(
col(ProductInventory.product_id).in_(product_ids)
)
).all()
if product_ids
else []
)
owned_ids = {
inv.product_id
for inv in inventory_rows
if inv.product_id is not None and inv.finished_at is None
}
results: list[ProductListItem] = []
for row in rows:
(
product_id,
name,
brand_value,
category_value,
recommended_time,
row_targets,
price_tier,
price_per_use_pln,
price_tier_source,
) = row
results.append(
ProductListItem(
id=product_id,
name=name,
brand=brand_value,
category=category_value,
recommended_time=recommended_time,
targets=row_targets or [],
is_owned=product_id in owned_ids,
price_tier=price_tier,
price_per_use_pln=price_per_use_pln,
price_tier_source=price_tier_source,
)
)
return results
@router.get("/{product_id}", response_model=ProductWithInventory)
def get_product(product_id: UUID, session: Session = Depends(get_session)):
product = get_or_404(session, Product, product_id)
pricing_pool = list(session.exec(select(Product)).all())
pricing_outputs = _compute_pricing_outputs(pricing_pool)
inventory = session.exec(
select(ProductInventory).where(ProductInventory.product_id == product_id)
).all()
result = ProductWithInventory.model_validate(product, from_attributes=True)
_with_pricing(result, pricing_outputs.get(product.id, (None, None, None)))
result.inventory = list(inventory)
return result
@ -658,18 +732,17 @@ def update_product(
for key, value in patch_data.items():
setattr(product, key, value)
session.add(product)
enqueue_pricing_recalc(session)
session.commit()
session.refresh(product)
pricing_pool = list(session.exec(select(Product)).all())
pricing_outputs = _compute_pricing_outputs(pricing_pool)
result = ProductPublic.model_validate(product, from_attributes=True)
return _with_pricing(result, pricing_outputs.get(product.id, (None, None, None)))
return ProductPublic.model_validate(product, from_attributes=True)
@router.delete("/{product_id}", status_code=204)
def delete_product(product_id: UUID, session: Session = Depends(get_session)):
product = get_or_404(session, Product, product_id)
session.delete(product)
enqueue_pricing_recalc(session)
session.commit()

View file

@ -32,6 +32,7 @@ from .product import (
ProductPublic,
ProductWithInventory,
)
from .pricing import PricingRecalcJob
from .routine import GroomingSchedule, Routine, RoutineStep
from .skincare import (
SkinConditionSnapshot,
@ -77,6 +78,7 @@ __all__ = [
"ProductInventory",
"ProductPublic",
"ProductWithInventory",
"PricingRecalcJob",
# routine
"GroomingSchedule",
"Routine",

View file

@ -0,0 +1,33 @@
from datetime import datetime
from typing import ClassVar
from uuid import UUID, uuid4
from sqlalchemy import Column, DateTime
from sqlmodel import Field, SQLModel
from .base import utc_now
from .domain import Domain
class PricingRecalcJob(SQLModel, table=True):
__tablename__ = "pricing_recalc_jobs"
__domains__: ClassVar[frozenset[Domain]] = frozenset({Domain.SKINCARE})
id: UUID = Field(default_factory=uuid4, primary_key=True)
scope: str = Field(default="global", max_length=32, index=True)
status: str = Field(default="pending", max_length=16, index=True)
attempts: int = Field(default=0, ge=0)
error: str | None = Field(default=None, max_length=512)
created_at: datetime = Field(default_factory=utc_now, nullable=False)
started_at: datetime | None = Field(default=None)
finished_at: datetime | None = Field(default=None)
updated_at: datetime = Field(
default_factory=utc_now,
sa_column=Column(
DateTime(timezone=True),
default=utc_now,
onupdate=utc_now,
nullable=False,
),
)

View file

@ -174,6 +174,11 @@ class Product(ProductBase, table=True):
default=None, sa_column=Column(JSON, nullable=True)
)
price_tier: PriceTier | None = Field(default=None, index=True)
price_per_use_pln: float | None = Field(default=None)
price_tier_source: str | None = Field(default=None, max_length=32)
pricing_computed_at: datetime | None = Field(default=None)
created_at: datetime = Field(default_factory=utc_now, nullable=False)
updated_at: datetime = Field(
default_factory=utc_now,

View file

@ -0,0 +1,93 @@
from datetime import datetime
from sqlmodel import Session, col, select
from innercontext.models import PricingRecalcJob, Product
from innercontext.models.base import utc_now
def enqueue_pricing_recalc(
session: Session, *, scope: str = "global"
) -> PricingRecalcJob:
existing = session.exec(
select(PricingRecalcJob)
.where(PricingRecalcJob.scope == scope)
.where(col(PricingRecalcJob.status).in_(["pending", "running"]))
.order_by(col(PricingRecalcJob.created_at).asc())
).first()
if existing is not None:
return existing
job = PricingRecalcJob(scope=scope, status="pending")
session.add(job)
return job
def claim_next_pending_pricing_job(session: Session) -> PricingRecalcJob | None:
stmt = (
select(PricingRecalcJob)
.where(PricingRecalcJob.status == "pending")
.order_by(col(PricingRecalcJob.created_at).asc())
)
bind = session.get_bind()
if bind is not None and bind.dialect.name == "postgresql":
stmt = stmt.with_for_update(skip_locked=True)
job = session.exec(stmt).first()
if job is None:
return None
job.status = "running"
job.attempts += 1
job.started_at = utc_now()
job.finished_at = None
job.error = None
session.add(job)
session.commit()
session.refresh(job)
return job
def _apply_pricing_snapshot(session: Session, computed_at: datetime) -> int:
from innercontext.api.products import _compute_pricing_outputs
products = list(session.exec(select(Product)).all())
pricing_outputs = _compute_pricing_outputs(products)
for product in products:
tier, price_per_use_pln, tier_source = pricing_outputs.get(
product.id, (None, None, None)
)
product.price_tier = tier
product.price_per_use_pln = price_per_use_pln
product.price_tier_source = tier_source
product.pricing_computed_at = computed_at
return len(products)
def process_pricing_job(session: Session, job: PricingRecalcJob) -> int:
try:
updated_count = _apply_pricing_snapshot(session, computed_at=utc_now())
job.status = "succeeded"
job.finished_at = utc_now()
job.error = None
session.add(job)
session.commit()
return updated_count
except Exception as exc:
session.rollback()
job.status = "failed"
job.finished_at = utc_now()
job.error = str(exc)[:512]
session.add(job)
session.commit()
raise
def process_one_pending_pricing_job(session: Session) -> bool:
job = claim_next_pending_pricing_job(session)
if job is None:
return False
process_pricing_job(session, job)
return True

View file

@ -0,0 +1,18 @@
import time
from sqlmodel import Session
from db import engine
from innercontext.services.pricing_jobs import process_one_pending_pricing_job
def run_forever(poll_interval_seconds: float = 2.0) -> None:
while True:
with Session(engine) as session:
processed = process_one_pending_pricing_job(session)
if not processed:
time.sleep(poll_interval_seconds)
if __name__ == "__main__":
run_forever()