Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,4 @@ migrations: ## generate migration file
docker compose run ldap_server alembic revision --autogenerate

migrate: ## upgrade db
docker compose run ldap_server alembic upgrade head
Comment thread
TheMihMih marked this conversation as resolved.
docker compose run ldap_server alembic upgrade head
54 changes: 41 additions & 13 deletions app/alembic/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,18 @@
from logging.config import fileConfig

from alembic import context
from dishka import AsyncContainer, make_async_container
from sqlalchemy import Connection, text
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.ext.asyncio import AsyncConnection

from config import Settings
from ioc import (
HTTPProvider,
MainProvider,
MFACredsProvider,
MFAProvider,
MigrationProvider,
)
from repo.pg.tables import metadata

# this is the Alembic Config object, which provides
Expand All @@ -22,7 +30,11 @@
target_metadata = metadata


def run_sync_migrations(connection: Connection, schema_name: str) -> None:
def run_sync_migrations(
connection: Connection,
schema_name: str,
dishka_container: AsyncContainer,
) -> None:
"""Run sync migrations."""
if schema_name != "public":
connection.execute(text(f"SET search_path = {schema_name}, public;"))
Expand All @@ -35,18 +47,20 @@ def run_sync_migrations(connection: Connection, schema_name: str) -> None:
)

with context.begin_transaction():
context.run_migrations()
context.run_migrations(container=dishka_container)


async def run_async_migrations(settings: Settings) -> None:
async def run_async_migrations(
settings: Settings,
dishka_container: AsyncContainer,
) -> None:
"""Run async migrations."""
engine = create_async_engine(str(settings.POSTGRES_URI))

async with engine.connect() as connection:
await connection.run_sync(
run_sync_migrations,
schema_name=settings.TEST_POSTGRES_SCHEMA,
)
connection = await dishka_container.get(AsyncConnection)
await connection.run_sync(
run_sync_migrations,
schema_name=settings.TEST_POSTGRES_SCHEMA,
dishka_container=dishka_container,
)


def run_migrations_online() -> None:
Expand All @@ -60,11 +74,25 @@ def run_migrations_online() -> None:
"app_settings",
Settings.from_os(),
)
dishka_container = context.config.attributes.get("dishka_container", None)
if not dishka_container:
dishka_container = make_async_container(
MainProvider(),
MFACredsProvider(),
MFAProvider(),
HTTPProvider(),
MigrationProvider(),
context={Settings: settings},
)

if conn is None:
asyncio.run(run_async_migrations(settings))
asyncio.run(run_async_migrations(settings, dishka_container))
else:
run_sync_migrations(conn, schema_name=settings.TEST_POSTGRES_SCHEMA)
run_sync_migrations(
conn,
schema_name=settings.TEST_POSTGRES_SCHEMA,
dishka_container=dishka_container,
)


