feat(backend): move product pricing to async persisted jobs

This commit is contained in:
Piotr Oleszczyk 2026-03-04 22:46:16 +01:00
parent c869f88db2
commit 0e439b4ca7
18 changed files with 468 additions and 67 deletions

View file

@ -0,0 +1,85 @@
"""add_async_pricing_jobs_and_snapshot_fields
Revision ID: f1a2b3c4d5e6
Revises: 7c91e4b2af38
Create Date: 2026-03-04 00:00:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
import sqlmodel.sql.sqltypes
from alembic import op
revision: str = "f1a2b3c4d5e6"
down_revision: Union[str, None] = "7c91e4b2af38"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column(
"products",
sa.Column(
"price_tier",
sa.Enum("BUDGET", "MID", "PREMIUM", "LUXURY", name="pricetier"),
nullable=True,
),
)
op.add_column("products", sa.Column("price_per_use_pln", sa.Float(), nullable=True))
op.add_column(
"products", sa.Column("price_tier_source", sa.String(length=32), nullable=True)
)
op.add_column(
"products", sa.Column("pricing_computed_at", sa.DateTime(), nullable=True)
)
op.create_index(
op.f("ix_products_price_tier"), "products", ["price_tier"], unique=False
)
op.create_table(
"pricing_recalc_jobs",
sa.Column("id", sa.Uuid(), nullable=False),
sa.Column("scope", sqlmodel.sql.sqltypes.AutoString(length=32), nullable=False),
sa.Column(
"status", sqlmodel.sql.sqltypes.AutoString(length=16), nullable=False
),
sa.Column("attempts", sa.Integer(), nullable=False),
sa.Column("error", sqlmodel.sql.sqltypes.AutoString(length=512), nullable=True),
sa.Column("created_at", sa.DateTime(), nullable=False),
sa.Column("started_at", sa.DateTime(), nullable=True),
sa.Column("finished_at", sa.DateTime(), nullable=True),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
op.f("ix_pricing_recalc_jobs_scope"),
"pricing_recalc_jobs",
["scope"],
unique=False,
)
op.create_index(
op.f("ix_pricing_recalc_jobs_status"),
"pricing_recalc_jobs",
["status"],
unique=False,
)
def downgrade() -> None:
op.drop_index(
op.f("ix_pricing_recalc_jobs_status"), table_name="pricing_recalc_jobs"
)
op.drop_index(
op.f("ix_pricing_recalc_jobs_scope"), table_name="pricing_recalc_jobs"
)
op.drop_table("pricing_recalc_jobs")
op.drop_index(op.f("ix_products_price_tier"), table_name="products")
op.drop_column("products", "pricing_computed_at")
op.drop_column("products", "price_tier_source")
op.drop_column("products", "price_per_use_pln")
op.drop_column("products", "price_tier")
op.execute("DROP TYPE IF EXISTS pricetier")

View file

@ -7,7 +7,8 @@ from fastapi import APIRouter, Depends, HTTPException, Query
from google.genai import types as genai_types
from pydantic import BaseModel as PydanticBase
from pydantic import ValidationError
from sqlmodel import Session, SQLModel, col, select
from sqlalchemy import inspect, select as sa_select
from sqlmodel import Field, Session, SQLModel, col, select
from db import get_session
from innercontext.api.utils import get_or_404
@ -18,6 +19,7 @@ from innercontext.llm import (
get_extraction_config,
)
from innercontext.services.fx import convert_to_pln
from innercontext.services.pricing_jobs import enqueue_pricing_recalc
from innercontext.models import (
Product,
ProductBase,
@ -43,6 +45,10 @@ from innercontext.models.product import (
router = APIRouter()
PricingSource = Literal["category", "fallback", "insufficient_data"]
PricingOutput = tuple[PriceTier | None, float | None, PricingSource | None]
PricingOutputs = dict[UUID, PricingOutput]
# ---------------------------------------------------------------------------
# Request / response schemas
@ -150,6 +156,19 @@ class ProductParseResponse(SQLModel):
needle_length_mm: Optional[float] = None
class ProductListItem(SQLModel):
id: UUID
name: str
brand: str
category: ProductCategory
recommended_time: DayTime
targets: list[SkinConcern] = Field(default_factory=list)
is_owned: bool
price_tier: PriceTier | None = None
price_per_use_pln: float | None = None
price_tier_source: PricingSource | None = None
class AIActiveIngredient(ActiveIngredient):
# Gemini API rejects int-enum values in response_schema; override with plain int.
strength_level: Optional[int] = None # type: ignore[assignment]
@ -317,14 +336,7 @@ def _thresholds(values: list[float]) -> tuple[float, float, float]:
def _compute_pricing_outputs(
products: list[Product],
) -> dict[
UUID,
tuple[
PriceTier | None,
float | None,
Literal["category", "fallback", "insufficient_data"] | None,
],
]:
) -> PricingOutputs:
price_per_use_by_id: dict[UUID, float] = {}
grouped: dict[ProductCategory, list[tuple[UUID, float]]] = {}
@ -335,14 +347,7 @@ def _compute_pricing_outputs(
price_per_use_by_id[product.id] = ppu
grouped.setdefault(product.category, []).append((product.id, ppu))
outputs: dict[
UUID,
tuple[
PriceTier | None,
float | None,
Literal["category", "fallback", "insufficient_data"] | None,
],
] = {
outputs: PricingOutputs = {
p.id: (
None,
price_per_use_by_id.get(p.id),
@ -385,21 +390,6 @@ def _compute_pricing_outputs(
return outputs
def _with_pricing(
view: ProductPublic,
pricing: tuple[
PriceTier | None,
float | None,
Literal["category", "fallback", "insufficient_data"] | None,
],
) -> ProductPublic:
price_tier, price_per_use_pln, price_tier_source = pricing
view.price_tier = price_tier
view.price_per_use_pln = price_per_use_pln
view.price_tier_source = price_tier_source
return view
# ---------------------------------------------------------------------------
# Product routes
# ---------------------------------------------------------------------------
@ -424,7 +414,7 @@ def list_products(
if is_tool is not None:
stmt = stmt.where(Product.is_tool == is_tool)
products = session.exec(stmt).all()
products = list(session.exec(stmt).all())
# Filter by targets (JSON column — done in Python)
if targets:
@ -454,12 +444,8 @@ def list_products(
inv_by_product.setdefault(inv.product_id, []).append(inv)
results = []
pricing_pool = list(session.exec(select(Product)).all()) if products else []
pricing_outputs = _compute_pricing_outputs(pricing_pool)
for p in products:
r = ProductWithInventory.model_validate(p, from_attributes=True)
_with_pricing(r, pricing_outputs.get(p.id, (None, None, None)))
r.inventory = inv_by_product.get(p.id, [])
results.append(r)
return results
@ -476,6 +462,7 @@ def create_product(data: ProductCreate, session: Session = Depends(get_session))
**payload,
)
session.add(product)
enqueue_pricing_recalc(session)
session.commit()
session.refresh(product)
return product
@ -631,17 +618,104 @@ def parse_product_text(data: ProductParseRequest) -> ProductParseResponse:
raise HTTPException(status_code=422, detail=e.errors())
@router.get("/summary", response_model=list[ProductListItem])
def list_products_summary(
category: Optional[ProductCategory] = None,
brand: Optional[str] = None,
targets: Optional[list[SkinConcern]] = Query(default=None),
is_medication: Optional[bool] = None,
is_tool: Optional[bool] = None,
session: Session = Depends(get_session),
):
product_table = inspect(Product).local_table
stmt = sa_select(
product_table.c.id,
product_table.c.name,
product_table.c.brand,
product_table.c.category,
product_table.c.recommended_time,
product_table.c.targets,
product_table.c.price_tier,
product_table.c.price_per_use_pln,
product_table.c.price_tier_source,
)
if category is not None:
stmt = stmt.where(product_table.c.category == category)
if brand is not None:
stmt = stmt.where(product_table.c.brand == brand)
if is_medication is not None:
stmt = stmt.where(product_table.c.is_medication == is_medication)
if is_tool is not None:
stmt = stmt.where(product_table.c.is_tool == is_tool)
rows = list(session.execute(stmt).all())
if targets:
target_values = {t.value for t in targets}
rows = [
row
for row in rows
if any(
(t.value if hasattr(t, "value") else t) in target_values
for t in (row[5] or [])
)
]
product_ids = [row[0] for row in rows]
inventory_rows = (
session.exec(
select(ProductInventory).where(
col(ProductInventory.product_id).in_(product_ids)
)
).all()
if product_ids
else []
)
owned_ids = {
inv.product_id
for inv in inventory_rows
if inv.product_id is not None and inv.finished_at is None
}
results: list[ProductListItem] = []
for row in rows:
(
product_id,
name,
brand_value,
category_value,
recommended_time,
row_targets,
price_tier,
price_per_use_pln,
price_tier_source,
) = row
results.append(
ProductListItem(
id=product_id,
name=name,
brand=brand_value,
category=category_value,
recommended_time=recommended_time,
targets=row_targets or [],
is_owned=product_id in owned_ids,
price_tier=price_tier,
price_per_use_pln=price_per_use_pln,
price_tier_source=price_tier_source,
)
)
return results
@router.get("/{product_id}", response_model=ProductWithInventory)
def get_product(product_id: UUID, session: Session = Depends(get_session)):
product = get_or_404(session, Product, product_id)
pricing_pool = list(session.exec(select(Product)).all())
pricing_outputs = _compute_pricing_outputs(pricing_pool)
inventory = session.exec(
select(ProductInventory).where(ProductInventory.product_id == product_id)
).all()
result = ProductWithInventory.model_validate(product, from_attributes=True)
_with_pricing(result, pricing_outputs.get(product.id, (None, None, None)))
result.inventory = list(inventory)
return result
@ -658,18 +732,17 @@ def update_product(
for key, value in patch_data.items():
setattr(product, key, value)
session.add(product)
enqueue_pricing_recalc(session)
session.commit()
session.refresh(product)
pricing_pool = list(session.exec(select(Product)).all())
pricing_outputs = _compute_pricing_outputs(pricing_pool)
result = ProductPublic.model_validate(product, from_attributes=True)
return _with_pricing(result, pricing_outputs.get(product.id, (None, None, None)))
return ProductPublic.model_validate(product, from_attributes=True)
@router.delete("/{product_id}", status_code=204)
def delete_product(product_id: UUID, session: Session = Depends(get_session)):
product = get_or_404(session, Product, product_id)
session.delete(product)
enqueue_pricing_recalc(session)
session.commit()

View file

@ -32,6 +32,7 @@ from .product import (
ProductPublic,
ProductWithInventory,
)
from .pricing import PricingRecalcJob
from .routine import GroomingSchedule, Routine, RoutineStep
from .skincare import (
SkinConditionSnapshot,
@ -77,6 +78,7 @@ __all__ = [
"ProductInventory",
"ProductPublic",
"ProductWithInventory",
"PricingRecalcJob",
# routine
"GroomingSchedule",
"Routine",

View file

@ -0,0 +1,33 @@
from datetime import datetime
from typing import ClassVar
from uuid import UUID, uuid4
from sqlalchemy import Column, DateTime
from sqlmodel import Field, SQLModel
from .base import utc_now
from .domain import Domain
class PricingRecalcJob(SQLModel, table=True):
__tablename__ = "pricing_recalc_jobs"
__domains__: ClassVar[frozenset[Domain]] = frozenset({Domain.SKINCARE})
id: UUID = Field(default_factory=uuid4, primary_key=True)
scope: str = Field(default="global", max_length=32, index=True)
status: str = Field(default="pending", max_length=16, index=True)
attempts: int = Field(default=0, ge=0)
error: str | None = Field(default=None, max_length=512)
created_at: datetime = Field(default_factory=utc_now, nullable=False)
started_at: datetime | None = Field(default=None)
finished_at: datetime | None = Field(default=None)
updated_at: datetime = Field(
default_factory=utc_now,
sa_column=Column(
DateTime(timezone=True),
default=utc_now,
onupdate=utc_now,
nullable=False,
),
)

View file

@ -174,6 +174,11 @@ class Product(ProductBase, table=True):
default=None, sa_column=Column(JSON, nullable=True)
)
price_tier: PriceTier | None = Field(default=None, index=True)
price_per_use_pln: float | None = Field(default=None)
price_tier_source: str | None = Field(default=None, max_length=32)
pricing_computed_at: datetime | None = Field(default=None)
created_at: datetime = Field(default_factory=utc_now, nullable=False)
updated_at: datetime = Field(
default_factory=utc_now,

View file

@ -0,0 +1,93 @@
from datetime import datetime
from sqlmodel import Session, col, select
from innercontext.models import PricingRecalcJob, Product
from innercontext.models.base import utc_now
def enqueue_pricing_recalc(
session: Session, *, scope: str = "global"
) -> PricingRecalcJob:
existing = session.exec(
select(PricingRecalcJob)
.where(PricingRecalcJob.scope == scope)
.where(col(PricingRecalcJob.status).in_(["pending", "running"]))
.order_by(col(PricingRecalcJob.created_at).asc())
).first()
if existing is not None:
return existing
job = PricingRecalcJob(scope=scope, status="pending")
session.add(job)
return job
def claim_next_pending_pricing_job(session: Session) -> PricingRecalcJob | None:
stmt = (
select(PricingRecalcJob)
.where(PricingRecalcJob.status == "pending")
.order_by(col(PricingRecalcJob.created_at).asc())
)
bind = session.get_bind()
if bind is not None and bind.dialect.name == "postgresql":
stmt = stmt.with_for_update(skip_locked=True)
job = session.exec(stmt).first()
if job is None:
return None
job.status = "running"
job.attempts += 1
job.started_at = utc_now()
job.finished_at = None
job.error = None
session.add(job)
session.commit()
session.refresh(job)
return job
def _apply_pricing_snapshot(session: Session, computed_at: datetime) -> int:
from innercontext.api.products import _compute_pricing_outputs
products = list(session.exec(select(Product)).all())
pricing_outputs = _compute_pricing_outputs(products)
for product in products:
tier, price_per_use_pln, tier_source = pricing_outputs.get(
product.id, (None, None, None)
)
product.price_tier = tier
product.price_per_use_pln = price_per_use_pln
product.price_tier_source = tier_source
product.pricing_computed_at = computed_at
return len(products)
def process_pricing_job(session: Session, job: PricingRecalcJob) -> int:
try:
updated_count = _apply_pricing_snapshot(session, computed_at=utc_now())
job.status = "succeeded"
job.finished_at = utc_now()
job.error = None
session.add(job)
session.commit()
return updated_count
except Exception as exc:
session.rollback()
job.status = "failed"
job.finished_at = utc_now()
job.error = str(exc)[:512]
session.add(job)
session.commit()
raise
def process_one_pending_pricing_job(session: Session) -> bool:
job = claim_next_pending_pricing_job(session)
if job is None:
return False
process_pricing_job(session, job)
return True

View file

@ -0,0 +1,18 @@
import time
from sqlmodel import Session
from db import engine
from innercontext.services.pricing_jobs import process_one_pending_pricing_job
def run_forever(poll_interval_seconds: float = 2.0) -> None:
while True:
with Session(engine) as session:
processed = process_one_pending_pricing_job(session)
if not processed:
time.sleep(poll_interval_seconds)
if __name__ == "__main__":
run_forever()

View file

@ -7,8 +7,9 @@ load_dotenv() # load .env before db.py reads DATABASE_URL
from fastapi import FastAPI # noqa: E402
from fastapi.middleware.cors import CORSMiddleware # noqa: E402
from sqlmodel import Session # noqa: E402
from db import create_db_and_tables # noqa: E402
from db import create_db_and_tables, engine # noqa: E402
from innercontext.api import ( # noqa: E402
ai_logs,
health,
@ -17,11 +18,18 @@ from innercontext.api import ( # noqa: E402
routines,
skincare,
)
from innercontext.services.pricing_jobs import enqueue_pricing_recalc # noqa: E402
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
create_db_and_tables()
try:
with Session(engine) as session:
enqueue_pricing_recalc(session)
session.commit()
except Exception as exc: # pragma: no cover
print(f"[startup] failed to enqueue pricing recalculation job: {exc}")
yield

View file

@ -1,8 +1,10 @@
import uuid
from innercontext.api import products as products_api
from innercontext.models import Product
from innercontext.models import PricingRecalcJob, Product
from innercontext.models.enums import DayTime, ProductCategory
from innercontext.services.pricing_jobs import process_one_pending_pricing_job
from sqlmodel import select
def _product(
@ -45,7 +47,7 @@ def test_compute_pricing_outputs_groups_by_category(monkeypatch):
assert cleanser_tiers[-1] == "luxury"
def test_price_tier_is_null_when_not_enough_products(client, monkeypatch):
def test_price_tier_is_null_when_not_enough_products(client, session, monkeypatch):
monkeypatch.setattr(products_api, "convert_to_pln", lambda amount, currency: amount)
base = {
@ -67,13 +69,15 @@ def test_price_tier_is_null_when_not_enough_products(client, monkeypatch):
)
assert response.status_code == 201
assert process_one_pending_pricing_job(session)
products = client.get("/products").json()
assert len(products) == 7
assert all(p["price_tier"] is None for p in products)
assert all(p["price_per_use_pln"] is not None for p in products)
def test_price_tier_is_computed_on_list(client, monkeypatch):
def test_price_tier_is_computed_by_worker(client, session, monkeypatch):
monkeypatch.setattr(products_api, "convert_to_pln", lambda amount, currency: amount)
base = {
@ -91,13 +95,15 @@ def test_price_tier_is_computed_on_list(client, monkeypatch):
)
assert response.status_code == 201
assert process_one_pending_pricing_job(session)
products = client.get("/products").json()
assert len(products) == 8
assert any(p["price_tier"] == "budget" for p in products)
assert any(p["price_tier"] == "luxury" for p in products)
def test_price_tier_uses_fallback_for_medium_categories(client, monkeypatch):
def test_price_tier_uses_fallback_for_medium_categories(client, session, monkeypatch):
monkeypatch.setattr(products_api, "convert_to_pln", lambda amount, currency: amount)
serum_base = {
@ -130,6 +136,8 @@ def test_price_tier_uses_fallback_for_medium_categories(client, monkeypatch):
)
assert response.status_code == 201
assert process_one_pending_pricing_job(session)
products = client.get("/products?category=toner").json()
assert len(products) == 5
assert all(p["price_tier"] is not None for p in products)
@ -137,7 +145,7 @@ def test_price_tier_uses_fallback_for_medium_categories(client, monkeypatch):
def test_price_tier_stays_null_for_tiny_categories_even_with_fallback_pool(
client, monkeypatch
client, session, monkeypatch
):
monkeypatch.setattr(products_api, "convert_to_pln", lambda amount, currency: amount)
@ -171,7 +179,27 @@ def test_price_tier_stays_null_for_tiny_categories_even_with_fallback_pool(
)
assert response.status_code == 201
assert process_one_pending_pricing_job(session)
oils = client.get("/products?category=oil").json()
assert len(oils) == 3
assert all(p["price_tier"] is None for p in oils)
assert all(p["price_tier_source"] == "insufficient_data" for p in oils)
def test_product_write_enqueues_pricing_job(client, session):
response = client.post(
"/products",
json={
"name": "Serum X",
"brand": "B",
"category": "serum",
"recommended_time": "both",
"leave_on": True,
},
)
assert response.status_code == 201
jobs = session.exec(select(PricingRecalcJob)).all()
assert len(jobs) == 1
assert jobs[0].status in {"pending", "running", "succeeded"}