diff --git a/backend/alembic/versions/4b7d2e9f1c3a_add_auth_tables_and_ownership.py b/backend/alembic/versions/4b7d2e9f1c3a_add_auth_tables_and_ownership.py new file mode 100644 index 0000000..a609602 --- /dev/null +++ b/backend/alembic/versions/4b7d2e9f1c3a_add_auth_tables_and_ownership.py @@ -0,0 +1,289 @@ +"""add auth tables and ownership + +Revision ID: 4b7d2e9f1c3a +Revises: 9f3a2c1b4d5e +Create Date: 2026-03-12 12:00:00.000000 + +""" + +import os +from collections.abc import Sequence +from datetime import datetime, timezone +from uuid import UUID, uuid4 + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "4b7d2e9f1c3a" +down_revision: str | Sequence[str] | None = "9f3a2c1b4d5e" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + +OWNED_TABLES: tuple[str, ...] = ( + "products", + "product_inventory", + "user_profiles", + "medication_entries", + "medication_usages", + "lab_results", + "routines", + "routine_steps", + "grooming_schedule", + "skin_condition_snapshots", + "ai_call_logs", +) + + +def _table_has_rows(connection: sa.Connection, table_name: str) -> bool: + query = sa.text(f"SELECT 1 FROM {table_name} LIMIT 1") + return connection.execute(query).first() is not None + + +def _legacy_data_exists(connection: sa.Connection) -> bool: + return any(_table_has_rows(connection, table_name) for table_name in OWNED_TABLES) + + +def _ensure_bootstrap_user_and_household( + connection: sa.Connection, + *, + issuer: str, + subject: str, +) -> UUID: + now = datetime.now(timezone.utc) + + users_table = sa.table( + "users", + sa.column("id", sa.Uuid()), + sa.column("oidc_issuer", sa.String(length=512)), + sa.column("oidc_subject", sa.String(length=512)), + sa.column("role", sa.Enum("ADMIN", "MEMBER", name="role")), + sa.column("created_at", sa.DateTime()), + sa.column("updated_at", sa.DateTime(timezone=True)), + ) + + user_id = connection.execute( + sa.select(users_table.c.id).where( + users_table.c.oidc_issuer == issuer, + users_table.c.oidc_subject == subject, + ) + ).scalar_one_or_none() + + if user_id is None: + user_id = uuid4() + _ = connection.execute( + sa.insert(users_table).values( + id=user_id, + oidc_issuer=issuer, + oidc_subject=subject, + role="ADMIN", + created_at=now, + updated_at=now, + ) + ) + + households_table = sa.table( + "households", + sa.column("id", sa.Uuid()), + sa.column("created_at", sa.DateTime()), + sa.column("updated_at", sa.DateTime(timezone=True)), + ) + memberships_table = sa.table( + "household_memberships", + sa.column("id", sa.Uuid()), + sa.column("user_id", sa.Uuid()), + sa.column("household_id", sa.Uuid()), + sa.column("role", sa.Enum("OWNER", "MEMBER", name="householdrole")), + sa.column("created_at", sa.DateTime()), + sa.column("updated_at", sa.DateTime(timezone=True)), + ) + + membership_id = connection.execute( + sa.select(memberships_table.c.id).where(memberships_table.c.user_id == user_id) + ).scalar_one_or_none() + if membership_id is None: + household_id = uuid4() + _ = connection.execute( + sa.insert(households_table).values( + id=household_id, + created_at=now, + updated_at=now, + ) + ) + _ = connection.execute( + sa.insert(memberships_table).values( + id=uuid4(), + user_id=user_id, + household_id=household_id, + role="OWNER", + created_at=now, + updated_at=now, + ) + ) + + return user_id + + +def _backfill_owned_rows(connection: sa.Connection, user_id: UUID) -> None: + for table_name in OWNED_TABLES: + table = sa.table(table_name, sa.column("user_id", sa.Uuid())) + _ = connection.execute( + sa.update(table).where(table.c.user_id.is_(None)).values(user_id=user_id) + ) + + +def upgrade() -> None: + bind = op.get_bind() + + role_enum = sa.Enum("ADMIN", "MEMBER", name="role") + household_role_enum = sa.Enum("OWNER", "MEMBER", name="householdrole") + role_enum.create(bind, checkfirst=True) + household_role_enum.create(bind, checkfirst=True) + + _ = op.create_table( + "users", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("oidc_issuer", sa.String(length=512), nullable=False), + sa.Column("oidc_subject", sa.String(length=512), nullable=False), + sa.Column("role", role_enum, nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint( + "oidc_issuer", "oidc_subject", name="uq_users_oidc_identity" + ), + ) + op.create_index(op.f("ix_users_role"), "users", ["role"], unique=False) + + _ = op.create_table( + "households", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + + _ = op.create_table( + "household_memberships", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("user_id", sa.Uuid(), nullable=False), + sa.Column("household_id", sa.Uuid(), nullable=False), + sa.Column("role", household_role_enum, nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), + sa.ForeignKeyConstraint( + ["household_id"], ["households.id"], ondelete="CASCADE" + ), + sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("user_id", name="uq_household_memberships_user_id"), + ) + op.create_index( + op.f("ix_household_memberships_household_id"), + "household_memberships", + ["household_id"], + unique=False, + ) + op.create_index( + op.f("ix_household_memberships_role"), + "household_memberships", + ["role"], + unique=False, + ) + op.create_index( + op.f("ix_household_memberships_user_id"), + "household_memberships", + ["user_id"], + unique=False, + ) + + for table_name in OWNED_TABLES: + with op.batch_alter_table(table_name) as batch_op: + batch_op.add_column(sa.Column("user_id", sa.Uuid(), nullable=True)) + batch_op.create_index( + op.f(f"ix_{table_name}_user_id"), ["user_id"], unique=False + ) + batch_op.create_foreign_key( + f"fk_{table_name}_user_id_users", + "users", + ["user_id"], + ["id"], + ondelete="CASCADE", + ) + if table_name == "product_inventory": + batch_op.add_column( + sa.Column( + "is_household_shared", + sa.Boolean(), + nullable=False, + server_default=sa.false(), + ) + ) + + connection = op.get_bind() + legacy_data_exists = _legacy_data_exists(connection) + + issuer = os.getenv("BOOTSTRAP_ADMIN_OIDC_ISSUER", "").strip() + subject = os.getenv("BOOTSTRAP_ADMIN_OIDC_SUB", "").strip() + bootstrap_email = os.getenv("BOOTSTRAP_ADMIN_EMAIL", "").strip() + bootstrap_name = os.getenv("BOOTSTRAP_ADMIN_NAME", "").strip() + bootstrap_household_name = os.getenv("BOOTSTRAP_HOUSEHOLD_NAME", "").strip() + _ = (bootstrap_email, bootstrap_name, bootstrap_household_name) + + if legacy_data_exists: + missing_required: list[str] = [] + if not issuer: + missing_required.append("BOOTSTRAP_ADMIN_OIDC_ISSUER") + if not subject: + missing_required.append("BOOTSTRAP_ADMIN_OIDC_SUB") + + if missing_required: + missing_csv = ", ".join(missing_required) + raise RuntimeError( + f"Legacy data requires bootstrap admin identity; missing required env vars: {missing_csv}" + ) + + bootstrap_user_id = _ensure_bootstrap_user_and_household( + connection, + issuer=issuer, + subject=subject, + ) + _backfill_owned_rows(connection, bootstrap_user_id) + + for table_name in OWNED_TABLES: + with op.batch_alter_table(table_name) as batch_op: + batch_op.alter_column("user_id", existing_type=sa.Uuid(), nullable=False) + + +def downgrade() -> None: + for table_name in reversed(OWNED_TABLES): + with op.batch_alter_table(table_name) as batch_op: + batch_op.drop_constraint( + f"fk_{table_name}_user_id_users", type_="foreignkey" + ) + batch_op.drop_index(op.f(f"ix_{table_name}_user_id")) + if table_name == "product_inventory": + batch_op.drop_column("is_household_shared") + batch_op.drop_column("user_id") + + op.drop_index( + op.f("ix_household_memberships_user_id"), table_name="household_memberships" + ) + op.drop_index( + op.f("ix_household_memberships_role"), table_name="household_memberships" + ) + op.drop_index( + op.f("ix_household_memberships_household_id"), + table_name="household_memberships", + ) + op.drop_table("household_memberships") + op.drop_table("households") + op.drop_index(op.f("ix_users_role"), table_name="users") + op.drop_table("users") + + bind = op.get_bind() + household_role_enum = sa.Enum("OWNER", "MEMBER", name="householdrole") + role_enum = sa.Enum("ADMIN", "MEMBER", name="role") + household_role_enum.drop(bind, checkfirst=True) + role_enum.drop(bind, checkfirst=True)