feat(db): backfill tenant ownership for existing records

This commit is contained in:
Piotr Oleszczyk 2026-03-12 14:54:24 +01:00
parent 04daadccda
commit 2704d58673

View file

@ -0,0 +1,289 @@
"""add auth tables and ownership
Revision ID: 4b7d2e9f1c3a
Revises: 9f3a2c1b4d5e
Create Date: 2026-03-12 12:00:00.000000
"""
import os
from collections.abc import Sequence
from datetime import datetime, timezone
from uuid import UUID, uuid4
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "4b7d2e9f1c3a"
down_revision: str | Sequence[str] | None = "9f3a2c1b4d5e"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
OWNED_TABLES: tuple[str, ...] = (
"products",
"product_inventory",
"user_profiles",
"medication_entries",
"medication_usages",
"lab_results",
"routines",
"routine_steps",
"grooming_schedule",
"skin_condition_snapshots",
"ai_call_logs",
)
def _table_has_rows(connection: sa.Connection, table_name: str) -> bool:
query = sa.text(f"SELECT 1 FROM {table_name} LIMIT 1")
return connection.execute(query).first() is not None
def _legacy_data_exists(connection: sa.Connection) -> bool:
return any(_table_has_rows(connection, table_name) for table_name in OWNED_TABLES)
def _ensure_bootstrap_user_and_household(
connection: sa.Connection,
*,
issuer: str,
subject: str,
) -> UUID:
now = datetime.now(timezone.utc)
users_table = sa.table(
"users",
sa.column("id", sa.Uuid()),
sa.column("oidc_issuer", sa.String(length=512)),
sa.column("oidc_subject", sa.String(length=512)),
sa.column("role", sa.Enum("ADMIN", "MEMBER", name="role")),
sa.column("created_at", sa.DateTime()),
sa.column("updated_at", sa.DateTime(timezone=True)),
)
user_id = connection.execute(
sa.select(users_table.c.id).where(
users_table.c.oidc_issuer == issuer,
users_table.c.oidc_subject == subject,
)
).scalar_one_or_none()
if user_id is None:
user_id = uuid4()
_ = connection.execute(
sa.insert(users_table).values(
id=user_id,
oidc_issuer=issuer,
oidc_subject=subject,
role="ADMIN",
created_at=now,
updated_at=now,
)
)
households_table = sa.table(
"households",
sa.column("id", sa.Uuid()),
sa.column("created_at", sa.DateTime()),
sa.column("updated_at", sa.DateTime(timezone=True)),
)
memberships_table = sa.table(
"household_memberships",
sa.column("id", sa.Uuid()),
sa.column("user_id", sa.Uuid()),
sa.column("household_id", sa.Uuid()),
sa.column("role", sa.Enum("OWNER", "MEMBER", name="householdrole")),
sa.column("created_at", sa.DateTime()),
sa.column("updated_at", sa.DateTime(timezone=True)),
)
membership_id = connection.execute(
sa.select(memberships_table.c.id).where(memberships_table.c.user_id == user_id)
).scalar_one_or_none()
if membership_id is None:
household_id = uuid4()
_ = connection.execute(
sa.insert(households_table).values(
id=household_id,
created_at=now,
updated_at=now,
)
)
_ = connection.execute(
sa.insert(memberships_table).values(
id=uuid4(),
user_id=user_id,
household_id=household_id,
role="OWNER",
created_at=now,
updated_at=now,
)
)
return user_id
def _backfill_owned_rows(connection: sa.Connection, user_id: UUID) -> None:
for table_name in OWNED_TABLES:
table = sa.table(table_name, sa.column("user_id", sa.Uuid()))
_ = connection.execute(
sa.update(table).where(table.c.user_id.is_(None)).values(user_id=user_id)
)
def upgrade() -> None:
bind = op.get_bind()
role_enum = sa.Enum("ADMIN", "MEMBER", name="role")
household_role_enum = sa.Enum("OWNER", "MEMBER", name="householdrole")
role_enum.create(bind, checkfirst=True)
household_role_enum.create(bind, checkfirst=True)
_ = op.create_table(
"users",
sa.Column("id", sa.Uuid(), nullable=False),
sa.Column("oidc_issuer", sa.String(length=512), nullable=False),
sa.Column("oidc_subject", sa.String(length=512), nullable=False),
sa.Column("role", role_enum, nullable=False),
sa.Column("created_at", sa.DateTime(), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint(
"oidc_issuer", "oidc_subject", name="uq_users_oidc_identity"
),
)
op.create_index(op.f("ix_users_role"), "users", ["role"], unique=False)
_ = op.create_table(
"households",
sa.Column("id", sa.Uuid(), nullable=False),
sa.Column("created_at", sa.DateTime(), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
_ = op.create_table(
"household_memberships",
sa.Column("id", sa.Uuid(), nullable=False),
sa.Column("user_id", sa.Uuid(), nullable=False),
sa.Column("household_id", sa.Uuid(), nullable=False),
sa.Column("role", household_role_enum, nullable=False),
sa.Column("created_at", sa.DateTime(), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint(
["household_id"], ["households.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("user_id", name="uq_household_memberships_user_id"),
)
op.create_index(
op.f("ix_household_memberships_household_id"),
"household_memberships",
["household_id"],
unique=False,
)
op.create_index(
op.f("ix_household_memberships_role"),
"household_memberships",
["role"],
unique=False,
)
op.create_index(
op.f("ix_household_memberships_user_id"),
"household_memberships",
["user_id"],
unique=False,
)
for table_name in OWNED_TABLES:
with op.batch_alter_table(table_name) as batch_op:
batch_op.add_column(sa.Column("user_id", sa.Uuid(), nullable=True))
batch_op.create_index(
op.f(f"ix_{table_name}_user_id"), ["user_id"], unique=False
)
batch_op.create_foreign_key(
f"fk_{table_name}_user_id_users",
"users",
["user_id"],
["id"],
ondelete="CASCADE",
)
if table_name == "product_inventory":
batch_op.add_column(
sa.Column(
"is_household_shared",
sa.Boolean(),
nullable=False,
server_default=sa.false(),
)
)
connection = op.get_bind()
legacy_data_exists = _legacy_data_exists(connection)
issuer = os.getenv("BOOTSTRAP_ADMIN_OIDC_ISSUER", "").strip()
subject = os.getenv("BOOTSTRAP_ADMIN_OIDC_SUB", "").strip()
bootstrap_email = os.getenv("BOOTSTRAP_ADMIN_EMAIL", "").strip()
bootstrap_name = os.getenv("BOOTSTRAP_ADMIN_NAME", "").strip()
bootstrap_household_name = os.getenv("BOOTSTRAP_HOUSEHOLD_NAME", "").strip()
_ = (bootstrap_email, bootstrap_name, bootstrap_household_name)
if legacy_data_exists:
missing_required: list[str] = []
if not issuer:
missing_required.append("BOOTSTRAP_ADMIN_OIDC_ISSUER")
if not subject:
missing_required.append("BOOTSTRAP_ADMIN_OIDC_SUB")
if missing_required:
missing_csv = ", ".join(missing_required)
raise RuntimeError(
f"Legacy data requires bootstrap admin identity; missing required env vars: {missing_csv}"
)
bootstrap_user_id = _ensure_bootstrap_user_and_household(
connection,
issuer=issuer,
subject=subject,
)
_backfill_owned_rows(connection, bootstrap_user_id)
for table_name in OWNED_TABLES:
with op.batch_alter_table(table_name) as batch_op:
batch_op.alter_column("user_id", existing_type=sa.Uuid(), nullable=False)
def downgrade() -> None:
for table_name in reversed(OWNED_TABLES):
with op.batch_alter_table(table_name) as batch_op:
batch_op.drop_constraint(
f"fk_{table_name}_user_id_users", type_="foreignkey"
)
batch_op.drop_index(op.f(f"ix_{table_name}_user_id"))
if table_name == "product_inventory":
batch_op.drop_column("is_household_shared")
batch_op.drop_column("user_id")
op.drop_index(
op.f("ix_household_memberships_user_id"), table_name="household_memberships"
)
op.drop_index(
op.f("ix_household_memberships_role"), table_name="household_memberships"
)
op.drop_index(
op.f("ix_household_memberships_household_id"),
table_name="household_memberships",
)
op.drop_table("household_memberships")
op.drop_table("households")
op.drop_index(op.f("ix_users_role"), table_name="users")
op.drop_table("users")
bind = op.get_bind()
household_role_enum = sa.Enum("OWNER", "MEMBER", name="householdrole")
role_enum = sa.Enum("ADMIN", "MEMBER", name="role")
household_role_enum.drop(bind, checkfirst=True)
role_enum.drop(bind, checkfirst=True)