diff --git a/app/api/network/adapters/network.py b/app/api/network/adapters/network.py index c478545c0..9316b8c7e 100644 --- a/app/api/network/adapters/network.py +++ b/app/api/network/adapters/network.py @@ -19,11 +19,11 @@ PolicyUpdate, SwapResponse, ) -from ldap_protocol.policies.network.dto import ( +from ldap_protocol.policies.network import ( NetworkPolicyDTO, NetworkPolicyUpdateDTO, + NetworkPolicyUseCase, ) -from ldap_protocol.policies.network.use_cases import NetworkPolicyUseCase def _convert_netmasks( diff --git a/app/enums.py b/app/enums.py index d94e3ba7a..5258c3a0d 100644 --- a/app/enums.py +++ b/app/enums.py @@ -207,6 +207,14 @@ class AuthorizationRules(IntFlag): SESSION_CLEAR_USER_SESSIONS = auto() SESSION_DELETE = auto() + NETWORK_POLICY_VALIDATOR_GET_BY_PROTOCOL = auto() + NETWORK_POLICY_VALIDATOR_GET_USER_NETWORK_POLICY = auto() + NETWORK_POLICY_VALIDATOR_GET_USER_HTTP_POLICY = auto() + NETWORK_POLICY_VALIDATOR_GET_USER_KERBEROS_POLICY = auto() + NETWORK_POLICY_VALIDATOR_GET_USER_LDAP_POLICY = auto() + NETWORK_POLICY_VALIDATOR_IS_USER_GROUP_VALID = auto() + NETWORK_POLICY_VALIDATOR_CHECK_MFA_GROUP = auto() + @classmethod def get_all(cls) -> Self: return cls(sum(cls)) @@ -218,6 +226,14 @@ def combine( return reduce(or_, permissions, AuthorizationRules(0)) +class ProtocolType(StrEnum): + """Protocol fields.""" + + LDAP = "is_ldap" + HTTP = "is_http" + KERBEROS = "is_kerberos" + + class DomainCodes(IntEnum): """Error code parts.""" diff --git a/app/ioc.py b/app/ioc.py index c1a1d7e93..3ad06c6b4 100644 --- a/app/ioc.py +++ b/app/ioc.py @@ -116,8 +116,13 @@ ) from ldap_protocol.policies.audit.policies_dao import AuditPoliciesDAO from ldap_protocol.policies.audit.service import AuditService -from ldap_protocol.policies.network.gateway import NetworkPolicyGateway -from ldap_protocol.policies.network.use_cases import NetworkPolicyUseCase +from ldap_protocol.policies.network import ( + NetworkPolicyGateway, + NetworkPolicyUseCase, + NetworkPolicyValidatorGateway, + NetworkPolicyValidatorProtocol, + NetworkPolicyValidatorUseCase, +) from ldap_protocol.policies.password import ( PasswordPolicyDAO, PasswordPolicyUseCases, @@ -514,6 +519,20 @@ class HTTPProvider(LDAPContextProvider): scope = Scope.REQUEST request = from_context(provides=Request, scope=Scope.REQUEST) monitor_use_case = provide(AuditMonitorUseCase, scope=Scope.REQUEST) + network_policy_gateway = provide(NetworkPolicyGateway, scope=Scope.REQUEST) + network_policy_use_case = provide( + NetworkPolicyUseCase, + scope=Scope.REQUEST, + ) + network_policy_validator_gateway = provide( + NetworkPolicyValidatorGateway, + provides=NetworkPolicyValidatorProtocol, + scope=Scope.REQUEST, + ) + network_policy_validator_use_case = provide( + NetworkPolicyValidatorUseCase, + scope=Scope.REQUEST, + ) @provide() async def get_audit_monitor( @@ -651,11 +670,6 @@ def get_krb_template_render( NetworkPolicyFastAPIAdapter, scope=Scope.REQUEST, ) - network_policy_use_case = provide( - NetworkPolicyUseCase, - scope=Scope.REQUEST, - ) - network_policy_gateway = provide(NetworkPolicyGateway, scope=Scope.REQUEST) class LDAPServerProvider(LDAPContextProvider): @@ -663,6 +677,21 @@ class LDAPServerProvider(LDAPContextProvider): scope = Scope.SESSION + network_policy_validator_gateway = provide( + NetworkPolicyValidatorGateway, + scope=Scope.REQUEST, + ) + + network_policy_validator = provide( + NetworkPolicyValidatorGateway, + provides=NetworkPolicyValidatorProtocol, + scope=Scope.REQUEST, + ) + network_policy_validator_use_case = provide( + NetworkPolicyValidatorUseCase, + scope=Scope.REQUEST, + ) + @provide(scope=Scope.SESSION, provides=LDAPSession) async def get_session( self, diff --git a/app/ldap_protocol/auth/auth_manager.py b/app/ldap_protocol/auth/auth_manager.py index 47dc4df3a..61c336f56 100644 --- a/app/ldap_protocol/auth/auth_manager.py +++ b/app/ldap_protocol/auth/auth_manager.py @@ -32,10 +32,7 @@ from ldap_protocol.multifactor import MultifactorAPI from ldap_protocol.objects import UserAccountControlFlag from ldap_protocol.policies.audit.monitor import AuditMonitorUseCase -from ldap_protocol.policies.network_policy import ( - check_mfa_group, - get_user_network_policy, -) +from ldap_protocol.policies.network import NetworkPolicyValidatorUseCase from ldap_protocol.policies.password import PasswordPolicyUseCases from ldap_protocol.session_storage import SessionStorage from ldap_protocol.session_storage.repository import SessionRepository @@ -61,6 +58,7 @@ def __init__( mfa_manager: MFAManager, setup_use_case: SetupUseCase, identity_provider: IdentityProvider, + network_policy_validator: NetworkPolicyValidatorUseCase, ) -> None: """Initialize dependencies of the manager (via DI). @@ -84,6 +82,7 @@ def __init__( self._mfa_manager = mfa_manager self._setup_use_case = setup_use_case self._identity_provider = identity_provider + self._network_policy_validator = network_policy_validator def __getattribute__(self, name: str) -> object: """Intercept attribute access.""" @@ -147,11 +146,11 @@ async def login( if user.is_expired(): raise LoginFailedError("User account is expired") - network_policy = await get_user_network_policy( - ip, - user, - self._session, - policy_type="is_http", + network_policy = ( + await self._network_policy_validator.get_user_http_policy( + ip, + user, + ) ) if network_policy is None: raise LoginFailedError("User not part of network policy") @@ -162,10 +161,11 @@ async def login( ): request_2fa = True if network_policy.mfa_status == MFAFlags.WHITELIST: - request_2fa = await check_mfa_group( - network_policy, - user, - self._session, + request_2fa = ( + await self._network_policy_validator.check_mfa_group( + network_policy, + user, + ) ) if request_2fa: ( diff --git a/app/ldap_protocol/auth/mfa_manager.py b/app/ldap_protocol/auth/mfa_manager.py index 50d7024d7..334a66a44 100644 --- a/app/ldap_protocol/auth/mfa_manager.py +++ b/app/ldap_protocol/auth/mfa_manager.py @@ -46,10 +46,7 @@ MultifactorAPI, ) from ldap_protocol.policies.audit.monitor import AuditMonitorUseCase -from ldap_protocol.policies.network_policy import ( - check_mfa_group, - get_user_network_policy, -) +from ldap_protocol.policies.network import NetworkPolicyValidatorUseCase from ldap_protocol.session_storage import SessionStorage from ldap_protocol.session_storage.repository import SessionRepository from password_utils import PasswordUtils @@ -72,6 +69,7 @@ def __init__( monitor: AuditMonitorUseCase, password_utils: PasswordUtils, identity_provider: IdentityProvider, + network_policy_validator: NetworkPolicyValidatorUseCase, ) -> None: """Initialize dependencies via DI. @@ -90,6 +88,7 @@ def __init__( self._monitor = monitor self._password_utils = password_utils self._identity_provider = identity_provider + self._network_policy_validator = network_policy_validator def __getattribute__(self, name: str) -> object: """Intercept attribute access.""" @@ -328,11 +327,11 @@ async def proxy_request(self, principal: str, ip: IPv4Address) -> None: f"User {principal} not found in the database.", ) - network_policy = await get_user_network_policy( - ip, - user, - self._session, - policy_type="is_kerberos", + network_policy = ( + await self._network_policy_validator.get_user_kerberos_policy( + ip, + user, + ) ) if network_policy is None or not network_policy.is_kerberos: @@ -351,10 +350,9 @@ async def proxy_request(self, principal: str, ip: IPv4Address) -> None: ): if ( network_policy.mfa_status == MFAFlags.WHITELIST - and not await check_mfa_group( + and not await self._network_policy_validator.check_mfa_group( network_policy, user, - self._session, ) ): return diff --git a/app/ldap_protocol/dialogue.py b/app/ldap_protocol/dialogue.py index c15a53051..a594b628c 100644 --- a/app/ldap_protocol/dialogue.py +++ b/app/ldap_protocol/dialogue.py @@ -16,10 +16,10 @@ from typing import TYPE_CHECKING, AsyncIterator import gssapi -from sqlalchemy.ext.asyncio import AsyncSession from entities import NetworkPolicy, User -from ldap_protocol.policies.network_policy import build_policy_query +from enums import ProtocolType +from ldap_protocol.policies.network import NetworkPolicyValidatorUseCase from .session_storage import SessionStorage @@ -142,21 +142,16 @@ async def lock(self) -> AsyncIterator[UserSchema | None]: async with self._lock: yield self._user - @staticmethod - async def _get_policy( - ip: IPv4Address, - session: AsyncSession, - ) -> NetworkPolicy | None: - query = build_policy_query(ip, "is_ldap") - return await session.scalar(query) - async def validate_conn( self, ip: IPv4Address | IPv6Address, - session: AsyncSession, + network_policy_use_case: NetworkPolicyValidatorUseCase, ) -> None: """Validate network policies.""" - policy = await self._get_policy(ip, session) # type: ignore + policy = await network_policy_use_case.get_by_protocol( + ip, + ProtocolType.LDAP, + ) if policy is not None: self.policy = policy await self.bind_session() diff --git a/app/ldap_protocol/ldap_requests/bind.py b/app/ldap_protocol/ldap_requests/bind.py index d64294e72..a303f0fdf 100644 --- a/app/ldap_protocol/ldap_requests/bind.py +++ b/app/ldap_protocol/ldap_requests/bind.py @@ -8,12 +8,10 @@ from typing import AsyncGenerator, ClassVar from pydantic import Field -from sqlalchemy.ext.asyncio import AsyncSession -from entities import NetworkPolicy, User +from entities import NetworkPolicy from enums import MFAFlags from ldap_protocol.asn1parser import ASN1Row -from ldap_protocol.dialogue import LDAPSession from ldap_protocol.kerberos.exceptions import ( KRBAPIAddPrincipalError, KRBAPIConnectionError, @@ -34,10 +32,6 @@ from ldap_protocol.ldap_responses import BaseResponse, BindResponse from ldap_protocol.multifactor import MultifactorAPI from ldap_protocol.objects import ProtocolRequests, UserAccountControlFlag -from ldap_protocol.policies.network_policy import ( - check_mfa_group, - is_user_group_valid, -) from ldap_protocol.user_account_control import get_check_uac from ldap_protocol.utils.queries import set_user_logon_attrs @@ -93,15 +87,6 @@ def from_data(cls, data: list[ASN1Row]) -> "BindRequest": AuthenticationChoice=auth_choice, ) - @staticmethod - async def is_user_group_valid( - user: User, - ldap_session: LDAPSession, - session: AsyncSession, - ) -> bool: - """Test compability.""" - return await is_user_group_valid(user, ldap_session.policy, session) - @staticmethod async def check_mfa( api: MultifactorAPI | None, @@ -173,11 +158,12 @@ async def handle( if uac_check(UserAccountControlFlag.ACCOUNTDISABLE): yield get_bad_response(LDAPBindErrors.ACCOUNT_DISABLED) return - - if not await self.is_user_group_valid( + policy = getattr(ctx.ldap_session, "policy", None) + if ( + policy is not None + ) and not await ctx.network_policy_validator.is_user_group_valid( user, - ctx.ldap_session, - ctx.session, + policy, ): yield get_bad_response(LDAPBindErrors.LOGON_FAILURE) return @@ -192,14 +178,18 @@ async def handle( yield get_bad_response(LDAPBindErrors.PASSWORD_MUST_CHANGE) return - if ( - (policy := getattr(ctx.ldap_session, "policy", None)) - and policy.mfa_status in (MFAFlags.ENABLED, MFAFlags.WHITELIST) + if (policy is not None) and ( + policy.mfa_status in (MFAFlags.ENABLED, MFAFlags.WHITELIST) and ctx.mfa is not None ): request_2fa = True if policy.mfa_status == MFAFlags.WHITELIST: - request_2fa = await check_mfa_group(policy, user, ctx.session) + request_2fa = ( + await ctx.network_policy_validator.check_mfa_group( + policy, + user, + ) + ) if request_2fa: mfa_status = await self.check_mfa( diff --git a/app/ldap_protocol/ldap_requests/contexts.py b/app/ldap_protocol/ldap_requests/contexts.py index 469939516..4926d826e 100644 --- a/app/ldap_protocol/ldap_requests/contexts.py +++ b/app/ldap_protocol/ldap_requests/contexts.py @@ -16,6 +16,7 @@ ) from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO from ldap_protocol.multifactor import LDAPMultiFactorAPI +from ldap_protocol.policies.network import NetworkPolicyValidatorUseCase from ldap_protocol.policies.password import PasswordPolicyUseCases from ldap_protocol.roles.access_manager import AccessManager from ldap_protocol.roles.role_use_case import RoleUseCase @@ -66,6 +67,7 @@ class LDAPBindRequestContext: password_use_cases: PasswordPolicyUseCases password_utils: PasswordUtils mfa: LDAPMultiFactorAPI + network_policy_validator: NetworkPolicyValidatorUseCase @dataclass diff --git a/app/ldap_protocol/policies/network/__init__.py b/app/ldap_protocol/policies/network/__init__.py index 09538aa8d..c4ea76e70 100644 --- a/app/ldap_protocol/policies/network/__init__.py +++ b/app/ldap_protocol/policies/network/__init__.py @@ -1,13 +1,26 @@ """Network policies module.""" -from .dto import NetworkPolicyDTO -from .exceptions import NetworkPolicyAlreadyExistsError +from .dto import NetworkPolicyDTO, NetworkPolicyUpdateDTO, SwapPrioritiesDTO +from .exceptions import ( + LastActivePolicyError, + NetworkPolicyAlreadyExistsError, + NetworkPolicyNotFoundError, +) from .gateway import NetworkPolicyGateway -from .use_cases import NetworkPolicyUseCase +from .use_cases import NetworkPolicyUseCase, NetworkPolicyValidatorUseCase +from .validator_gateway import NetworkPolicyValidatorGateway +from .validator_protocol import NetworkPolicyValidatorProtocol __all__ = [ "NetworkPolicyDTO", + "NetworkPolicyUpdateDTO", + "SwapPrioritiesDTO", "NetworkPolicyAlreadyExistsError", + "LastActivePolicyError", + "NetworkPolicyNotFoundError", "NetworkPolicyGateway", "NetworkPolicyUseCase", + "NetworkPolicyValidatorUseCase", + "NetworkPolicyValidatorGateway", + "NetworkPolicyValidatorProtocol", ] diff --git a/app/ldap_protocol/policies/network/use_cases.py b/app/ldap_protocol/policies/network/use_cases.py index a069b02d8..cde4294d6 100644 --- a/app/ldap_protocol/policies/network/use_cases.py +++ b/app/ldap_protocol/policies/network/use_cases.py @@ -4,6 +4,7 @@ License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE """ +from ipaddress import IPv4Address, IPv6Address from typing import ClassVar from adaptix import P @@ -11,8 +12,8 @@ from sqlalchemy.ext.asyncio import AsyncSession from abstract_service import AbstractService -from entities import NetworkPolicy -from enums import AuthorizationRules +from entities import NetworkPolicy, User +from enums import AuthorizationRules, ProtocolType from ldap_protocol.policies.network.dto import ( NetworkPolicyDTO, NetworkPolicyUpdateDTO, @@ -22,8 +23,10 @@ LastActivePolicyError, NetworkPolicyAlreadyExistsError, ) - -from .gateway import NetworkPolicyGateway +from ldap_protocol.policies.network.gateway import NetworkPolicyGateway +from ldap_protocol.policies.network.validator_protocol import ( + NetworkPolicyValidatorProtocol, +) def _convert_groups(policy: NetworkPolicy) -> list[str]: @@ -197,3 +200,103 @@ async def swap_priorities(self, id1: int, id2: int) -> SwapPrioritiesDTO: update.__name__: AuthorizationRules.NETWORK_POLICY_UPDATE, swap_priorities.__name__: AuthorizationRules.NETWORK_POLICY_SWAP_PRIORITIES, # noqa: E501 } + + +class NetworkPolicyValidatorUseCase(AbstractService): + """Network policies validator use cases.""" + + def __init__( + self, + network_policy_validator_gateway: NetworkPolicyValidatorProtocol, + ): + """Initialize Network policies validator use cases.""" + self._gateway = network_policy_validator_gateway + + async def get_by_protocol( + self, + ip: IPv4Address | IPv6Address, + protocol_type: ProtocolType, + ) -> NetworkPolicy | None: + """Get network policy by protocol.""" + return await self._gateway.get_by_protocol( + ip, + protocol_type, + ) + + async def get_user_network_policy( + self, + ip: IPv4Address | IPv6Address, + user: User, + policy_type: ProtocolType, + ) -> NetworkPolicy | None: + """Get user network policy.""" + return await self._gateway.get_user_network_policy( + ip, + user, + policy_type, + ) + + async def get_user_http_policy( + self, + ip: IPv4Address | IPv6Address, + user: User, + ) -> NetworkPolicy | None: + """Get user HTTP policy.""" + return await self._gateway.get_user_http_policy( + ip, + user, + ) + + async def get_user_kerberos_policy( + self, + ip: IPv4Address | IPv6Address, + user: User, + ) -> NetworkPolicy | None: + """Get user Kerberos policy.""" + return await self._gateway.get_user_kerberos_policy( + ip, + user, + ) + + async def get_user_ldap_policy( + self, + ip: IPv4Address | IPv6Address, + user: User, + ) -> NetworkPolicy | None: + """Get user LDAP policy.""" + return await self._gateway.get_user_ldap_policy( + ip, + user, + ) + + async def is_user_group_valid( + self, + user: User | None, + policy: NetworkPolicy | None, + ) -> bool: + """Validate user groups, is it including to policy.""" + return await self._gateway.is_user_group_valid( + user, + policy, + ) + + async def check_mfa_group( + self, + policy: NetworkPolicy, + user: User, + ) -> bool: + """Check if user is in a group with MFA policy.""" + return await self._gateway.check_mfa_group( + policy, + user, + ) + + PERMISSIONS: ClassVar[dict[str, AuthorizationRules]] = { + get_by_protocol.__name__: AuthorizationRules.NETWORK_POLICY_VALIDATOR_GET_BY_PROTOCOL, # noqa: E501 + get_user_network_policy.__name__: AuthorizationRules.NETWORK_POLICY_VALIDATOR_GET_USER_NETWORK_POLICY, # noqa: E501 + get_user_http_policy.__name__: AuthorizationRules.NETWORK_POLICY_VALIDATOR_GET_USER_HTTP_POLICY, # noqa: E501 + get_user_kerberos_policy.__name__: AuthorizationRules.NETWORK_POLICY_VALIDATOR_GET_USER_KERBEROS_POLICY, # noqa: E501 + get_user_ldap_policy.__name__: AuthorizationRules.NETWORK_POLICY_VALIDATOR_GET_USER_LDAP_POLICY, # noqa: E501 + is_user_group_valid.__name__: AuthorizationRules.NETWORK_POLICY_VALIDATOR_IS_USER_GROUP_VALID, # noqa: E501 + check_mfa_group.__name__: AuthorizationRules.NETWORK_POLICY_VALIDATOR_CHECK_MFA_GROUP, # noqa: E501 + } diff --git a/app/ldap_protocol/policies/network/validator_gateway.py b/app/ldap_protocol/policies/network/validator_gateway.py new file mode 100644 index 000000000..e04cb0dbe --- /dev/null +++ b/app/ldap_protocol/policies/network/validator_gateway.py @@ -0,0 +1,178 @@ +"""Network policy validator gateway. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from ipaddress import IPv4Address, IPv6Address + +from sqlalchemy import exists, or_, select, text +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload +from sqlalchemy.sql.expression import Select, true + +from entities import Group, NetworkPolicy, User +from enums import ProtocolType +from repo.pg.tables import queryable_attr as qa + + +class NetworkPolicyValidatorGateway: + """Gateway for validating network policies.""" + + def __init__( + self, + session: AsyncSession, + ): + """Initialize validator gateway.""" + self._session = session + + def _build_base_query( + self, + ip: IPv4Address | IPv6Address, + protocol_type: ProtocolType, + ) -> Select: + """Build a base query for network policies. + + :param IPv4Address | IPv6Address ip: IP address to filter + :param ProtocolType protocol_type: Protocol to filter + :param list[int] | None user_group_ids: + List of user group IDs, optional + :return: Select query + """ + protocol_field = getattr(NetworkPolicy, protocol_type) + query = ( + select(NetworkPolicy) + .options( + selectinload(qa(NetworkPolicy.groups)), + selectinload(qa(NetworkPolicy.mfa_groups)), + ) + .filter( + qa(NetworkPolicy.enabled).is_(True), + text(':ip <<= ANY("Policies".netmasks)').bindparams(ip=ip), + protocol_field == true(), + ) + .order_by(qa(NetworkPolicy.priority).asc()) + .limit(1) + ) + + return query + + async def get_by_protocol( + self, + ip: IPv4Address | IPv6Address, + protocol_type: ProtocolType, + ) -> NetworkPolicy | None: + """Get network policy by protocol.""" + query = self._build_base_query(ip, protocol_type) + return await self._session.scalar(query) + + async def get_user_network_policy( + self, + ip: IPv4Address | IPv6Address, + user: User, + policy_type: ProtocolType, + ) -> NetworkPolicy | None: + """Get the highest priority network policy for user, ip and protocol. + + :param User user: user object + :return NetworkPolicy | None: a NetworkPolicy object + """ + user_group_ids = [group.id for group in user.groups] + + query = self._build_base_query(ip, policy_type) + + if user_group_ids is not None: + query = query.filter( + or_( + qa(NetworkPolicy.groups) == None, # noqa + qa(NetworkPolicy.groups).any( + qa(Group.id).in_(user_group_ids), + ), + ), + ) + + return await self._session.scalar(query) + + async def get_user_http_policy( + self, + ip: IPv4Address | IPv6Address, + user: User, + ) -> NetworkPolicy | None: + """Get user HTTP policy.""" + return await self.get_user_network_policy( + ip, + user, + ProtocolType.HTTP, + ) + + async def get_user_kerberos_policy( + self, + ip: IPv4Address | IPv6Address, + user: User, + ) -> NetworkPolicy | None: + """Get user Kerberos policy.""" + return await self.get_user_network_policy( + ip, + user, + ProtocolType.KERBEROS, + ) + + async def get_user_ldap_policy( + self, + ip: IPv4Address | IPv6Address, + user: User, + ) -> NetworkPolicy | None: + """Get user LDAP policy.""" + return await self.get_user_network_policy( + ip, + user, + ProtocolType.LDAP, + ) + + async def is_user_group_valid( + self, + user: User | None, + policy: NetworkPolicy | None, + ) -> bool: + """Validate user groups, is it including to policy. + + :param User user: db user + :param NetworkPolicy policy: db policy + :return bool: status + """ + if not (user and policy): + return False + + if not policy.groups: + return True + query = select( + select(Group) + .join(qa(Group.users)) + .join(qa(Group.policies), isouter=True) + .exists() + .where(qa(Group.users).contains(user)) + .where(qa(Group.policies).contains(policy)), + ) + group = await self._session.scalar(query) + + return bool(group) + + async def check_mfa_group( + self, + policy: NetworkPolicy, + user: User, + ) -> bool: + """Check if user is in a group with MFA policy. + + :param NetworkPolicy policy: policy object + :param User user: user object + :return bool: status + """ + return await self._session.scalar( + select( + exists().where( # type: ignore + qa(Group.mfa_policies).contains(policy), + qa(Group.users).contains(user), + ), + ), + ) diff --git a/app/ldap_protocol/policies/network/validator_protocol.py b/app/ldap_protocol/policies/network/validator_protocol.py new file mode 100644 index 000000000..5ead88a2f --- /dev/null +++ b/app/ldap_protocol/policies/network/validator_protocol.py @@ -0,0 +1,72 @@ +"""Network policy validator protocol. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from ipaddress import IPv4Address, IPv6Address +from typing import Protocol + +from entities import NetworkPolicy, User +from enums import ProtocolType + + +class NetworkPolicyValidatorProtocol(Protocol): + """Protocol for validating network policies.""" + + async def get_user_http_policy( + self, + ip: IPv4Address | IPv6Address, + user: User, + ) -> NetworkPolicy | None: + """Get user HTTP policy.""" + ... + + async def get_user_kerberos_policy( + self, + ip: IPv4Address | IPv6Address, + user: User, + ) -> NetworkPolicy | None: + """Get user Kerberos policy.""" + ... + + async def get_user_ldap_policy( + self, + ip: IPv4Address | IPv6Address, + user: User, + ) -> NetworkPolicy | None: + """Get user LDAP policy.""" + ... + + async def check_mfa_group( + self, + policy: NetworkPolicy, + user: User, + ) -> bool: + """Check if user is in a group with MFA policy.""" + ... + + async def is_user_group_valid( + self, + user: User | None, + policy: NetworkPolicy | None, + ) -> bool: + """Validate user groups, is it including to policy.""" + ... + + async def get_user_network_policy( + self, + ip: IPv4Address | IPv6Address, + user: User, + policy_type: ProtocolType, + ) -> NetworkPolicy | None: + """Get the highest priority network policy.""" + ... + + async def get_by_protocol( + self, + ip: IPv4Address | IPv6Address, + protocol_type: ProtocolType, + ) -> NetworkPolicy | None: + """Get network policy by protocol.""" + ... diff --git a/app/ldap_protocol/policies/network_policy.py b/app/ldap_protocol/policies/network_policy.py deleted file mode 100644 index 264616ab7..000000000 --- a/app/ldap_protocol/policies/network_policy.py +++ /dev/null @@ -1,130 +0,0 @@ -"""Network policy manager. - -Copyright (c) 2024 MultiFactor -License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE -""" - -from ipaddress import IPv4Address, IPv6Address -from typing import Literal - -from sqlalchemy import exists, or_, select, text -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import selectinload -from sqlalchemy.sql.expression import Select, true - -from entities import Group, NetworkPolicy, User -from repo.pg.tables import queryable_attr as qa - - -def build_policy_query( - ip: IPv4Address | IPv6Address, - protocol_field_name: Literal["is_http", "is_ldap", "is_kerberos"], - user_group_ids: list[int] | None = None, -) -> Select: - """Build a base query for network policies with optional group filtering. - - :param IPv4Address ip: IP address to filter - :param Literal["is_http", "is_ldap", "is_kerberos"] protocol_field_name - protocol: Protocol to filter - :param list[int] | None user_group_ids: List of user group IDs, optional - :return: Select query - """ - protocol_field = getattr(NetworkPolicy, protocol_field_name) - query = ( - select(NetworkPolicy) - .filter_by(enabled=True) - .options( - selectinload(qa(NetworkPolicy.groups)), - selectinload(qa(NetworkPolicy.mfa_groups)), - ) - .filter( - text(':ip <<= ANY("Policies".netmasks)').bindparams(ip=ip), - protocol_field == true(), - ) - .order_by(qa(NetworkPolicy.priority).asc()) - .limit(1) - ) - - if user_group_ids is not None: - return query.filter( - or_( - qa(NetworkPolicy.groups) == None, # noqa - qa(NetworkPolicy.groups).any( - qa(Group.id).in_(user_group_ids), - ), - ), - ) - - return query - - -async def check_mfa_group( - policy: NetworkPolicy, - user: User, - session: AsyncSession, -) -> bool: - """Check if user is in a group with MFA policy. - - :param NetworkPolicy policy: policy object - :param User user: user object - :param AsyncSession session: db session - :return bool: status - """ - return await session.scalar( - select( - exists().where( # type: ignore - qa(Group.mfa_policies).contains(policy), - qa(Group.users).contains(user), - ), - ), - ) - - -async def get_user_network_policy( - ip: IPv4Address | IPv6Address, - user: User, - session: AsyncSession, - policy_type: Literal["is_http", "is_ldap", "is_kerberos"], -) -> NetworkPolicy | None: - """Get the highest priority network policy for user, ip and protocol. - - :param User user: user object - :param AsyncSession session: db session - :return NetworkPolicy | None: a NetworkPolicy object - """ - user_group_ids = [group.id for group in user.groups] - - query = build_policy_query(ip, policy_type, user_group_ids) - - return await session.scalar(query) - - -async def is_user_group_valid( - user: User | None, - policy: NetworkPolicy | None, - session: AsyncSession, -) -> bool: - """Validate user groups, is it including to policy. - - :param User user: db user - :param NetworkPolicy policy: db policy - :param AsyncSession session: db - :return bool: status - """ - if user is None or policy is None: - return False - - if not policy.groups: - return True - - query = ( - select(Group) - .join(qa(Group.users)) - .join(qa(Group.policies), isouter=True) - .where(qa(Group.users).contains(user)) - .where(qa(Group.policies).contains(policy)) - .limit(1) - ) - - group = await session.scalar(query) - return bool(group) diff --git a/app/ldap_protocol/server.py b/app/ldap_protocol/server.py index 597b107a6..6f24cd312 100644 --- a/app/ldap_protocol/server.py +++ b/app/ldap_protocol/server.py @@ -18,11 +18,11 @@ from proxyprotocol import ProxyProtocolIncompleteError from proxyprotocol.v2 import ProxyProtocolV2 from pydantic import ValidationError -from sqlalchemy.ext.asyncio import AsyncSession from config import Settings from ldap_protocol import LDAPRequestMessage, LDAPSession from ldap_protocol.ldap_requests.bind_methods import GSSAPISL +from ldap_protocol.policies.network import NetworkPolicyValidatorUseCase from .data_logger import DataLogger @@ -83,8 +83,13 @@ async def __call__( try: async with session_scope(scope=Scope.REQUEST) as r: try: - session = await r.get(AsyncSession) - await ldap_session.validate_conn(addr, session) + network_policy_use_case = await r.get( + NetworkPolicyValidatorUseCase, + ) + await ldap_session.validate_conn( + addr, + network_policy_use_case, + ) except PermissionError: log.warning(f"Whitelist violation from {addr}") return diff --git a/app/ldap_protocol/udp_server.py b/app/ldap_protocol/udp_server.py index b889d0264..84c69b16e 100644 --- a/app/ldap_protocol/udp_server.py +++ b/app/ldap_protocol/udp_server.py @@ -10,10 +10,10 @@ from dishka import AsyncContainer, Scope from loguru import logger from pydantic import ValidationError -from sqlalchemy.ext.asyncio import AsyncSession from config import Settings from ldap_protocol import LDAPRequestMessage, LDAPSession +from ldap_protocol.policies.network import NetworkPolicyValidatorUseCase from .data_logger import DataLogger from .utils.udp import create_udp_socket @@ -50,8 +50,13 @@ async def _handle( ldap_session.ip = ip_address(addr[0]) try: - session = await container.get(AsyncSession) - await ldap_session.validate_conn(ldap_session.ip, session) + network_policy_use_case = await container.get( + NetworkPolicyValidatorUseCase, + ) + await ldap_session.validate_conn( + ldap_session.ip, + network_policy_use_case, + ) except PermissionError: log.warning(f"Whitelist violation from UDP {addr_str}") raise ConnectionAbortedError diff --git a/interface b/interface index 242c01f0f..21b31fed4 160000 --- a/interface +++ b/interface @@ -1 +1 @@ -Subproject commit 242c01f0f26a5080beef14523f9dd9dfab3c89ec +Subproject commit 21b31fed42a5082311a458da4d475c839f99a717 diff --git a/tests/conftest.py b/tests/conftest.py index 1a41b3783..b97a8ce4a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -122,8 +122,13 @@ ) from ldap_protocol.policies.audit.policies_dao import AuditPoliciesDAO from ldap_protocol.policies.audit.service import AuditService -from ldap_protocol.policies.network.gateway import NetworkPolicyGateway -from ldap_protocol.policies.network.use_cases import NetworkPolicyUseCase +from ldap_protocol.policies.network import ( + NetworkPolicyGateway, + NetworkPolicyUseCase, + NetworkPolicyValidatorGateway, + NetworkPolicyValidatorProtocol, + NetworkPolicyValidatorUseCase, +) from ldap_protocol.policies.password import ( PasswordPolicyDAO, PasswordPolicyUseCases, @@ -336,6 +341,16 @@ def get_object_class_dao(self, session: AsyncSession) -> ObjectClassDAO: dns_fastapi_adapter = provide(DNSFastAPIAdapter, scope=Scope.REQUEST) dns_use_case = provide(DNSUseCase, scope=Scope.REQUEST) dns_state_gateway = provide(DNSStateGateway, scope=Scope.REQUEST) + network_policy_gateway = provide(NetworkPolicyGateway, scope=Scope.SESSION) + network_policy_validator_gateway = provide( + NetworkPolicyValidatorGateway, + provides=NetworkPolicyValidatorProtocol, + scope=Scope.SESSION, + ) + network_policy_validator = provide( + NetworkPolicyValidatorUseCase, + scope=Scope.SESSION, + ) @provide(scope=scope, provides=AsyncEngine) def get_engine(self, settings: Settings) -> AsyncEngine: @@ -682,7 +697,6 @@ async def get_audit_monitor( NetworkPolicyUseCase, scope=Scope.REQUEST, ) - network_policy_gateway = provide(NetworkPolicyGateway, scope=Scope.REQUEST) @provide( provides=AuthorizationProviderProtocol, @@ -1023,6 +1037,24 @@ async def ldap_bound_session( return +@pytest_asyncio.fixture(scope="function") +async def network_policy_gateway( + container: AsyncContainer, +) -> AsyncIterator[NetworkPolicyGateway]: + """Get network policy gateway.""" + async with container(scope=Scope.SESSION) as container: + yield await container.get(NetworkPolicyGateway) + + +@pytest_asyncio.fixture(scope="function") +async def network_policy_validator( + container: AsyncContainer, +) -> AsyncIterator[NetworkPolicyValidatorUseCase]: + """Get network policy validator.""" + async with container(scope=Scope.SESSION) as container: + yield await container.get(NetworkPolicyValidatorUseCase) + + @pytest_asyncio.fixture(scope="session") async def handler( settings: Settings, diff --git a/tests/test_ldap/policies/test_network/test_pool_client_handler.py b/tests/test_ldap/policies/test_network/test_pool_client_handler.py index 95bad8cc3..9f212986c 100644 --- a/tests/test_ldap/policies/test_network/test_pool_client_handler.py +++ b/tests/test_ldap/policies/test_network/test_pool_client_handler.py @@ -10,8 +10,8 @@ from sqlalchemy.ext.asyncio import AsyncSession from entities import NetworkPolicy -from ldap_protocol.dialogue import LDAPSession -from ldap_protocol.policies.network_policy import is_user_group_valid +from enums import ProtocolType +from ldap_protocol.policies.network import NetworkPolicyValidatorUseCase from ldap_protocol.utils.queries import get_group, get_user @@ -19,18 +19,20 @@ @pytest.mark.usefixtures("setup_session") @pytest.mark.usefixtures("session") async def test_check_policy( - ldap_session: LDAPSession, - session: AsyncSession, + network_policy_validator: NetworkPolicyValidatorUseCase, ) -> None: """Check policy.""" - policy = await ldap_session._get_policy(IPv4Address("127.0.0.1"), session) + policy = await network_policy_validator.get_by_protocol( + IPv4Address("127.0.0.1"), + ProtocolType.LDAP, + ) assert policy assert policy.netmasks == [IPv4Network("0.0.0.0/0")] @pytest.mark.asyncio async def test_specific_policy_ok( - ldap_session: LDAPSession, + network_policy_validator: NetworkPolicyValidatorUseCase, session: AsyncSession, ) -> None: """Test specific ip.""" @@ -44,15 +46,15 @@ async def test_specific_policy_ok( ), ) await session.commit() - policy = await ldap_session._get_policy( + policy = await network_policy_validator.get_by_protocol( ip=IPv4Address("127.100.10.5"), - session=session, + protocol_type=ProtocolType.LDAP, ) assert policy assert policy.netmasks == [IPv4Network("127.100.10.5/32")] - assert not await ldap_session._get_policy( + assert not await network_policy_validator.get_by_protocol( ip=IPv4Address("127.100.10.4"), - session=session, + protocol_type=ProtocolType.LDAP, ) @@ -60,17 +62,20 @@ async def test_specific_policy_ok( @pytest.mark.usefixtures("setup_session") @pytest.mark.usefixtures("settings") async def test_check_policy_group( - ldap_session: LDAPSession, + network_policy_validator: NetworkPolicyValidatorUseCase, session: AsyncSession, ) -> None: """Check policy.""" user = await get_user(session, "user0") assert user - policy = await ldap_session._get_policy(IPv4Address("127.0.0.1"), session) + policy = await network_policy_validator.get_by_protocol( + IPv4Address("127.0.0.1"), + ProtocolType.LDAP, + ) assert policy - assert await is_user_group_valid(user, policy, session) + assert await network_policy_validator.is_user_group_valid(user, policy) group = await get_group( dn="cn=domain admins,cn=groups,dc=md,dc=test", @@ -80,4 +85,4 @@ async def test_check_policy_group( policy.groups.append(group) await session.commit() - assert await is_user_group_valid(user, policy, session) + assert await network_policy_validator.is_user_group_valid(user, policy) diff --git a/tests/test_ldap/test_util/test_search.py b/tests/test_ldap/test_util/test_search.py index 137508bf6..903fb2598 100644 --- a/tests/test_ldap/test_util/test_search.py +++ b/tests/test_ldap/test_util/test_search.py @@ -14,13 +14,13 @@ from config import Settings from entities import User -from enums import AceType, RoleScope +from enums import AceType, ProtocolType, RoleScope from ldap_protocol.asn1parser import ASN1Row, TagNumbers from ldap_protocol.dialogue import LDAPSession from ldap_protocol.ldap_requests import SearchRequest from ldap_protocol.ldap_requests.contexts import LDAPSearchRequestContext from ldap_protocol.ldap_responses import SearchResultEntry -from ldap_protocol.policies.network_policy import is_user_group_valid +from ldap_protocol.policies.network import NetworkPolicyValidatorUseCase from ldap_protocol.roles.ace_dao import AccessControlEntryDAO from ldap_protocol.roles.dataclasses import AccessControlEntryDTO, RoleDTO from ldap_protocol.roles.role_dao import RoleDAO @@ -307,10 +307,13 @@ async def test_bind_policy( session: AsyncSession, settings: Settings, creds: TestCreds, - ldap_session: LDAPSession, + network_policy_validator: NetworkPolicyValidatorUseCase, ) -> None: """Bind with policy.""" - policy = await ldap_session._get_policy(IPv4Address("127.0.0.1"), session) # noqa: SLF001 + policy = await network_policy_validator.get_by_protocol( + IPv4Address("127.0.0.1"), + ProtocolType.LDAP, + ) assert policy group = await get_group( @@ -345,12 +348,15 @@ async def test_bind_policy( @pytest.mark.usefixtures("setup_session") async def test_bind_policy_missing_group( session: AsyncSession, - ldap_session: LDAPSession, settings: Settings, creds: TestCreds, + network_policy_validator: NetworkPolicyValidatorUseCase, ) -> None: """Bind policy fail.""" - policy = await ldap_session._get_policy(IPv4Address("127.0.0.1"), session) # noqa: SLF001 + policy = await network_policy_validator.get_by_protocol( + IPv4Address("127.0.0.1"), + ProtocolType.LDAP, + ) assert policy @@ -368,7 +374,7 @@ async def test_bind_policy_missing_group( user.groups.clear() await session.commit() - assert not await is_user_group_valid(user, policy, session) + assert not await network_policy_validator.is_user_group_valid(user, policy) proc = await asyncio.create_subprocess_exec( "ldapsearch",