run_migrations_online()
5 changes: 3 additions & 2 deletions app/alembic/script.py.mako
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Create Date: ${create_date}
"""
from alembic import op
import sqlalchemy as sa
from dishka import AsyncContainer
${imports if imports else ""}

# revision identifiers, used by Alembic.
Expand All @@ -16,11 +17,11 @@ branch_labels: None | list[str] = ${repr(branch_labels)}
depends_on: None | list[str] = ${repr(depends_on)}


def upgrade() -> None:
def upgrade(container: AsyncContainer) -> None:
"""Upgrade."""
${upgrades if upgrades else "pass"}


def downgrade() -> None:
def downgrade(container: AsyncContainer) -> None:
"""Downgrade."""
${downgrades if downgrades else "pass"}
31 changes: 8 additions & 23 deletions app/alembic/versions/01f3f05a5b11_add_primary_group_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""

from alembic import op
from dishka import AsyncContainer, Scope
from sqlalchemy import delete, exists, select
from sqlalchemy.exc import DBAPIError, IntegrityError
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession
Expand All @@ -18,9 +19,6 @@
AttributeValueValidator,
)
from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO
from ldap_protocol.ldap_schema.object_class_dao import ObjectClassDAO
from ldap_protocol.roles.ace_dao import AccessControlEntryDAO
from ldap_protocol.roles.role_dao import RoleDAO
from ldap_protocol.roles.role_use_case import RoleUseCase
from ldap_protocol.utils.queries import (
create_group,
Expand All @@ -37,27 +35,19 @@
depends_on: None = None


def upgrade() -> None:
def upgrade(container: AsyncContainer) -> None:
"""Upgrade."""

async def _add_domain_computers_group(connection: AsyncConnection) -> None:
session = AsyncSession(connection)
await session.begin()
async def _add_domain_computers_group(connection: AsyncConnection) -> None: # noqa: ARG001
Comment thread
TheMihMih marked this conversation as resolved.
async with container(scope=Scope.REQUEST) as cnt:
session = await cnt.get(AsyncSession)
entity_type_dao = await cnt.get(EntityTypeDAO)
role_use_case = await cnt.get(RoleUseCase)

base_dn_list = await get_base_directories(session)
if not base_dn_list:
return

object_class_dao = ObjectClassDAO(session)
entity_type_dao = EntityTypeDAO(
session,
object_class_dao=object_class_dao,
attribute_value_validator=AttributeValueValidator(),
)
role_dao = RoleDAO(session)
ace_dao = AccessControlEntryDAO(session)
role_use_case = RoleUseCase(role_dao, ace_dao)

try:
group_dir_query = select(
exists(Directory)
Expand Down Expand Up @@ -126,7 +116,6 @@ async def _add_domain_computers_group(connection: AsyncConnection) -> None:

async def _add_primary_group_id(connection: AsyncConnection) -> None:
session = AsyncSession(connection)
await session.begin()

base_dn_list = await get_base_directories(session)
if not base_dn_list:
Expand Down Expand Up @@ -172,12 +161,10 @@ async def _add_primary_group_id(connection: AsyncConnection) -> None:
except (IntegrityError, DBAPIError):
pass

await session.close()

op.run_async(_add_primary_group_id)


def downgrade() -> None:
def downgrade(container: AsyncContainer) -> None: # noqa: ARG001
"""Downgrade."""
bind = op.get_bind()
session = Session(bind=bind)
Expand All @@ -186,8 +173,6 @@ async def _delete_domain_computers_group(
connection: AsyncConnection,
) -> None:
session = AsyncSession(connection)
Comment thread
TheMihMih marked this conversation as resolved.
await session.begin()

base_dn_list = await get_base_directories(session)
if not base_dn_list:
return
Expand Down
5 changes: 3 additions & 2 deletions app/alembic/versions/05ddc0bd562a_add_roles.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import sqlalchemy as sa
from alembic import op
from dishka import AsyncContainer
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession

Expand All @@ -24,7 +25,7 @@
depends_on: None = None


def upgrade() -> None:
def upgrade(container: AsyncContainer) -> None: # noqa: ARG001
"""Upgrade."""
op.create_table(
"Roles",
Expand Down Expand Up @@ -184,7 +185,7 @@ async def _create_system_roles(connection: AsyncConnection) -> None:
op.run_async(_create_system_roles)


def downgrade() -> None:
def downgrade(container: AsyncContainer) -> None: # noqa: ARG001
Comment thread
TheMihMih marked this conversation as resolved.
"""Downgrade."""
op.create_table(
"AccessPolicies",
Expand Down
5 changes: 3 additions & 2 deletions app/alembic/versions/16a9fa2c1f1e_rename_readonly_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""

from alembic import op
from dishka import AsyncContainer
from sqlalchemy import select, update
from sqlalchemy.exc import DBAPIError, IntegrityError
from sqlalchemy.orm import Session, selectinload
Expand All @@ -21,7 +22,7 @@
depends_on: None | list[str] = None


def upgrade() -> None:
def upgrade(container: AsyncContainer) -> None: # noqa: ARG001
"""Upgrade."""
bind = op.get_bind()
session = Session(bind=bind)
Expand Down Expand Up @@ -72,7 +73,7 @@ def upgrade() -> None:
session.close()


def downgrade() -> None:
def downgrade(container: AsyncContainer) -> None: # noqa: ARG001
"""Downgrade."""
bind = op.get_bind()
session = Session(bind=bind)
Expand Down
5 changes: 3 additions & 2 deletions app/alembic/versions/196f0d327c6a_.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""

from alembic import op
from dishka import AsyncContainer

# revision identifiers, used by Alembic.
revision = "196f0d327c6a"
Expand All @@ -15,7 +16,7 @@
depends_on: None | str = None


def upgrade() -> None:
def upgrade(container: AsyncContainer) -> None: # noqa: ARG001
"""Upgrade."""
op.drop_constraint(
"AccessPolicyMemberships_policy_id_fkey",
Expand Down Expand Up @@ -201,7 +202,7 @@ def upgrade() -> None:
)


def downgrade() -> None:
def downgrade(container: AsyncContainer) -> None: # noqa: ARG001
"""Downgrade."""
op.drop_constraint(
"PolicyMemberships_policy_id_fkey",
Expand Down
10 changes: 6 additions & 4 deletions app/alembic/versions/275222846605_initial_ldap_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import sqlalchemy as sa
from alembic import op
from dishka import AsyncContainer, Scope
from ldap3.protocol.schemas.ad2012R2 import ad_2012_r2_schema
from sqlalchemy import delete, or_
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession
Expand All @@ -36,7 +37,7 @@


@temporary_stub_entity_type_name
def upgrade() -> None:
def upgrade(container: AsyncContainer) -> None:
"""Upgrade."""
bind = op.get_bind()
session = Session(bind=bind)
Expand Down Expand Up @@ -371,8 +372,9 @@ async def _modify_object_classes(connection: AsyncConnection) -> None:
session = AsyncSession(bind=connection)
await session.begin()

at_dao = AttributeTypeDAO(session)
oc_dao = ObjectClassDAO(session)
async with container(scope=Scope.REQUEST) as cnt:
at_dao = await cnt.get(AttributeTypeDAO)
oc_dao = await cnt.get(ObjectClassDAO)

for oc_name, at_names in (
("user", ["nsAccountLock", "shadowExpire"]),
Expand All @@ -393,7 +395,7 @@ async def _modify_object_classes(connection: AsyncConnection) -> None:
session.commit()


def downgrade() -> None:
def downgrade(container: AsyncContainer) -> None: # noqa: ARG001
"""Downgrade."""
op.drop_index(
"idx_object_classes_name_gin_trgm",
Expand Down
5 changes: 3 additions & 2 deletions app/alembic/versions/35d1542d2505_add_entity_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import sqlalchemy as sa
from alembic import op
from dishka import AsyncContainer
from sqlalchemy.sql import text

# revision identifiers, used by Alembic.
Expand All @@ -17,7 +18,7 @@
depends_on: None = None


def upgrade() -> None:
def upgrade(container: AsyncContainer) -> None: # noqa: ARG001
"""Upgrade."""
op.add_column(
"EntityTypes",
Expand Down Expand Up @@ -87,7 +88,7 @@ def upgrade() -> None:
op.drop_column("Directory", "entity_type_name")


def downgrade() -> None:
def downgrade(container: AsyncContainer) -> None: # noqa: ARG001
"""Downgrade."""
op.add_column(
"Directory",
Expand Down
5 changes: 3 additions & 2 deletions app/alembic/versions/4334e2e871a4_add_sessions_ttl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import sqlalchemy as sa
from alembic import op
from dishka import AsyncContainer

# revision identifiers, used by Alembic.
revision = "4334e2e871a4"
Expand All @@ -16,7 +17,7 @@
depends_on: None | str = None


def upgrade() -> None:
def upgrade(container: AsyncContainer) -> None: # noqa: ARG001
"""Upgrade."""
op.add_column(
"Policies",
Expand All @@ -38,7 +39,7 @@ def upgrade() -> None:
)


def downgrade() -> None:
def downgrade(container: AsyncContainer) -> None: # noqa: ARG001
"""Downgrade."""
op.drop_column("Policies", "http_session_ttl")
op.drop_column("Policies", "ldap_session_ttl")
5 changes: 3 additions & 2 deletions app/alembic/versions/4442d1d982a4_remove_krb_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""

from alembic import op
from dishka import AsyncContainer
from sqlalchemy import delete
from sqlalchemy.orm import Session

Expand All @@ -21,13 +22,13 @@


@temporary_stub_entity_type_name
def upgrade() -> None:
def upgrade(container: AsyncContainer) -> None: # noqa: ARG001
"""Upgrade."""
bind = op.get_bind()
session = Session(bind=bind)
session.execute(delete(Directory).filter_by(name="default_policy"))
session.execute(delete(Attribute).filter_by(name="krbpwdpolicyreference"))


def downgrade() -> None:
def downgrade(container: AsyncContainer) -> None:
"""Downgrade."""
Loading