diff --git a/app/alembic/versions/01f3f05a5b11_add_primary_group_id.py b/app/alembic/versions/01f3f05a5b11_add_primary_group_id.py index 237a558ab..a3146a444 100644 --- a/app/alembic/versions/01f3f05a5b11_add_primary_group_id.py +++ b/app/alembic/versions/01f3f05a5b11_add_primary_group_id.py @@ -13,6 +13,10 @@ from sqlalchemy.orm import Session, selectinload from entities import Attribute, Directory, EntityType, Group +from enums import EntityTypeNames +from ldap_protocol.ldap_schema.attribute_value_validator import ( + AttributeValueValidator, +) from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO from ldap_protocol.ldap_schema.object_class_dao import ObjectClassDAO from ldap_protocol.roles.ace_dao import AccessControlEntryDAO @@ -48,6 +52,7 @@ async def _add_domain_computers_group(connection: AsyncConnection) -> None: entity_type_dao = EntityTypeDAO( session, object_class_dao=object_class_dao, + attribute_value_validator=AttributeValueValidator(), ) role_dao = RoleDAO(session) ace_dao = AccessControlEntryDAO(session) @@ -66,12 +71,15 @@ async def _add_domain_computers_group(connection: AsyncConnection) -> None: dir_, group_ = await create_group( name="domain computers", sid=515, + attribute_value_validator=AttributeValueValidator(), session=session, ) await session.flush() - computer_entity_type = await entity_type_dao.get("Computer") + computer_entity_type = await entity_type_dao.get( + EntityTypeNames.COMPUTER, + ) computer_dirs = await session.scalars( select(Directory) .where( @@ -126,7 +134,11 @@ async def _add_primary_group_id(connection: AsyncConnection) -> None: entity_type = await session.scalars( select(qa(EntityType.id)) - .where(qa(EntityType.name).in_(["User", "Computer"])), + .where( + qa(EntityType.name).in_( + [EntityTypeNames.USER, EntityTypeNames.COMPUTER], + ), + ), ) # fmt: skip entity_type_ids = list(entity_type.all()) diff --git a/app/alembic/versions/692ae64e0cc5_.py b/app/alembic/versions/692ae64e0cc5_.py index e2cef3b6b..fdc40696b 100755 --- a/app/alembic/versions/692ae64e0cc5_.py +++ b/app/alembic/versions/692ae64e0cc5_.py @@ -17,21 +17,17 @@ def upgrade() -> None: """Upgrade.""" - # ### commands auto generated by Alembic - please adjust! ### op.create_unique_constraint( "group_policy_uc", "GroupAccessPolicyMemberships", ["group_id", "policy_id"], ) - # ### end Alembic commands ### def downgrade() -> None: """Downgrade.""" - # ### commands auto generated by Alembic - please adjust! ### op.drop_constraint( "group_policy_uc", "GroupAccessPolicyMemberships", type_="unique", ) - # ### end Alembic commands ### diff --git a/app/alembic/versions/8164b4a9e1f1_add_ou_computers.py b/app/alembic/versions/8164b4a9e1f1_add_ou_computers.py index 8d6ab8b18..fc74bba3d 100644 --- a/app/alembic/versions/8164b4a9e1f1_add_ou_computers.py +++ b/app/alembic/versions/8164b4a9e1f1_add_ou_computers.py @@ -11,6 +11,9 @@ from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession from entities import Directory +from ldap_protocol.ldap_schema.attribute_value_validator import ( + AttributeValueValidator, +) from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO from ldap_protocol.ldap_schema.object_class_dao import ObjectClassDAO from ldap_protocol.roles.ace_dao import AccessControlEntryDAO @@ -43,11 +46,17 @@ async def _create_ou_computers(connection: AsyncConnection) -> None: session = AsyncSession(bind=connection) await session.begin() object_class_dao = ObjectClassDAO(session) - entity_type_dao = EntityTypeDAO(session, object_class_dao) + attribute_value_validator = AttributeValueValidator() + entity_type_dao = EntityTypeDAO( + session, + object_class_dao, + attribute_value_validator=attribute_value_validator, + ) setup_gateway = SetupGateway( session, PasswordUtils(), entity_type_dao, + attribute_value_validator=attribute_value_validator, ) base_directories = await get_base_directories(session) diff --git a/app/alembic/versions/ba78cef9700a_initial_entity_type.py b/app/alembic/versions/ba78cef9700a_initial_entity_type.py index 1e46ed611..e31c3958c 100644 --- a/app/alembic/versions/ba78cef9700a_initial_entity_type.py +++ b/app/alembic/versions/ba78cef9700a_initial_entity_type.py @@ -15,6 +15,9 @@ from constants import ENTITY_TYPE_DATAS from entities import Attribute, Directory, User from extra.alembic_utils import temporary_stub_entity_type_name +from ldap_protocol.ldap_schema.attribute_value_validator import ( + AttributeValueValidator, +) from ldap_protocol.ldap_schema.dto import EntityTypeDTO from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO from ldap_protocol.ldap_schema.entity_type_use_case import EntityTypeUseCase @@ -107,6 +110,7 @@ async def _create_entity_types(connection: AsyncConnection) -> None: entity_type_dao = EntityTypeDAO( session, object_class_dao=object_class_dao, + attribute_value_validator=AttributeValueValidator(), ) entity_type_use_case = EntityTypeUseCase( entity_type_dao, @@ -116,8 +120,8 @@ async def _create_entity_types(connection: AsyncConnection) -> None: for entity_type_data in ENTITY_TYPE_DATAS: await entity_type_use_case.create( EntityTypeDTO( - name=entity_type_data["name"], # type: ignore - object_class_names=entity_type_data["object_class_names"], # type: ignore + name=entity_type_data["name"], + object_class_names=entity_type_data["object_class_names"], is_system=True, ), ) @@ -175,6 +179,7 @@ async def _attach_entity_type_to_directories( entity_type_dao = EntityTypeDAO( session, object_class_dao=object_class_dao, + attribute_value_validator=AttributeValueValidator(), ) await entity_type_dao.attach_entity_type_to_directories() diff --git a/app/alembic/versions/c4888c68e221_fix_admin_attr_and_policy.py b/app/alembic/versions/c4888c68e221_fix_admin_attr_and_policy.py index bbd904958..ce182452a 100644 --- a/app/alembic/versions/c4888c68e221_fix_admin_attr_and_policy.py +++ b/app/alembic/versions/c4888c68e221_fix_admin_attr_and_policy.py @@ -12,6 +12,9 @@ from sqlalchemy.orm import joinedload from entities import Attribute, Directory, NetworkPolicy +from ldap_protocol.ldap_schema.attribute_value_validator import ( + AttributeValueValidator, +) from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO from ldap_protocol.ldap_schema.object_class_dao import ObjectClassDAO from ldap_protocol.utils.helpers import create_integer_hash @@ -43,6 +46,7 @@ async def _attach_entity_type_to_directories( entity_type_dao = EntityTypeDAO( session, object_class_dao=object_class_dao, + attribute_value_validator=AttributeValueValidator(), ) await entity_type_dao.attach_entity_type_to_directories() await session.commit() diff --git a/app/alembic/versions/f1abf7ef2443_add_container_object_class.py b/app/alembic/versions/f1abf7ef2443_add_container_object_class.py index d879bd591..fa9ef428c 100644 --- a/app/alembic/versions/f1abf7ef2443_add_container_object_class.py +++ b/app/alembic/versions/f1abf7ef2443_add_container_object_class.py @@ -11,6 +11,7 @@ from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession from entities import Attribute, Directory, EntityType +from enums import EntityTypeNames from repo.pg.tables import queryable_attr as qa # revision identifiers, used by Alembic. @@ -39,7 +40,7 @@ async def _migrate_ou_to_cn_containers( ) entity_type = await session.scalar( select(EntityType) - .where(qa(EntityType.name) == "Container"), + .where(qa(EntityType.name) == EntityTypeNames.CONTAINER), ) # fmt: skip for directory in directories: @@ -124,7 +125,7 @@ async def _migrate_cn_to_ou_containers( ) entity_type = await session.scalar( select(EntityType) - .where(qa(EntityType.name) == "Organizational Unit"), + .where(qa(EntityType.name) == EntityTypeNames.ORGANIZATIONAL_UNIT), ) # fmt: skip for directory in directories: diff --git a/app/alembic/versions/fafc3d0b11ec_.py b/app/alembic/versions/fafc3d0b11ec_.py index 40ce32bab..ba9c04b39 100644 --- a/app/alembic/versions/fafc3d0b11ec_.py +++ b/app/alembic/versions/fafc3d0b11ec_.py @@ -13,6 +13,9 @@ from entities import Directory from extra.alembic_utils import temporary_stub_entity_type_name +from ldap_protocol.ldap_schema.attribute_value_validator import ( + AttributeValueValidator, +) from ldap_protocol.utils.queries import ( create_group, get_base_directories, @@ -52,6 +55,7 @@ async def _create_readonly_grp_and_plcy( dir_, _ = await create_group( name="readonly domain controllers", sid=521, + attribute_value_validator=AttributeValueValidator(), session=session, ) diff --git a/app/api/ldap_schema/schema.py b/app/api/ldap_schema/schema.py index 86100ed05..9e6453eff 100644 --- a/app/api/ldap_schema/schema.py +++ b/app/api/ldap_schema/schema.py @@ -8,7 +8,7 @@ from pydantic import BaseModel, Field -from enums import KindType +from enums import EntityTypeNames, KindType from ldap_protocol.ldap_schema.constants import ( DEFAULT_ENTITY_TYPE_IS_SYSTEM, OID_REGEX_PATTERN, @@ -82,7 +82,7 @@ class EntityTypeSchema(BaseModel, Generic[_IdT]): """Entity Type Schema.""" id: _IdT = Field(default=None) # type: ignore[assignment] - name: str + name: EntityTypeNames | str is_system: bool object_class_names: list[str] = Field( default_factory=list, diff --git a/app/constants.py b/app/constants.py index 136335ffc..077d35ec1 100644 --- a/app/constants.py +++ b/app/constants.py @@ -4,6 +4,10 @@ License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE """ +from typing import TypedDict + +from enums import EntityTypeNames + group_attrs = { "objectClass": ["top"], "groupType": ["-2147483646"], @@ -207,24 +211,37 @@ ] -ENTITY_TYPE_DATAS = [ - { - "name": "Domain", - "object_class_names": ["top", "domain", "domainDNS"], - }, - {"name": "Computer", "object_class_names": ["top", "computer"]}, - {"name": "Container", "object_class_names": ["top", "container"]}, - { - "name": "Organizational Unit", - "object_class_names": ["top", "container", "organizationalUnit"], - }, - { - "name": "Group", - "object_class_names": ["top", "group", "posixGroup"], - }, - { - "name": "User", - "object_class_names": [ +class EntityTypeData(TypedDict): + """Entity Type data.""" + + name: EntityTypeNames + object_class_names: list[str] + + +ENTITY_TYPE_DATAS: tuple[EntityTypeData, ...] = ( + EntityTypeData( + name=EntityTypeNames.DOMAIN, + object_class_names=["top", "domain", "domainDNS"], + ), + EntityTypeData( + name=EntityTypeNames.COMPUTER, + object_class_names=["top", "computer"], + ), + EntityTypeData( + name=EntityTypeNames.CONTAINER, + object_class_names=["top", "container"], + ), + EntityTypeData( + name=EntityTypeNames.ORGANIZATIONAL_UNIT, + object_class_names=["top", "container", "organizationalUnit"], + ), + EntityTypeData( + name=EntityTypeNames.GROUP, + object_class_names=["top", "group", "posixGroup"], + ), + EntityTypeData( + name=EntityTypeNames.USER, + object_class_names=[ "top", "user", "person", @@ -233,28 +250,25 @@ "shadowAccount", "inetOrgPerson", ], - }, - {"name": "KRB Container", "object_class_names": ["krbContainer"]}, - { - "name": "KRB Principal", - "object_class_names": [ + ), + EntityTypeData( + name=EntityTypeNames.KRB_CONTAINER, + object_class_names=["krbContainer"], + ), + EntityTypeData( + name=EntityTypeNames.KRB_PRINCIPAL, + object_class_names=[ "krbprincipal", "krbprincipalaux", "krbTicketPolicyAux", ], - }, - { - "name": "KRB Realm Container", - "object_class_names": [ - "top", - "krbrealmcontainer", - "krbticketpolicyaux", - ], - }, -] -PRIMARY_ENTITY_TYPE_NAMES = { - entity_type_data["name"] for entity_type_data in ENTITY_TYPE_DATAS -} + ), + EntityTypeData( + name=EntityTypeNames.KRB_REALM_CONTAINER, + object_class_names=["top", "krbrealmcontainer", "krbticketpolicyaux"], + ), +) + FIRST_SETUP_DATA = [ { diff --git a/app/enums.py b/app/enums.py index fb95e3a2f..bd6c59bd0 100644 --- a/app/enums.py +++ b/app/enums.py @@ -45,6 +45,24 @@ class MFAChallengeStatuses(StrEnum): PENDING = "pending" +class EntityTypeNames(StrEnum): + """Enum of base (system) Entity Types. + + Used for system objects. + Custom Entity Types aren't included here. + """ + + DOMAIN = "Domain" + COMPUTER = "Computer" + CONTAINER = "Container" + ORGANIZATIONAL_UNIT = "Organizational Unit" + GROUP = "Group" + USER = "User" + KRB_CONTAINER = "KRB Container" + KRB_PRINCIPAL = "KRB Principal" + KRB_REALM_CONTAINER = "KRB Realm Container" + + class KindType(StrEnum): """Object kind types.""" diff --git a/app/extra/scripts/uac_sync.py b/app/extra/scripts/uac_sync.py index dd7b8514d..f0623e1b1 100644 --- a/app/extra/scripts/uac_sync.py +++ b/app/extra/scripts/uac_sync.py @@ -49,12 +49,10 @@ async def disable_accounts( String, ) conditions = [ - ( - cast(Attribute.value, Integer).op("&")( - UserAccountControlFlag.ACCOUNTDISABLE, - ) - == 0 - ), + cast(Attribute.value, Integer).op("&")( + UserAccountControlFlag.ACCOUNTDISABLE, + ) + == 0, qa(Attribute.directory_id).in_(subquery), qa(Attribute.name) == "userAccountControl", ] diff --git a/app/ioc.py b/app/ioc.py index 2c733e6a8..cd74d77f3 100644 --- a/app/ioc.py +++ b/app/ioc.py @@ -78,6 +78,9 @@ from ldap_protocol.ldap_schema.attribute_type_use_case import ( AttributeTypeUseCase, ) +from ldap_protocol.ldap_schema.attribute_value_validator import ( + AttributeValueValidator, +) from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO from ldap_protocol.ldap_schema.entity_type_use_case import EntityTypeUseCase from ldap_protocol.ldap_schema.object_class_dao import ObjectClassDAO @@ -416,6 +419,10 @@ async def get_dhcp_mngr( kea_dhcp_repository=dhcp_api_repository, ) + attribute_value_validator = provide( + AttributeValueValidator, + scope=Scope.RUNTIME, + ) attribute_type_dao = provide(AttributeTypeDAO, scope=Scope.REQUEST) object_class_dao = provide(ObjectClassDAO, scope=Scope.REQUEST) entity_type_dao = provide(EntityTypeDAO, scope=Scope.REQUEST) diff --git a/app/ldap_protocol/auth/setup_gateway.py b/app/ldap_protocol/auth/setup_gateway.py index d294eb0c3..d57b6a95f 100644 --- a/app/ldap_protocol/auth/setup_gateway.py +++ b/app/ldap_protocol/auth/setup_gateway.py @@ -12,6 +12,9 @@ from sqlalchemy.ext.asyncio import AsyncSession from entities import Attribute, Directory, Group, NetworkPolicy, User +from ldap_protocol.ldap_schema.attribute_value_validator import ( + AttributeValueValidator, +) from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO from ldap_protocol.utils.helpers import create_object_sid, generate_domain_sid from ldap_protocol.utils.queries import get_domain_object_class @@ -27,6 +30,7 @@ def __init__( session: AsyncSession, password_utils: PasswordUtils, entity_type_dao: EntityTypeDAO, + attribute_value_validator: AttributeValueValidator, ) -> None: """Initialize Setup use case. @@ -37,6 +41,7 @@ def __init__( self._session = session self._password_utils = password_utils self._entity_type_dao = entity_type_dao + self._attribute_value_validator = attribute_value_validator async def is_setup(self) -> bool: """Check if setup is performed. @@ -44,8 +49,9 @@ async def is_setup(self) -> bool: :return: bool (True if setup is performed, False otherwise) """ query = select( - exists(Directory).where(qa(Directory.parent_id).is_(None)), - ) + exists(Directory) + .where(qa(Directory.parent_id).is_(None)), + ) # fmt: skip retval = await self._session.scalars(query) return retval.one() @@ -61,10 +67,7 @@ async def setup_enviroment( logger.warning("dev data already set up") return - domain = Directory( - name=dn, - object_class="domain", - ) + domain = Directory(name=dn, object_class="domain") domain.object_sid = generate_domain_sid() domain.path = [f"dc={path}" for path in reversed(dn.split("."))] domain.depth = len(domain.path) @@ -94,6 +97,10 @@ async def setup_enviroment( directory=domain, is_system_entity_type=True, ) + if not self._attribute_value_validator.is_directory_valid(domain): + raise ValueError( + "Invalid directory attribute values during environment setup", # noqa: E501 + ) await self._session.flush() try: @@ -199,13 +206,15 @@ async def create_dir( await self._session.refresh( instance=dir_, - attribute_names=["attributes"], + attribute_names=["attributes", "user"], with_for_update=None, ) await self._entity_type_dao.attach_entity_type_to_directory( directory=dir_, is_system_entity_type=True, ) + if not self._attribute_value_validator.is_directory_valid(dir_): + raise ValueError("Invalid directory attribute values") await self._session.flush() if "children" in data: diff --git a/app/ldap_protocol/kerberos/utils.py b/app/ldap_protocol/kerberos/utils.py index 1f43443a3..518169475 100644 --- a/app/ldap_protocol/kerberos/utils.py +++ b/app/ldap_protocol/kerberos/utils.py @@ -9,7 +9,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from entities import Attribute, CatalogueSetting, Directory, EntityType -from enums import StrEnum +from enums import EntityTypeNames, StrEnum from repo.pg.tables import queryable_attr as qa from .exceptions import KRBAPIConnectionError, KRBAPIError @@ -122,7 +122,7 @@ async def unlock_principal(name: str, session: AsyncSession) -> None: .outerjoin(qa(Directory.entity_type)) .where( qa(Directory.name).ilike(name), - qa(EntityType.name) == "KRB Principal", + qa(EntityType.name) == EntityTypeNames.KRB_PRINCIPAL, ) .scalar_subquery() ) diff --git a/app/ldap_protocol/ldap_requests/add.py b/app/ldap_protocol/ldap_requests/add.py index 4e15bc75d..dfab7674d 100644 --- a/app/ldap_protocol/ldap_requests/add.py +++ b/app/ldap_protocol/ldap_requests/add.py @@ -11,7 +11,7 @@ from sqlalchemy.exc import IntegrityError from entities import Attribute, Directory, Group, User -from enums import AceType +from enums import AceType, EntityTypeNames from ldap_protocol.asn1parser import ASN1Row from ldap_protocol.kerberos.exceptions import ( KRBAPIAddPrincipalError, @@ -158,10 +158,21 @@ async def handle( # noqa: C901 object_class_names=self.object_class_names, ) ) - if entity_type and entity_type.name == "Container": + if entity_type and entity_type.name == EntityTypeNames.CONTAINER: yield AddResponse(result_code=LDAPCodes.INSUFFICIENT_ACCESS_RIGHTS) return + if not ctx.attribute_value_validator.is_value_valid( + entity_type.name if entity_type else "", + "name", + name, + ): + yield AddResponse( + result_code=LDAPCodes.UNDEFINED_ATTRIBUTE_TYPE, + errorMessage="Invalid attribute value(s)", + ) + return + can_add = ctx.access_manager.check_entity_level_access( aces=parent.access_control_entries, entity_type_id=entity_type.id if entity_type else None, @@ -399,10 +410,22 @@ async def handle( # noqa: C901 ), ) + if not ctx.attribute_value_validator.is_directory_attributes_valid( + entity_type.name if entity_type else "", + attributes, + ) or (user and not ctx.attribute_value_validator.is_user_valid(user)): + await ctx.session.rollback() + yield AddResponse( + result_code=LDAPCodes.UNDEFINED_ATTRIBUTE_TYPE, + errorMessage="Invalid attribute value(s)", + ) + return + try: items_to_add.extend(attributes) ctx.session.add_all(items_to_add) await ctx.session.flush() + await ctx.entity_type_dao.attach_entity_type_to_directory( directory=new_dir, is_system_entity_type=False, diff --git a/app/ldap_protocol/ldap_requests/contexts.py b/app/ldap_protocol/ldap_requests/contexts.py index df4918f8d..469939516 100644 --- a/app/ldap_protocol/ldap_requests/contexts.py +++ b/app/ldap_protocol/ldap_requests/contexts.py @@ -11,6 +11,9 @@ from config import Settings from ldap_protocol.dialogue import LDAPSession from ldap_protocol.kerberos import AbstractKadmin +from ldap_protocol.ldap_schema.attribute_value_validator import ( + AttributeValueValidator, +) from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO from ldap_protocol.multifactor import LDAPMultiFactorAPI from ldap_protocol.policies.password import PasswordPolicyUseCases @@ -33,6 +36,7 @@ class LDAPAddRequestContext: password_utils: PasswordUtils access_manager: AccessManager role_use_case: RoleUseCase + attribute_value_validator: AttributeValueValidator @dataclass @@ -48,6 +52,7 @@ class LDAPModifyRequestContext: access_manager: AccessManager password_use_cases: PasswordPolicyUseCases password_utils: PasswordUtils + attribute_value_validator: AttributeValueValidator @dataclass @@ -115,3 +120,4 @@ class LDAPModifyDNRequestContext: entity_type_dao: EntityTypeDAO access_manager: AccessManager role_use_case: RoleUseCase + attribute_value_validator: AttributeValueValidator diff --git a/app/ldap_protocol/ldap_requests/modify.py b/app/ldap_protocol/ldap_requests/modify.py index 5161ae754..7334ba225 100644 --- a/app/ldap_protocol/ldap_requests/modify.py +++ b/app/ldap_protocol/ldap_requests/modify.py @@ -14,9 +14,8 @@ from sqlalchemy.orm import joinedload, selectinload from config import Settings -from constants import PRIMARY_ENTITY_TYPE_NAMES from entities import Attribute, Directory, Group, User -from enums import AceType +from enums import AceType, EntityTypeNames from ldap_protocol.asn1parser import ASN1Row from ldap_protocol.dialogue import UserSchema from ldap_protocol.kerberos import AbstractKadmin, unlock_principal @@ -193,9 +192,7 @@ async def handle( names = {change.get_name() for change in self.changes} - password_change_requested = self._check_password_change_requested( - names, - ) + password_change_requested = self._is_password_change_requested(names) self_modify = directory.id == ctx.ldap_session.user.directory_id if ( @@ -210,7 +207,7 @@ async def handle( return before_attrs = self.get_directory_attrs(directory) - + entity_type = directory.entity_type try: if not can_modify and not ( password_change_requested and self_modify @@ -224,6 +221,17 @@ async def handle( if change.modification.type.lower() in Directory.ro_fields: continue + if not ctx.attribute_value_validator.is_partial_attribute_valid( # noqa: E501 + entity_type.name if entity_type else "", + change.modification, + ): + await ctx.session.rollback() + yield ModifyResponse( + result_code=LDAPCodes.UNDEFINED_ATTRIBUTE_TYPE, + message="Invalid attribute value(s)", + ) + return + await self._update_password_expiration( change, directory.user, @@ -289,8 +297,10 @@ async def handle( directory=directory, is_system_entity_type=False, ) + await ctx.session.commit() yield ModifyResponse(result_code=LDAPCodes.SUCCESS) + finally: query = self._get_dir_query() directory = await ctx.session.scalar(query) @@ -351,7 +361,7 @@ def _get_dir_query(self) -> Select[tuple[Directory]]: .filter(get_filter_from_path(self.object)) ) - def _check_password_change_requested( + def _is_password_change_requested( self, names: set[str], ) -> bool: @@ -591,7 +601,7 @@ async def _validate_object_class_modification( ) -> None: if not ( directory.entity_type - and directory.entity_type.name in PRIMARY_ENTITY_TYPE_NAMES + and directory.entity_type.name in EntityTypeNames ): return diff --git a/app/ldap_protocol/ldap_requests/modify_dn.py b/app/ldap_protocol/ldap_requests/modify_dn.py index d17120540..c1ff681d3 100644 --- a/app/ldap_protocol/ldap_requests/modify_dn.py +++ b/app/ldap_protocol/ldap_requests/modify_dn.py @@ -8,7 +8,7 @@ from sqlalchemy import delete, func, select, text, update from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import selectinload +from sqlalchemy.orm import joinedload, selectinload from entities import AccessControlEntry, Attribute, Directory from enums import AceType @@ -112,7 +112,8 @@ async def handle( query = ( select(Directory) .options( - selectinload(qa(Directory.parent)), + joinedload(qa(Directory.parent)), + joinedload(qa(Directory.entity_type)), ) .filter(get_filter_from_path(self.entry)) ) @@ -143,6 +144,21 @@ async def handle( old_depth = directory.depth + if ( + directory.entity_type + and not ctx.attribute_value_validator.is_value_valid( + entity_type_name=directory.entity_type.name, + attr_name="name", + attr_value=new_name, + ) + ): + await ctx.session.rollback() + yield ModifyDNResponse( + result_code=LDAPCodes.UNDEFINED_ATTRIBUTE_TYPE, + message="Invalid attribute value(s)", + ) + return + if ( self.new_superior and directory.parent diff --git a/app/ldap_protocol/ldap_schema/attribute_value_validator.py b/app/ldap_protocol/ldap_schema/attribute_value_validator.py new file mode 100644 index 000000000..9c7c77f56 --- /dev/null +++ b/app/ldap_protocol/ldap_schema/attribute_value_validator.py @@ -0,0 +1,270 @@ +"""Attribute Value Validator. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +import re +from collections import defaultdict +from typing import Callable, cast as tcast + +from entities import Attribute, Directory, User +from enums import EntityTypeNames +from ldap_protocol.objects import PartialAttribute + +type _AttrNameType = str +type _ValueType = str +type _ValueValidatorType = Callable[[_ValueType], bool] +type _CompiledValidatorsType = dict[ + EntityTypeNames, + dict[_AttrNameType, _ValueValidatorType], +] + + +class AttributeValueValidatorError(Exception): ... + + +# NOTE: Not validate `distinguishedName`, `member` and `memberOf` attributes, +# because it doesn't exist. +_ENTITY_NAME_AND_ATTR_NAME_VALIDATION_MAP: dict[ + tuple[EntityTypeNames, _AttrNameType], + tuple[str, ...], +] = { + (EntityTypeNames.ORGANIZATIONAL_UNIT, "name"): ( + "not_start_with_space", + "not_start_with_hash", + "not_end_with_space", + "not_contains_symbols", + ), + (EntityTypeNames.GROUP, "name"): ( + "not_start_with_space", + "not_start_with_hash", + "not_end_with_space", + "not_contains_symbols", + ), + (EntityTypeNames.USER, "name"): ( + "not_start_with_space", + "not_start_with_hash", + "not_end_with_space", + "not_contains_symbols", + ), + (EntityTypeNames.USER, "sAMAccountName"): ( + "not_contains_symbols_ext", + "not_end_with_dot", + "not_contains_control_characters", + "not_contains_at", + ), + (EntityTypeNames.COMPUTER, "name"): ( + "not_start_with_space", + "not_start_with_hash", + "not_end_with_space", + "not_contains_symbols", + ), + (EntityTypeNames.COMPUTER, "sAMAccountName"): ( + "not_contains_symbols_ext", + "not_end_with_dot", + "not_contains_control_characters", + "not_contains_spaces_and_dots", + "not_only_numbers", + "not_start_with_number", + ), +} + + +class _ValValidators: + @staticmethod + def not_start_with_space(value: _ValueType) -> bool: + return not value.startswith(" ") + + @staticmethod + def not_only_numbers(value: _ValueType) -> bool: + return not value.isdigit() + + @staticmethod + def not_contains_at(value: _ValueType) -> bool: + return "@" not in value + + @staticmethod + def not_start_with_number(value: _ValueType) -> bool: + return bool(value and not value[0].isdigit()) + + @staticmethod + def not_start_with_hash(value: _ValueType) -> bool: + return not value.startswith("#") + + @staticmethod + def not_end_with_space(value: _ValueType) -> bool: + return not value.endswith(" ") + + @staticmethod + def not_contains_control_characters(value: _ValueType) -> bool: + return all(ord(char) >= 32 and ord(char) != 127 for char in value) + + @staticmethod + def not_contains_spaces_and_dots(value: _ValueType) -> bool: + return " " not in value and "." not in value + + @staticmethod + def not_contains_symbols(value: _ValueType) -> bool: + return not re.search(r'[,+"\\<>;=]', value) + + @staticmethod + def not_contains_symbols_ext(value: _ValueType) -> bool: + return not re.search(r'["/\\\[\]:;\|=,\+\*\?<>]', value) + + @staticmethod + def not_end_with_dot(value: _ValueType) -> bool: + return not value.endswith(".") + + +class AttributeValueValidator: + _compiled_validators: _CompiledValidatorsType + + def __init__(self) -> None: + self._compiled_validators: _CompiledValidatorsType = ( + self.__compile_validators() + ) + + def __compile_validators(self) -> _CompiledValidatorsType: + res: _CompiledValidatorsType = defaultdict(dict) + + for ( + key, + validator_names, + ) in _ENTITY_NAME_AND_ATTR_NAME_VALIDATION_MAP.items(): + validators = [getattr(_ValValidators, n) for n in validator_names] + res[key[0]][key[1]] = self.__create_combined_validator(validators) + + return res + + def __create_combined_validator( + self, + funcs: list[_ValueValidatorType], + ) -> _ValueValidatorType: + def combined(value: _ValueType) -> bool: + return all(func(value) for func in funcs) + + return combined + + def _get_subset_validators( + self, + entity_type_name: EntityTypeNames | str, + ) -> dict[_AttrNameType, _ValueValidatorType] | None: + if entity_type_name in self._compiled_validators: + entity_type_name = tcast("EntityTypeNames", entity_type_name) + else: + return None + return self._compiled_validators.get(entity_type_name) + + def _get_validator( + self, + entity_type_name: EntityTypeNames | str, + attr_name: str, + ) -> _ValueValidatorType | None: + subset_validators = self._get_subset_validators(entity_type_name) + return subset_validators.get(attr_name) if subset_validators else None + + def is_value_valid( + self, + entity_type_name: EntityTypeNames | str, + attr_name: _AttrNameType, + attr_value: _ValueType, + ) -> bool: + validator = self._get_validator(entity_type_name, attr_name) + + if not validator: + return True + + return validator(attr_value) + + def is_partial_attribute_valid( + self, + entity_type_name: EntityTypeNames | str, + partial_attribute: PartialAttribute, + ) -> bool: + validator = self._get_validator( + entity_type_name, + partial_attribute.type, + ) + + if not validator: + return True + + for value in partial_attribute.vals: + if isinstance(value, str) and not validator(value): + return False + + return True + + def is_directory_attributes_valid( + self, + entity_type_name: EntityTypeNames | str, + attributes: list[Attribute], + ) -> bool: + subset_validators = self._get_subset_validators(entity_type_name) + if not subset_validators: + return True + + for attribute in attributes: + if not attribute.value: + continue + + validator = subset_validators.get(attribute.name) + if not validator: + continue + + if not validator(attribute.value): + return False + + return True + + def is_directory_valid(self, directory: Directory) -> bool: + if not directory.entity_type: + raise AttributeValueValidatorError( + "Directory must have an entity type", + ) + + entity_type_name = directory.entity_type.name + + if entity_type_name and not self.is_value_valid( + entity_type_name, + "name", + directory.name, + ): + return False + + if entity_type_name == EntityTypeNames.USER: + if not directory.user: + raise AttributeValueValidatorError( + "User directory must have associated User", + ) + + if not self.is_user_valid(directory.user): + return False + + if not self.is_directory_attributes_valid( # noqa: SIM103 + entity_type_name, + directory.attributes, + ): + return False + + return True + + def is_user_valid(self, user: User) -> bool: + user_entity_type_name = EntityTypeNames.USER + + if not self.is_value_valid( + user_entity_type_name, + "sAMAccountName", + user.sam_account_name, + ): + return False + + if not self.is_value_valid( # noqa: SIM103 + user_entity_type_name, + "userPrincipalName", + user.user_principal_name, + ): + return False + + return True diff --git a/app/ldap_protocol/ldap_schema/dto.py b/app/ldap_protocol/ldap_schema/dto.py index 0430d9a2e..118a6e1e8 100644 --- a/app/ldap_protocol/ldap_schema/dto.py +++ b/app/ldap_protocol/ldap_schema/dto.py @@ -7,7 +7,7 @@ from dataclasses import dataclass, field from typing import Generic, TypeVar -from enums import KindType +from enums import EntityTypeNames, KindType _IdT = TypeVar("_IdT", int, None) @@ -49,7 +49,7 @@ class ObjectClassDTO(Generic[_IdT, _LinkT]): class EntityTypeDTO(Generic[_IdT]): """Entity Type DTO.""" - name: str + name: EntityTypeNames | str is_system: bool object_class_names: list[str] id: _IdT = None # type: ignore diff --git a/app/ldap_protocol/ldap_schema/entity_type_dao.py b/app/ldap_protocol/ldap_schema/entity_type_dao.py index 6e30a0989..abfdc49d1 100644 --- a/app/ldap_protocol/ldap_schema/entity_type_dao.py +++ b/app/ldap_protocol/ldap_schema/entity_type_dao.py @@ -16,6 +16,10 @@ from abstract_dao import AbstractDAO from entities import Attribute, Directory, EntityType, ObjectClass +from ldap_protocol.ldap_schema.attribute_value_validator import ( + AttributeValueValidator, + AttributeValueValidatorError, +) from ldap_protocol.ldap_schema.dto import EntityTypeDTO from ldap_protocol.ldap_schema.exceptions import ( EntityTypeAlreadyExistsError, @@ -44,15 +48,18 @@ class EntityTypeDAO(AbstractDAO[EntityTypeDTO, str]): __session: AsyncSession __object_class_dao: ObjectClassDAO + __attribute_value_validator: AttributeValueValidator def __init__( self, session: AsyncSession, object_class_dao: ObjectClassDAO, + attribute_value_validator: AttributeValueValidator, ) -> None: """Initialize Entity Type DAO with a database session.""" self.__session = session self.__object_class_dao = object_class_dao + self.__attribute_value_validator = attribute_value_validator async def get_all(self) -> list[EntityTypeDTO[int]]: """Get all Entity Types.""" @@ -120,11 +127,20 @@ async def update(self, _id: str, dto: EntityTypeDTO[int]) -> None: for directory in result.scalars(): for object_class_name in entity_type.object_class_names: + if not self.__attribute_value_validator.is_value_valid( + entity_type.name, + "objectClass", + object_class_name, + ): + raise AttributeValueValidatorError( + f"Invalid objectClass value '{object_class_name}' for entity type '{entity_type.name}'.", # noqa: E501 + ) + self.__session.add( Attribute( directory_id=directory.id, - value=object_class_name, name="objectClass", + value=object_class_name, ), ) diff --git a/app/ldap_protocol/ldap_schema/entity_type_use_case.py b/app/ldap_protocol/ldap_schema/entity_type_use_case.py index bb6fb2729..5958e6a99 100644 --- a/app/ldap_protocol/ldap_schema/entity_type_use_case.py +++ b/app/ldap_protocol/ldap_schema/entity_type_use_case.py @@ -7,8 +7,8 @@ from typing import ClassVar from abstract_service import AbstractService -from constants import ENTITY_TYPE_DATAS, PRIMARY_ENTITY_TYPE_NAMES -from enums import AuthorizationRules +from constants import ENTITY_TYPE_DATAS +from enums import AuthorizationRules, EntityTypeNames from ldap_protocol.ldap_schema.dto import EntityTypeDTO from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO from ldap_protocol.ldap_schema.exceptions import ( @@ -65,7 +65,7 @@ async def _validate_name( self, name: str, ) -> None: - if name in PRIMARY_ENTITY_TYPE_NAMES: + if name in EntityTypeNames: raise EntityTypeCantModifyError( f"Can't change entity type name {name}", ) @@ -93,7 +93,7 @@ async def create_for_first_setup(self) -> None: for entity_type_data in ENTITY_TYPE_DATAS: await self.create( EntityTypeDTO( - name=entity_type_data["name"], # type: ignore + name=entity_type_data["name"], object_class_names=list( entity_type_data["object_class_names"], ), diff --git a/app/ldap_protocol/policies/password/dao.py b/app/ldap_protocol/policies/password/dao.py index 95a759bf7..9a85d2ad7 100644 --- a/app/ldap_protocol/policies/password/dao.py +++ b/app/ldap_protocol/policies/password/dao.py @@ -14,6 +14,11 @@ from abstract_dao import AbstractDAO from entities import Attribute, Group, PasswordPolicy, User +from enums import EntityTypeNames +from ldap_protocol.ldap_schema.attribute_value_validator import ( + AttributeValueValidator, + AttributeValueValidatorError, +) from ldap_protocol.objects import UserAccountControlFlag as UacFlag from ldap_protocol.policies.password.exceptions import ( PasswordPolicyAlreadyExistsError, @@ -63,13 +68,16 @@ class PasswordPolicyDAO(AbstractDAO[PasswordPolicyDTO, int]): """Password Policy DAO.""" _session: AsyncSession + __attribute_value_validator: AttributeValueValidator def __init__( self, session: AsyncSession, + attribute_value_validator: AttributeValueValidator, ) -> None: """Initialize Password Policy DAO with a database session.""" self._session = session + self.__attribute_value_validator = attribute_value_validator async def _get_total_count(self) -> int: """Count all Password Policies.""" @@ -392,6 +400,13 @@ async def get_or_create_pwd_last_set( ) # fmt: skip if not plset_attribute: + if not self.__attribute_value_validator.is_value_valid( + EntityTypeNames.USER, + "pwdLastSet", + ft_now(), + ): + raise AttributeValueValidatorError("Invalid pwdLastSet value") + plset_attribute = Attribute( directory_id=directory_id, name="pwdLastSet", diff --git a/app/ldap_protocol/utils/queries.py b/app/ldap_protocol/utils/queries.py index 9db9b889e..a1f9243de 100644 --- a/app/ldap_protocol/utils/queries.py +++ b/app/ldap_protocol/utils/queries.py @@ -15,6 +15,10 @@ from sqlalchemy.sql.expression import ColumnElement from entities import Attribute, Directory, Group, User +from ldap_protocol.ldap_schema.attribute_value_validator import ( + AttributeValueValidator, + AttributeValueValidatorError, +) from repo.pg.tables import ( directory_memberships_table, directory_table, @@ -336,6 +340,7 @@ def get_domain_object_class(domain: Directory) -> Iterator[Attribute]: async def create_group( name: str, sid: int | None, + attribute_value_validator: AttributeValueValidator, session: AsyncSession, ) -> tuple[Directory, Group]: """Create group in default groups path. @@ -388,9 +393,18 @@ async def create_group( for name, attr in attributes.items(): for val in attr: session.add(Attribute(name=name, value=val, directory_id=dir_.id)) - await session.flush() - await session.refresh(dir_) + + await session.refresh( + instance=dir_, + attribute_names=["attributes", "user"], + with_for_update=None, + ) + if not attribute_value_validator.is_directory_valid(dir_): + raise AttributeValueValidatorError( + "Invalid directory attributes values", + ) + await session.refresh(group) return dir_, group @@ -496,22 +510,22 @@ async def set_or_update_primary_group( f"group '{group_dn}'.", ) - existing_attr = await session.scalar( - select(Attribute) - .filter_by( - name="primaryGroupID", - directory_id=directory.id, - ), - ) # fmt: skip + updated_attribute = await session.scalar( + update(Attribute) + .values(value=group.directory.relative_id) + .where( + qa(Attribute.name) == "primaryGroupID", + qa(Attribute.directory_id) == directory.id, + ) + .returning(qa(Attribute.directory_id)), + ) - if existing_attr: - existing_attr.value = group.directory.relative_id - else: + if not updated_attribute: session.add( Attribute( name="primaryGroupID", - value=group.directory.relative_id, directory_id=directory.id, + value=group.directory.relative_id, ), ) diff --git a/app/repo/pg/tables.py b/app/repo/pg/tables.py index db63e6e30..f59d1e758 100644 --- a/app/repo/pg/tables.py +++ b/app/repo/pg/tables.py @@ -668,6 +668,12 @@ def _compile_create_uc( "PasswordBanWords", metadata, Column("word", String(255), primary_key=True), + Index( + "idx_password_ban_words_word_gin_trgm", + "word", + postgresql_ops={"word": "gin_trgm_ops"}, + postgresql_using="gin", + ), ) dedicated_servers_table = Table( diff --git a/interface b/interface index 21b31fed4..bef67d7cb 160000 --- a/interface +++ b/interface @@ -1 +1 @@ -Subproject commit 21b31fed42a5082311a458da4d475c839f99a717 +Subproject commit bef67d7cbfcb16648a4c4ebc5a870f97bdda0856 diff --git a/tests/conftest.py b/tests/conftest.py index 58937dbd3..4906ccc1d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -100,6 +100,9 @@ from ldap_protocol.ldap_schema.attribute_type_use_case import ( AttributeTypeUseCase, ) +from ldap_protocol.ldap_schema.attribute_value_validator import ( + AttributeValueValidator, +) from ldap_protocol.ldap_schema.dto import EntityTypeDTO from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO from ldap_protocol.ldap_schema.entity_type_use_case import EntityTypeUseCase @@ -153,7 +156,7 @@ class TestProvider(Provider): __test__ = False scope = Scope.RUNTIME - settings = from_context(provides=Settings, scope=Scope.RUNTIME) + settings = from_context(provides=Settings, scope=scope) _cached_session: AsyncSession | None = None _cached_kadmin: Mock | None = None _cached_audit_service: Mock | None = None @@ -328,13 +331,13 @@ def get_object_class_dao(self, session: AsyncSession) -> ObjectClassDAO: PasswordBanWordsFastAPIAdapter, scope=Scope.REQUEST, ) - password_utils = provide(PasswordUtils, scope=Scope.RUNTIME) + password_utils = provide(PasswordUtils, scope=scope) dns_fastapi_adapter = provide(DNSFastAPIAdapter, scope=Scope.REQUEST) dns_use_case = provide(DNSUseCase, scope=Scope.REQUEST) dns_state_gateway = provide(DNSStateGateway, scope=Scope.REQUEST) - @provide(scope=Scope.RUNTIME, provides=AsyncEngine) + @provide(scope=scope, provides=AsyncEngine) def get_engine(self, settings: Settings) -> AsyncEngine: """Get async engine.""" return settings.engine @@ -518,6 +521,7 @@ def get_krb_template_render( audit_policy_dao = provide(AuditPoliciesDAO, scope=Scope.REQUEST) audit_use_case = provide(AuditUseCase, scope=Scope.REQUEST) audit_destination_dao = provide(AuditDestinationDAO, scope=Scope.REQUEST) + attribute_value_validator = provide(AttributeValueValidator, scope=scope) @provide(scope=Scope.REQUEST, provides=AuditService) async def get_audit_service(self) -> AsyncIterator[AsyncMock]: @@ -543,7 +547,7 @@ async def get_audit_service(self) -> AsyncIterator[AsyncMock]: audit_adapter = provide(AuditPoliciesAdapter, scope=Scope.REQUEST) - @provide(scope=Scope.RUNTIME) + @provide(scope=scope) async def get_audit_redis_client( self, settings: Settings, @@ -879,13 +883,18 @@ async def setup_session( ) -> None: """Get session and acquire after completion.""" object_class_dao = ObjectClassDAO(session) - entity_type_dao = EntityTypeDAO(session, object_class_dao=object_class_dao) + attribute_value_validator = AttributeValueValidator() + entity_type_dao = EntityTypeDAO( + session, + object_class_dao=object_class_dao, + attribute_value_validator=attribute_value_validator, + ) for entity_type_data in ENTITY_TYPE_DATAS: await entity_type_dao.create( dto=EntityTypeDTO( id=None, - name=entity_type_data["name"], # type: ignore - object_class_names=entity_type_data["object_class_names"], # type: ignore + name=entity_type_data["name"], + object_class_names=entity_type_data["object_class_names"], is_system=True, ), ) @@ -899,7 +908,10 @@ async def setup_session( audit_destination_dao, raw_audit_manager, ) - password_policy_dao = PasswordPolicyDAO(session) + password_policy_dao = PasswordPolicyDAO( + session, + attribute_value_validator=attribute_value_validator, + ) password_policy_validator = PasswordPolicyValidator( PasswordValidatorSettings(), password_utils, @@ -910,7 +922,12 @@ async def setup_session( password_policy_validator, password_ban_word_repository, ) - setup_gateway = SetupGateway(session, password_utils, entity_type_dao) + setup_gateway = SetupGateway( + session, + password_utils, + entity_type_dao, + attribute_value_validator=attribute_value_validator, + ) await audit_use_case.create_policies() await setup_gateway.setup_enviroment(dn="md.test", data=TEST_DATA) @@ -998,7 +1015,14 @@ async def entity_type_dao( async with container(scope=Scope.APP) as container: session = await container.get(AsyncSession) object_class_dao = ObjectClassDAO(session) - yield EntityTypeDAO(session, object_class_dao) + attribute_value_validator = await container.get( + AttributeValueValidator, + ) + yield EntityTypeDAO( + session, + object_class_dao, + attribute_value_validator=attribute_value_validator, + ) @pytest_asyncio.fixture(scope="function") @@ -1008,7 +1032,13 @@ async def password_policy_dao( """Get session and acquire after completion.""" async with container(scope=Scope.APP) as container: session = await container.get(AsyncSession) - yield PasswordPolicyDAO(session) + attribute_value_validator = await container.get( + AttributeValueValidator, + ) + yield PasswordPolicyDAO( + session, + attribute_value_validator=attribute_value_validator, + ) @pytest_asyncio.fixture(scope="function") diff --git a/tests/test_api/test_ldap_schema/test_entity_type_router.py b/tests/test_api/test_ldap_schema/test_entity_type_router.py index c72d8c991..58815160c 100644 --- a/tests/test_api/test_ldap_schema/test_entity_type_router.py +++ b/tests/test_api/test_ldap_schema/test_entity_type_router.py @@ -5,6 +5,7 @@ from httpx import AsyncClient from constants import ENTITY_TYPE_DATAS +from enums import EntityTypeNames from .test_entity_type_router_datasets import ( test_create_one_entity_type_dataset, @@ -274,7 +275,7 @@ async def test_delete_bulk_entries( @pytest.mark.usefixtures("session") async def test_delete_entry_with_directory(http_client: AsyncClient) -> None: """Test deleting entry with directory.""" - entity_type_name = "User" + entity_type_name = EntityTypeNames.USER response = await http_client.post( "/schema/entity_type/delete", json={"entity_type_names": [entity_type_name]}, diff --git a/tests/test_api/test_ldap_schema/test_object_class_router.py b/tests/test_api/test_ldap_schema/test_object_class_router.py index e371fea40..8db0340b2 100644 --- a/tests/test_api/test_ldap_schema/test_object_class_router.py +++ b/tests/test_api/test_ldap_schema/test_object_class_router.py @@ -5,6 +5,7 @@ from httpx import AsyncClient from api.ldap_schema.schema import ObjectClassUpdateSchema +from enums import EntityTypeNames from .test_object_class_router_datasets import ( test_create_one_object_class_dataset, @@ -25,7 +26,7 @@ async def test_get_one_extended_object_class( assert response.status_code == status.HTTP_200_OK data = response.json() assert isinstance(data, dict) - assert data.get("entity_type_names") == ["User"] + assert data.get("entity_type_names") == [EntityTypeNames.USER] @pytest.mark.parametrize( diff --git a/tests/test_api/test_main/test_router/test_add.py b/tests/test_api/test_main/test_router/test_add.py index 8efd2d98c..3050bedec 100644 --- a/tests/test_api/test_main/test_router/test_add.py +++ b/tests/test_api/test_main/test_router/test_add.py @@ -23,24 +23,87 @@ async def test_api_correct_add(http_client: AsyncClient) -> None: "entry": "cn=test,dc=md,dc=test", "password": None, "attributes": [ + {"type": "name", "vals": ["test"]}, + {"type": "cn", "vals": ["test"]}, + {"type": "objectClass", "vals": ["organization", "top"]}, { - "type": "name", - "vals": ["test"], + "type": "memberOf", + "vals": ["cn=domain admins,cn=groups,dc=md,dc=test"], }, + ], + }, + ) + + data = response.json() + + assert isinstance(data, dict) + assert response.status_code == status.HTTP_200_OK + assert data.get("resultCode") == LDAPCodes.SUCCESS + assert data.get("errorMessage") == "" + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("session") +async def test_api_add_incorrect_computer_name( + http_client: AsyncClient, +) -> None: + """Test api incorrect (name) add.""" + response = await http_client.post( + "/entry/add", + json={ + "entry": "cn=test,dc=md,dc=test", + "password": None, + "attributes": [ + {"type": "name", "vals": [" test;incorrect"]}, + {"type": "cn", "vals": ["test"]}, + {"type": "objectClass", "vals": ["computer", "top"]}, { - "type": "cn", - "vals": ["test"], + "type": "memberOf", + "vals": ["cn=domain admins,cn=groups,dc=md,dc=test"], }, + ], + }, + ) + + data = response.json() + + assert isinstance(data, dict) + assert data.get("resultCode") == LDAPCodes.UNDEFINED_ATTRIBUTE_TYPE + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("session") +async def test_api_add_incorrect_user_samaccount_with_dot( + http_client: AsyncClient, +) -> None: + """Test api incorrect (sAMAccountName) add.""" + un = "test0" + + response = await http_client.post( + "/entry/add", + json={ + "entry": "cn=test0,dc=md,dc=test", + "password": "P@ssw0rd", + "attributes": [ + {"type": "name", "vals": [un]}, + {"type": "cn", "vals": [un]}, { "type": "objectClass", - "vals": ["organization", "top"], - }, - { - "type": "memberOf", "vals": [ - "cn=domain admins,cn=groups,dc=md,dc=test", + "top", + "user", + "person", + "organizationalPerson", + "posixAccount", + "shadowAccount", + "inetOrgPerson", ], }, + {"type": "sAMAccountName", "vals": ["test0."]}, + {"type": "userPrincipalName", "vals": [f"{un}@md.ru"]}, + {"type": "mail", "vals": [f"{un}@md.ru"]}, + {"type": "displayName", "vals": [un]}, + {"type": "userAccountControl", "vals": ["516"]}, ], }, ) @@ -48,9 +111,7 @@ async def test_api_correct_add(http_client: AsyncClient) -> None: data = response.json() assert isinstance(data, dict) - assert response.status_code == status.HTTP_200_OK - assert data.get("resultCode") == LDAPCodes.SUCCESS - assert data.get("errorMessage") == "" + assert data.get("resultCode") == LDAPCodes.UNDEFINED_ATTRIBUTE_TYPE @pytest.mark.asyncio diff --git a/tests/test_api/test_main/test_router/test_search.py b/tests/test_api/test_main/test_router/test_search.py index 59f60221a..9df3898c5 100644 --- a/tests/test_api/test_main/test_router/test_search.py +++ b/tests/test_api/test_main/test_router/test_search.py @@ -7,6 +7,7 @@ import pytest from httpx import AsyncClient +from enums import EntityTypeNames from ldap_protocol.ldap_codes import LDAPCodes from tests.search_request_datasets import ( test_search_by_rule_anr_dataset, @@ -432,7 +433,7 @@ async def test_api_search_by_entity_type_name( http_client: AsyncClient, ) -> None: """Test api search by entity type name.""" - entity_type_name = "User" + entity_type_name = EntityTypeNames.USER raw_response = await http_client.post( "entry/search", @@ -471,7 +472,7 @@ async def test_api_empty_search( http_client: AsyncClient, ) -> None: """Test api empty search.""" - entity_type_name = "User" + entity_type_name = EntityTypeNames.USER raw_response = await http_client.post( "entry/search", json={ diff --git a/tests/test_ldap/test_attribute_value_validator.py b/tests/test_ldap/test_attribute_value_validator.py new file mode 100644 index 000000000..084ca0a72 --- /dev/null +++ b/tests/test_ldap/test_attribute_value_validator.py @@ -0,0 +1,589 @@ +"""Tests for AttributeValueValidator. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +import pytest + +from enums import EntityTypeNames +from ldap_protocol.ldap_schema.attribute_value_validator import ( + AttributeValueValidator, +) + + +@pytest.fixture +def validator() -> AttributeValueValidator: + """Create validator instance.""" + return AttributeValueValidator() + + +class TestOrganizationalUnitName: + """Tests for Organizational Unit name validation.""" + + _entity_type_name = EntityTypeNames.ORGANIZATIONAL_UNIT + + def test_valid_names(self, validator: AttributeValueValidator) -> None: + """Test valid organizational unit names.""" + valid_names = [ + "IT Department", + "Sales", + "Marketing-Team", + "HR_Department", + "Department123", + ] + for name in valid_names: + assert validator.is_value_valid( + self._entity_type_name, + "name", + name, + ) + + def test_invalid_names_with_space_at_start( + self, + validator: AttributeValueValidator, + ) -> None: + """Test organizational unit names starting with space.""" + invalid_names = [" IT", " Sales", " Marketing"] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "name", + name, + ) + + def test_invalid_names_with_hash_at_start( + self, + validator: AttributeValueValidator, + ) -> None: + """Test organizational unit names starting with hash.""" + invalid_names = ["#IT", "#Sales", "#Marketing"] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "name", + name, + ) + + def test_invalid_names_with_space_at_end( + self, + validator: AttributeValueValidator, + ) -> None: + """Test organizational unit names ending with space.""" + invalid_names = ["IT ", "Sales ", "Marketing "] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "name", + name, + ) + + def test_invalid_names_with_forbidden_symbols( + self, + validator: AttributeValueValidator, + ) -> None: + """Test organizational unit names with forbidden symbols.""" + invalid_names = [ + 'IT"Dept', + "Sales,Team", + "Marketing+", + "HR\\Group", + "Dept<1>", + "Team;A", + "Group=B", + ] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "name", + name, + ) + + +class TestGroupName: + """Tests for Group name validation.""" + + _entity_type_name = EntityTypeNames.GROUP + + def test_valid_names(self, validator: AttributeValueValidator) -> None: + """Test valid group names.""" + valid_names = [ + "Administrators", + "Users", + "Power_Users", + "Group-123", + "TeamA", + ] + for name in valid_names: + assert validator.is_value_valid( + self._entity_type_name, + "name", + name, + ) + + def test_invalid_names_with_space_at_start( + self, + validator: AttributeValueValidator, + ) -> None: + """Test group names starting with space.""" + invalid_names = [" Admins", " Users", " PowerUsers"] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "name", + name, + ) + + def test_invalid_names_with_hash_at_start( + self, + validator: AttributeValueValidator, + ) -> None: + """Test group names starting with hash.""" + invalid_names = ["#Admins", "#Users", "#PowerUsers"] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "name", + name, + ) + + def test_invalid_names_with_space_at_end( + self, + validator: AttributeValueValidator, + ) -> None: + """Test group names ending with space.""" + invalid_names = ["Admins ", "Users ", "PowerUsers "] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "name", + name, + ) + + def test_invalid_names_with_forbidden_symbols( + self, + validator: AttributeValueValidator, + ) -> None: + """Test group names with forbidden symbols.""" + invalid_names = [ + 'Admins"Group', + "Users,Team", + "Power+Users", + "Group\\A", + "Team<1>", + "Users;B", + "Group=C", + ] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "name", + name, + ) + + +class TestUserName: + """Tests for User name validation.""" + + _entity_type_name = EntityTypeNames.USER + + def test_valid_names(self, validator: AttributeValueValidator) -> None: + """Test valid user names.""" + valid_names = [ + "John Doe", + "Jane_Smith", + "User-123", + "Administrator", + "User.Name", + ] + for name in valid_names: + assert validator.is_value_valid( + self._entity_type_name, + "name", + name, + ) + + def test_invalid_names_with_space_at_start( + self, + validator: AttributeValueValidator, + ) -> None: + """Test user names starting with space.""" + invalid_names = [" JohnDoe", " Jane", " User123"] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "name", + name, + ) + + def test_invalid_names_with_hash_at_start( + self, + validator: AttributeValueValidator, + ) -> None: + """Test user names starting with hash.""" + invalid_names = ["#JohnDoe", "#Jane", "#User123"] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "name", + name, + ) + + def test_invalid_names_with_space_at_end( + self, + validator: AttributeValueValidator, + ) -> None: + """Test user names ending with space.""" + invalid_names = ["JohnDoe ", "Jane ", "User123 "] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "name", + name, + ) + + def test_invalid_names_with_forbidden_symbols( + self, + validator: AttributeValueValidator, + ) -> None: + """Test user names with forbidden symbols.""" + invalid_names = [ + 'John"Doe', + "Jane,Smith", + "User+123", + "Name\\Test", + "User<1>", + "John;Doe", + "User=Name", + ] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "name", + name, + ) + + +class TestUserSAMAccountName: + """Tests for User sAMAccountName validation.""" + + _entity_type_name = EntityTypeNames.USER + + def test_valid_sam_account_names( + self, + validator: AttributeValueValidator, + ) -> None: + """Test valid sAMAccountName values.""" + valid_names = [ + "jdoe", + "john.doe", + "user123", + "admin_user", + "test-user", + ] + for name in valid_names: + assert validator.is_value_valid( + self._entity_type_name, + "sAMAccountName", + name, + ) + + def test_invalid_sam_account_names_with_forbidden_symbols( + self, + validator: AttributeValueValidator, + ) -> None: + """Test sAMAccountName with forbidden symbols.""" + invalid_names = [ + 'user"name', + "user/name", + "user\\name", + "user[name]", + "user:name", + "user;name", + "user|name", + "user=name", + "user,name", + "user+name", + "user*name", + "user?name", + "user", + ] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "sAMAccountName", + name, + ) + + def test_invalid_sam_account_names_ending_with_dot( + self, + validator: AttributeValueValidator, + ) -> None: + """Test sAMAccountName ending with dot.""" + invalid_names = ["user.", "john.doe.", "admin."] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "sAMAccountName", + name, + ) + + def test_invalid_sam_account_names_with_control_chars( + self, + validator: AttributeValueValidator, + ) -> None: + """Test sAMAccountName with control characters.""" + invalid_names = [ + "user\x00name", + "user\x01name", + "user\x1fname", + "user\x7fname", + ] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "sAMAccountName", + name, + ) + + def test_invalid_sam_account_names_with_at_symbol( + self, + validator: AttributeValueValidator, + ) -> None: + """Test sAMAccountName with @ symbol.""" + invalid_names = ["user@domain", "admin@test", "john@"] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "sAMAccountName", + name, + ) + + +class TestComputerName: + """Tests for Computer name validation.""" + + _entity_type_name = EntityTypeNames.COMPUTER + + def test_valid_names(self, validator: AttributeValueValidator) -> None: + """Test valid computer names.""" + valid_names = [ + "WORKSTATION01", + "Server-2024", + "PC_LAB_123", + "Desktop", + ] + for name in valid_names: + assert validator.is_value_valid( + self._entity_type_name, + "name", + name, + ) + + def test_invalid_names_with_space_at_start( + self, + validator: AttributeValueValidator, + ) -> None: + """Test computer names starting with space.""" + invalid_names = [" WORKSTATION", " Server", " PC123"] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "name", + name, + ) + + def test_invalid_names_with_hash_at_start( + self, + validator: AttributeValueValidator, + ) -> None: + """Test computer names starting with hash.""" + invalid_names = ["#WORKSTATION", "#Server", "#PC123"] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "name", + name, + ) + + def test_invalid_names_with_space_at_end( + self, + validator: AttributeValueValidator, + ) -> None: + """Test computer names ending with space.""" + invalid_names = ["WORKSTATION ", "Server ", "PC123 "] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "name", + name, + ) + + def test_invalid_names_with_forbidden_symbols( + self, + validator: AttributeValueValidator, + ) -> None: + """Test computer names with forbidden symbols.""" + invalid_names = [ + 'PC"01', + "Server,01", + "Work+Station", + "PC\\01", + "Server<1>", + "PC;01", + "Computer=01", + ] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "name", + name, + ) + + +class TestComputerSAMAccountName: + """Tests for Computer sAMAccountName validation.""" + + _entity_type_name = EntityTypeNames.COMPUTER + + def test_valid_sam_account_names( + self, + validator: AttributeValueValidator, + ) -> None: + """Test valid computer sAMAccountName values.""" + valid_names = [ + "WORKSTATION01$", + "SERVER-2024$", + "PC_LAB$", + ] + for name in valid_names: + assert validator.is_value_valid( + self._entity_type_name, + "sAMAccountName", + name, + ) + + def test_invalid_sam_account_names_with_forbidden_symbols( + self, + validator: AttributeValueValidator, + ) -> None: + """Test computer sAMAccountName with forbidden symbols.""" + invalid_names = [ + 'PC"01$', + "PC/01$", + "PC\\01$", + "PC[01]$", + "PC:01$", + "PC;01$", + "PC|01$", + "PC=01$", + "PC,01$", + "PC+01$", + "PC*01$", + "PC?01$", + "PC<01>$", + ] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "sAMAccountName", + name, + ) + + def test_invalid_sam_account_names_ending_with_dot( + self, + validator: AttributeValueValidator, + ) -> None: + """Test computer sAMAccountName ending with dot.""" + invalid_names = ["PC01.", "SERVER.", "WORKSTATION."] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "sAMAccountName", + name, + ) + + def test_invalid_sam_account_names_with_control_chars( + self, + validator: AttributeValueValidator, + ) -> None: + """Test computer sAMAccountName with control characters.""" + invalid_names = [ + "PC\x00NAME$", + "PC\x01NAME$", + "PC\x1fNAME$", + "PC\x7fNAME$", + ] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "sAMAccountName", + name, + ) + + def test_invalid_sam_account_names_with_spaces_and_dots( + self, + validator: AttributeValueValidator, + ) -> None: + """Test computer sAMAccountName with spaces and dots.""" + invalid_names = [ + "PC 01$", + "SERVER 2024$", + "WORK.STATION$", + "PC.01$", + ] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "sAMAccountName", + name, + ) + + def test_invalid_sam_account_names_only_numbers( + self, + validator: AttributeValueValidator, + ) -> None: + """Test computer sAMAccountName that are only numbers.""" + invalid_names = ["123", "456789", "0"] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "sAMAccountName", + name, + ) + + def test_invalid_sam_account_names_starting_with_number( + self, + validator: AttributeValueValidator, + ) -> None: + """Test computer sAMAccountName starting with number.""" + invalid_names = ["1PC$", "2SERVER$", "9WORKSTATION$"] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "sAMAccountName", + name, + ) + + +class TestNoValidationRules: + """Test validation for attributes without specific rules.""" + + def test_attributes_without_rules_always_valid( + self, + validator: AttributeValueValidator, + ) -> None: + """Test that attributes without validation rules always pass.""" + test_cases = [ + (EntityTypeNames.USER, "description", "Any value here!"), + (EntityTypeNames.GROUP, "description", " spaces and #symbols "), + (EntityTypeNames.COMPUTER, "location", "Building 1, Room 101"), + (EntityTypeNames.ORGANIZATIONAL_UNIT, "description", ""), + ] + + for entity_type, property_name, value in test_cases: + assert validator.is_value_valid( + entity_type, + property_name, + value, + ) diff --git a/tests/test_ldap/test_roles/test_multiple_access.py b/tests/test_ldap/test_roles/test_multiple_access.py index 5c6f83c27..da8cc17bc 100644 --- a/tests/test_ldap/test_roles/test_multiple_access.py +++ b/tests/test_ldap/test_roles/test_multiple_access.py @@ -13,7 +13,7 @@ from config import Settings from entities import Directory -from enums import AceType, RoleScope +from enums import AceType, EntityTypeNames, RoleScope from ldap_protocol.ldap_schema.attribute_type_dao import AttributeTypeDAO from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO from ldap_protocol.roles.ace_dao import AccessControlEntryDAO @@ -37,7 +37,7 @@ async def test_multiple_access( custom_role: RoleDTO, ) -> None: """Test multiple access control entries in a role.""" - user_entity_type = await entity_type_dao.get("User") + user_entity_type = await entity_type_dao.get(EntityTypeNames.USER) assert user_entity_type posix_email_attr = await attribute_type_dao.get("posixEmail") diff --git a/tests/test_ldap/test_roles/test_search.py b/tests/test_ldap/test_roles/test_search.py index 35115bc1d..a20e8f0dd 100644 --- a/tests/test_ldap/test_roles/test_search.py +++ b/tests/test_ldap/test_roles/test_search.py @@ -7,7 +7,7 @@ import pytest from config import Settings -from enums import AceType, RoleScope +from enums import AceType, EntityTypeNames, RoleScope from ldap_protocol.ldap_schema.attribute_type_dao import AttributeTypeDAO from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO from ldap_protocol.roles.ace_dao import AccessControlEntryDAO @@ -169,7 +169,7 @@ async def test_role_search_5( User with a custom role should see all Users objects. """ - user_entity_type = await entity_type_dao.get("User") + user_entity_type = await entity_type_dao.get(EntityTypeNames.USER) assert user_entity_type ace = AccessControlEntryDTO( @@ -221,7 +221,7 @@ async def test_role_search_6( User with a custom role should see only the posixEmail attribute. """ - user_entity_type = await entity_type_dao.get("User") + user_entity_type = await entity_type_dao.get(EntityTypeNames.USER) assert user_entity_type posix_email_attr = await attribute_type_dao.get("posixEmail") @@ -270,7 +270,7 @@ async def test_role_search_7( User with a custom role should see all attributes except description. """ - user_entity_type = await entity_type_dao.get("User") + user_entity_type = await entity_type_dao.get(EntityTypeNames.USER) assert user_entity_type description_attr = await attribute_type_dao.get("description") @@ -330,7 +330,7 @@ async def test_role_search_8( User with a custom role should see only the description attribute. """ - user_entity_type = await entity_type_dao.get("User") + user_entity_type = await entity_type_dao.get(EntityTypeNames.USER) assert user_entity_type description_attr = await attribute_type_dao.get("description") @@ -390,7 +390,7 @@ async def test_role_search_9( User with a custom role should see only the posixEmail attribute. """ - user_entity_type = await entity_type_dao.get("User") + user_entity_type = await entity_type_dao.get(EntityTypeNames.USER) assert user_entity_type description_attr = await attribute_type_dao.get("description")