"""add auth tables and ownership Revision ID: 4b7d2e9f1c3a Revises: 9f3a2c1b4d5e Create Date: 2026-03-12 12:00:00.000000 """ import os from collections.abc import Sequence from datetime import datetime, timezone from uuid import UUID, uuid4 import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. revision: str = "4b7d2e9f1c3a" down_revision: str | Sequence[str] | None = "9f3a2c1b4d5e" branch_labels: str | Sequence[str] | None = None depends_on: str | Sequence[str] | None = None OWNED_TABLES: tuple[str, ...] = ( "products", "product_inventory", "user_profiles", "medication_entries", "medication_usages", "lab_results", "routines", "routine_steps", "grooming_schedule", "skin_condition_snapshots", "ai_call_logs", ) def _table_has_rows(connection: sa.Connection, table_name: str) -> bool: query = sa.text(f"SELECT 1 FROM {table_name} LIMIT 1") return connection.execute(query).first() is not None def _legacy_data_exists(connection: sa.Connection) -> bool: return any(_table_has_rows(connection, table_name) for table_name in OWNED_TABLES) def _ensure_bootstrap_user_and_household( connection: sa.Connection, *, issuer: str, subject: str, ) -> UUID: now = datetime.now(timezone.utc) users_table = sa.table( "users", sa.column("id", sa.Uuid()), sa.column("oidc_issuer", sa.String(length=512)), sa.column("oidc_subject", sa.String(length=512)), sa.column("role", sa.Enum("ADMIN", "MEMBER", name="role")), sa.column("created_at", sa.DateTime()), sa.column("updated_at", sa.DateTime(timezone=True)), ) user_id = connection.execute( sa.select(users_table.c.id).where( users_table.c.oidc_issuer == issuer, users_table.c.oidc_subject == subject, ) ).scalar_one_or_none() if user_id is None: user_id = uuid4() _ = connection.execute( sa.insert(users_table).values( id=user_id, oidc_issuer=issuer, oidc_subject=subject, role="ADMIN", created_at=now, updated_at=now, ) ) households_table = sa.table( "households", sa.column("id", sa.Uuid()), sa.column("created_at", sa.DateTime()), sa.column("updated_at", sa.DateTime(timezone=True)), ) memberships_table = sa.table( "household_memberships", sa.column("id", sa.Uuid()), sa.column("user_id", sa.Uuid()), sa.column("household_id", sa.Uuid()), sa.column("role", sa.Enum("OWNER", "MEMBER", name="householdrole")), sa.column("created_at", sa.DateTime()), sa.column("updated_at", sa.DateTime(timezone=True)), ) membership_id = connection.execute( sa.select(memberships_table.c.id).where(memberships_table.c.user_id == user_id) ).scalar_one_or_none() if membership_id is None: household_id = uuid4() _ = connection.execute( sa.insert(households_table).values( id=household_id, created_at=now, updated_at=now, ) ) _ = connection.execute( sa.insert(memberships_table).values( id=uuid4(), user_id=user_id, household_id=household_id, role="OWNER", created_at=now, updated_at=now, ) ) return user_id def _backfill_owned_rows(connection: sa.Connection, user_id: UUID) -> None: for table_name in OWNED_TABLES: table = sa.table(table_name, sa.column("user_id", sa.Uuid())) _ = connection.execute( sa.update(table).where(table.c.user_id.is_(None)).values(user_id=user_id) ) def upgrade() -> None: bind = op.get_bind() role_enum = sa.Enum("ADMIN", "MEMBER", name="role") household_role_enum = sa.Enum("OWNER", "MEMBER", name="householdrole") role_enum.create(bind, checkfirst=True) household_role_enum.create(bind, checkfirst=True) _ = op.create_table( "users", sa.Column("id", sa.Uuid(), nullable=False), sa.Column("oidc_issuer", sa.String(length=512), nullable=False), sa.Column("oidc_subject", sa.String(length=512), nullable=False), sa.Column("role", role_enum, nullable=False), sa.Column("created_at", sa.DateTime(), nullable=False), sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint( "oidc_issuer", "oidc_subject", name="uq_users_oidc_identity" ), ) op.create_index(op.f("ix_users_role"), "users", ["role"], unique=False) _ = op.create_table( "households", sa.Column("id", sa.Uuid(), nullable=False), sa.Column("created_at", sa.DateTime(), nullable=False), sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), sa.PrimaryKeyConstraint("id"), ) _ = op.create_table( "household_memberships", sa.Column("id", sa.Uuid(), nullable=False), sa.Column("user_id", sa.Uuid(), nullable=False), sa.Column("household_id", sa.Uuid(), nullable=False), sa.Column("role", household_role_enum, nullable=False), sa.Column("created_at", sa.DateTime(), nullable=False), sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), sa.ForeignKeyConstraint( ["household_id"], ["households.id"], ondelete="CASCADE" ), sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint("user_id", name="uq_household_memberships_user_id"), ) op.create_index( op.f("ix_household_memberships_household_id"), "household_memberships", ["household_id"], unique=False, ) op.create_index( op.f("ix_household_memberships_role"), "household_memberships", ["role"], unique=False, ) op.create_index( op.f("ix_household_memberships_user_id"), "household_memberships", ["user_id"], unique=False, ) for table_name in OWNED_TABLES: with op.batch_alter_table(table_name) as batch_op: batch_op.add_column(sa.Column("user_id", sa.Uuid(), nullable=True)) batch_op.create_index( op.f(f"ix_{table_name}_user_id"), ["user_id"], unique=False ) batch_op.create_foreign_key( f"fk_{table_name}_user_id_users", "users", ["user_id"], ["id"], ondelete="CASCADE", ) if table_name == "product_inventory": batch_op.add_column( sa.Column( "is_household_shared", sa.Boolean(), nullable=False, server_default=sa.false(), ) ) connection = op.get_bind() legacy_data_exists = _legacy_data_exists(connection) issuer = os.getenv("BOOTSTRAP_ADMIN_OIDC_ISSUER", "").strip() subject = os.getenv("BOOTSTRAP_ADMIN_OIDC_SUB", "").strip() bootstrap_email = os.getenv("BOOTSTRAP_ADMIN_EMAIL", "").strip() bootstrap_name = os.getenv("BOOTSTRAP_ADMIN_NAME", "").strip() bootstrap_household_name = os.getenv("BOOTSTRAP_HOUSEHOLD_NAME", "").strip() _ = (bootstrap_email, bootstrap_name, bootstrap_household_name) if legacy_data_exists: missing_required: list[str] = [] if not issuer: missing_required.append("BOOTSTRAP_ADMIN_OIDC_ISSUER") if not subject: missing_required.append("BOOTSTRAP_ADMIN_OIDC_SUB") if missing_required: missing_csv = ", ".join(missing_required) raise RuntimeError( f"Legacy data requires bootstrap admin identity; missing required env vars: {missing_csv}" ) bootstrap_user_id = _ensure_bootstrap_user_and_household( connection, issuer=issuer, subject=subject, ) _backfill_owned_rows(connection, bootstrap_user_id) for table_name in OWNED_TABLES: with op.batch_alter_table(table_name) as batch_op: batch_op.alter_column("user_id", existing_type=sa.Uuid(), nullable=False) def downgrade() -> None: for table_name in reversed(OWNED_TABLES): with op.batch_alter_table(table_name) as batch_op: batch_op.drop_constraint( f"fk_{table_name}_user_id_users", type_="foreignkey" ) batch_op.drop_index(op.f(f"ix_{table_name}_user_id")) if table_name == "product_inventory": batch_op.drop_column("is_household_shared") batch_op.drop_column("user_id") op.drop_index( op.f("ix_household_memberships_user_id"), table_name="household_memberships" ) op.drop_index( op.f("ix_household_memberships_role"), table_name="household_memberships" ) op.drop_index( op.f("ix_household_memberships_household_id"), table_name="household_memberships", ) op.drop_table("household_memberships") op.drop_table("households") op.drop_index(op.f("ix_users_role"), table_name="users") op.drop_table("users") bind = op.get_bind() household_role_enum = sa.Enum("OWNER", "MEMBER", name="householdrole") role_enum = sa.Enum("ADMIN", "MEMBER", name="role") household_role_enum.drop(bind, checkfirst=True) role_enum.drop(bind, checkfirst=True)