diff --git a/.docker/lint.Dockerfile b/.docker/lint.Dockerfile index 8b11ab301..0daa07beb 100644 --- a/.docker/lint.Dockerfile +++ b/.docker/lint.Dockerfile @@ -29,5 +29,4 @@ ENV VIRTUAL_ENV=/venvs/.venv \ COPY --from=builder ${VIRTUAL_ENV} ${VIRTUAL_ENV} -COPY app /app -COPY pyproject.toml ./ +COPY . /app diff --git a/.github/workflows/build-beta.yml b/.github/workflows/build-beta.yml index eaa4671b6..8e2f0b4a9 100644 --- a/.github/workflows/build-beta.yml +++ b/.github/workflows/build-beta.yml @@ -1,7 +1,8 @@ name: build on: push: - branches: [main] + branches: + - 'release-*' env: REPO: ${{ github.repository }} diff --git a/.github/workflows/build-dev.yml b/.github/workflows/build-dev.yml new file mode 100644 index 000000000..6bd6a2ce6 --- /dev/null +++ b/.github/workflows/build-dev.yml @@ -0,0 +1,218 @@ +name: build +on: + push: + branches: [dev] + +env: + REPO: ${{ github.repository }} + +jobs: + build-ssh: + runs-on: ubuntu-latest + steps: + - name: downcase REPO + run: | + echo "REPO=${GITHUB_REPOSITORY,,}" >>${GITHUB_ENV} + - uses: actions/checkout@v4 + - name: Login to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Build docker image + env: + TAG: ghcr.io/${{ env.REPO }}_ssh_test:dev + DOCKER_BUILDKIT: '1' + run: | + echo $TAG + docker build integration_tests/ssh --tag $TAG --cache-from $TAG --build-arg BUILDKIT_INLINE_CACHE=1 + docker push $TAG + + build-tests: + runs-on: ubuntu-latest + steps: + - name: downcase REPO + run: | + echo "REPO=${GITHUB_REPOSITORY,,}" >>${GITHUB_ENV} + - uses: actions/checkout@v4 + - name: Login to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Build docker image + env: + TAG: ghcr.io/${{ env.REPO }}_test:dev + DOCKER_BUILDKIT: '1' + run: | + echo $TAG + docker build --push --target=runtime -f .docker/test.Dockerfile . -t $TAG --cache-to type=gha,mode=max --cache-from $TAG --build-arg BUILDKIT_INLINE_CACHE=1 + + + run-ssh-test: + runs-on: ubuntu-latest + needs: [build-tests, build-ssh] + steps: + - uses: actions/checkout@v4 + - name: Login to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Pull tests + run: cd integration_tests/ssh && docker compose pull + - name: run test enviroment + run: cd integration_tests/ssh && docker compose up -d + - name: run ssh test + run: cd integration_tests/ssh && ./run.sh + - name: shutdown test enviroment + run: cd integration_tests/ssh && docker compose up -d + + run-tests: + runs-on: ubuntu-latest + needs: build-tests + steps: + - name: downcase REPO + run: | + echo "REPO=${GITHUB_REPOSITORY,,}" >>${GITHUB_ENV} + - uses: actions/checkout@v4 + - name: Login to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Pull tests + env: + TAG: ghcr.io/${{ env.REPO }}_test:dev + run: docker compose -f docker-compose.remote.test.yml pull + - name: Run tests + env: + TAG: ghcr.io/${{ env.REPO }}_test:dev + run: docker compose -f docker-compose.remote.test.yml up --no-log-prefix --attach md-test --exit-code-from md-test + - name: Teardown tests + env: + TAG: ghcr.io/${{ env.REPO }}_test:dev + run: docker compose -f docker-compose.remote.test.yml down + + build-app: + runs-on: ubuntu-latest + needs: [build-tests, run-ssh-test, run-tests] + steps: + - name: downcase REPO + run: | + echo "REPO=${GITHUB_REPOSITORY,,}" >>${GITHUB_ENV} + - uses: actions/checkout@v4 + - name: Login to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Build docker image + env: + TAG: ghcr.io/${{ env.REPO }}:dev + DOCKER_BUILDKIT: '1' + run: | + echo $TAG + echo $TAG2 + docker build --push --target=runtime -f .docker/Dockerfile . -t $TAG --cache-to type=gha,mode=max --cache-from $TAG --build-arg BUILDKIT_INLINE_CACHE=1 + + build-kerberos: + runs-on: ubuntu-latest + needs: [build-tests, run-ssh-test, run-tests] + steps: + - name: downcase REPO + run: | + echo "REPO=${GITHUB_REPOSITORY,,}" >>${GITHUB_ENV} + - uses: actions/checkout@v4 + - name: Login to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Build docker image + env: + TAG: ghcr.io/${{ env.REPO }}_kerberos:dev + DOCKER_BUILDKIT: '1' + run: | + echo $TAG + docker build \ + --push \ + --target=runtime \ + -f .docker/krb.Dockerfile . \ + -t $TAG \ + --cache-to type=gha,mode=max \ + --cache-from $TAG \ + --build-arg BUILDKIT_INLINE_CACHE=1 \ + --build-arg VERSION=dev + + build-bind9: + runs-on: ubuntu-latest + needs: [build-tests, run-ssh-test, run-tests] + steps: + - name: downcase REPO + run: | + echo "REPO=${GITHUB_REPOSITORY,,}" >>${GITHUB_ENV} + - uses: actions/checkout@v4 + - name: Login to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Build docker image + env: + TAG: ghcr.io/${{ env.REPO }}_bind9:dev + DOCKER_BUILDKIT: '1' + run: | + echo $TAG + docker build \ + --push \ + --target=runtime \ + -f .docker/bind9.Dockerfile . \ + -t $TAG \ + --cache-to type=gha,mode=max \ + --cache-from $TAG \ + --build-arg BUILDKIT_INLINE_CACHE=1 \ + --build-arg VERSION=dev + + build-keadhcp4: + runs-on: ubuntu-latest + needs: [build-tests, run-ssh-test, run-tests] + steps: + - name: downcase REPO + run: | + echo "REPO=${GITHUB_REPOSITORY,,}" >>${GITHUB_ENV} + - uses: actions/checkout@v4 + - name: Login to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Build docker image + env: + TAG: ghcr.io/${{ env.REPO }}_dhcp4:dev + DOCKER_BUILDKIT: '1' + run: | + echo $TAG + docker build \ + --push \ + --target=runtime \ + -f .docker/kea.Dockerfile . \ + -t $TAG \ + --cache-to type=gha,mode=max \ + --cache-from $TAG \ + --build-arg BUILDKIT_INLINE_CACHE=1 \ + --build-arg VERSION=dev diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 01b92ef8b..b21f307bf 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -4,86 +4,111 @@ on: branches: [main] pull_request: null -env: - REPO: ${{ github.repository }} - jobs: - ruff_linter: + build_linter_image: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + - name: Build linter image + uses: docker/build-push-action@v6 + with: + context: . + file: .docker/lint.Dockerfile + target: runtime + cache-from: type=gha,scope=${{ github.ref_name }} + cache-to: type=gha,mode=max,scope=${{ github.ref_name }} + outputs: type=docker,dest=/tmp/linter_image.tar + tags: linter:latest + - name: Upload linter image + uses: actions/upload-artifact@v4 + with: + name: linter-image + path: /tmp/linter_image.tar + retention-days: 1 + + run_ruff_check: runs-on: ubuntu-latest + needs: build_linter_image steps: - - uses: actions/checkout@v4 - - name: Login to GitHub Container Registry - uses: docker/login-action@v3 + - name: Download linter image + uses: actions/download-artifact@v4 with: - registry: ghcr.io - username: ${{ github.actor }} - password: ${{ secrets.GITHUB_TOKEN }} - - name: build linters - env: - TAG: ghcr.io/${{ env.REPO }}_linters:latest - NEW_TAG: linter - run: docker build --target=runtime -f .docker/lint.Dockerfile . -t $NEW_TAG --cache-to type=gha,mode=max --cache-from $TAG --build-arg BUILDKIT_INLINE_CACHE=1 - - name: Run linters - env: - NEW_TAG: linter - run: docker run $NEW_TAG ruff check --output-format=github . + name: linter-image + path: /tmp + - name: Load image + run: docker load -i /tmp/linter_image.tar + - name: Run ruff check + run: docker run linter:latest ruff check --output-format=github - ruff_format: + run_ruff_format: runs-on: ubuntu-latest + needs: build_linter_image steps: - - uses: actions/checkout@v4 - - name: Login to GitHub Container Registry - uses: docker/login-action@v3 + - name: Download linter image + uses: actions/download-artifact@v4 with: - registry: ghcr.io - username: ${{ github.actor }} - password: ${{ secrets.GITHUB_TOKEN }} - - name: build linters - env: - TAG: ghcr.io/${{ env.REPO }}_linters:latest - NEW_TAG: linter - run: docker build --target=runtime -f .docker/lint.Dockerfile . -t $NEW_TAG --cache-to type=gha,mode=max --cache-from $TAG --build-arg BUILDKIT_INLINE_CACHE=1 - - name: Run linters - env: - NEW_TAG: linter - run: docker run $NEW_TAG ruff format --check + name: linter-image + path: /tmp + - name: Load image + run: docker load -i /tmp/linter_image.tar + - name: Run ruff format + run: docker run linter:latest ruff format --check - mypy: + run_mypy: runs-on: ubuntu-latest + needs: build_linter_image steps: - - uses: actions/checkout@v4 - - name: Login to GitHub Container Registry - uses: docker/login-action@v3 + - name: Download linter image + uses: actions/download-artifact@v4 + with: + name: linter-image + path: /tmp + - name: Load image + run: docker load -i /tmp/linter_image.tar + - name: Run mypy + run: docker run linter:latest mypy . + + build_test_image: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + - name: Build test image + uses: docker/build-push-action@v6 + with: + context: . + file: .docker/test.Dockerfile + target: runtime + cache-from: type=gha,scope=${{ github.ref_name }} + cache-to: type=gha,mode=max,scope=${{ github.ref_name }} + outputs: type=docker,dest=/tmp/test_image.tar + tags: test:latest + - name: Upload test image + uses: actions/upload-artifact@v4 with: - registry: ghcr.io - username: ${{ github.actor }} - password: ${{ secrets.GITHUB_TOKEN }} - - name: build linters - env: - TAG: ghcr.io/${{ env.REPO }}_linters:latest - NEW_TAG: linter - run: docker build --target=runtime -f .docker/lint.Dockerfile . -t $NEW_TAG --cache-to type=gha,mode=max --cache-from $TAG --build-arg BUILDKIT_INLINE_CACHE=1 - - name: Run linters - env: - NEW_TAG: linter - run: docker run $NEW_TAG mypy . + name: test-image + path: /tmp/test_image.tar + retention-days: 1 - tests: + run_tests: runs-on: ubuntu-latest + needs: build_test_image + env: + TAG: test:latest # Injected into docker-compose.remote.test.yml as ${TAG} steps: - - uses: actions/checkout@v4 - - name: Login to GitHub Container Registry - uses: docker/login-action@v3 + - name: Checkout code + uses: actions/checkout@v4 + - name: Download test image + uses: actions/download-artifact@v4 with: - registry: ghcr.io - username: ${{ github.actor }} - password: ${{ secrets.GITHUB_TOKEN }} - - name: build tests - env: - CACHE: ghcr.io/${{ env.REPO }}_test:latest - TAG: tests - run: docker build --target=runtime -f .docker/test.Dockerfile . -t $TAG --cache-to type=gha,mode=max --cache-from $CACHE --build-arg BUILDKIT_INLINE_CACHE=1 + name: test-image + path: /tmp + - name: Load image + run: docker load -i /tmp/test_image.tar - name: Run tests - env: - TAG: tests - run: docker compose -f docker-compose.remote.test.yml up --no-log-prefix --attach md-test --exit-code-from md-test \ No newline at end of file + run: docker compose -f docker-compose.remote.test.yml up --no-log-prefix --attach md-test --exit-code-from md-test diff --git a/.kerberos/entrypoint.sh b/.kerberos/entrypoint.sh index a50bd3ec7..b14b1e9a5 100755 --- a/.kerberos/entrypoint.sh +++ b/.kerberos/entrypoint.sh @@ -2,8 +2,9 @@ set -e -sed -i 's/ou=users/cn=users/g' /etc/kdc/krb5.d/stash.keyfile || true -sed -i 's/ou=users/cn=users/g' /etc/kdc/krb5.conf || true +sed -i 's/ou=users/cn=users/g' /etc/krb5.d/stash.keyfile || true +sed -i 's/ou=users/cn=users/g' /etc/krb5.conf || true + cd /server uvicorn --factory config_server:create_app \ diff --git a/.package/docker-compose.yml b/.package/docker-compose.yml index 839a45e89..104a416bd 100644 --- a/.package/docker-compose.yml +++ b/.package/docker-compose.yml @@ -13,6 +13,9 @@ services: - "53:53/udp" - "80:80" - "389:389" + - "389:389/udp" + - "3268:3268" + - "3269:3269" - "443:443" - "464:464" - "636:636" @@ -67,8 +70,6 @@ services: depends_on: - api_server - - migrations: image: ghcr.io/multidirectorylab/multidirectory:${VERSION:-latest} container_name: multidirectory_migrations @@ -127,7 +128,7 @@ services: - traefik.tcp.routers.ldaps.tls.certResolver=md-resolver - traefik.tcp.services.ldaps.loadbalancer.server.port=636 - traefik.tcp.services.ldaps.loadbalancer.proxyprotocol.version=2 - + cldap_server: image: ghcr.io/multidirectorylab/multidirectory:${VERSION:-latest} restart: unless-stopped @@ -155,6 +156,53 @@ services: - traefik.udp.routers.cldap.service=cldap - traefik.udp.services.cldap.loadbalancer.server.port=389 + global_ldap_server: + deploy: + mode: replicated + replicas: 2 + endpoint_mode: dnsrr + resources: + reservations: + cpus: "0.25" + memory: 100M + image: ghcr.io/multidirectorylab/multidirectory:${VERSION:-latest} + restart: unless-stopped + environment: + - SERVICE_NAME=global_ldap_server + volumes: + - ./certs:/certs + - ./logs:/app/logs + - ldap_keytab:/LDAP_keytab/ + env_file: + - .env + command: python -OO multidirectory.py --global_ldap_server + tty: true + depends_on: + migrations: + condition: service_completed_successfully + healthcheck: + test: ["CMD-SHELL", "nc -zv 127.0.0.1 3268 3269"] + interval: 30s + timeout: 10s + retries: 10 + start_period: 3s + labels: + - traefik.enable=true + + - traefik.tcp.routers.global_ldap.rule=HostSNI(`*`) + - traefik.tcp.routers.global_ldap.entrypoints=global_ldap + - traefik.tcp.routers.global_ldap.service=global_ldap + - traefik.tcp.services.global_ldap.loadbalancer.server.port=3268 + - traefik.tcp.services.global_ldap.loadbalancer.proxyprotocol.version=2 + + - traefik.tcp.routers.global_ldap_tls.rule=HostSNI(`*`) + - traefik.tcp.routers.global_ldap_tls.entrypoints=global_ldap_tls + - traefik.tcp.routers.global_ldap_tls.service=global_ldap_tls + - traefik.tcp.routers.global_ldap_tls.tls=true + - traefik.tcp.routers.global_ldap_tls.tls.certresolver=md-resolver + - traefik.tcp.services.global_ldap_tls.loadbalancer.server.port=3269 + - traefik.tcp.services.global_ldap_tls.loadbalancer.proxyprotocol.version=2 + api_server: image: ghcr.io/multidirectorylab/multidirectory:${VERSION:-latest} container_name: multidirectory_api @@ -225,6 +273,7 @@ services: - ./certs:/certs - psync_queue:/var/spool/krb5-sync - ldap_keytab:/LDAP_keytab/ + - kdc:/etc/krb5kdc/ env_file: .env command: python multidirectory.py --scheduler @@ -269,6 +318,8 @@ services: condition: service_healthy cert_check: condition: service_completed_successfully + kdc: + condition: service_started working_dir: /server command: ./entrypoint.sh kadmind: diff --git a/.package/traefik.yml b/.package/traefik.yml index 77c0f9594..bb8d711de 100644 --- a/.package/traefik.yml +++ b/.package/traefik.yml @@ -24,6 +24,14 @@ entryPoints: address: ":389" proxyProtocol: insecure: true + global_ldap: + address: ":3268" + proxyProtocol: + insecure: true + global_ldap_tls: + address: ":3269" + proxyProtocol: + insecure: true ldaps: address: ":636" proxyProtocol: diff --git a/Makefile b/Makefile index dd78ff0c8..efc0edf41 100644 --- a/Makefile +++ b/Makefile @@ -2,10 +2,10 @@ help: ## show help message @awk 'BEGIN {FS = ":.*##"; printf "\nUsage:\n make \033[36m\033[0m\n"} /^[$$()% a-zA-Z_-]+:.*?##/ { printf " \033[36m%-15s\033[0m %s\n", $$1, $$2 } /^##@/ { printf "\n\033[1m%s\033[0m\n", substr($$0, 5) } ' $(MAKEFILE_LIST) -before_pr: - ruff format ./app - ruff check ./app --fix --unsafe-fixes - mypy ./app +before_pr: ## format, lint and type-check code + ruff format + ruff check --fix --unsafe-fixes + mypy . build: ## build app and manually generate self-signed cert make down @@ -15,27 +15,34 @@ build_test: docker compose -f docker-compose.test.yml build up: ## run tty container with related services, use with run command - make down; docker compose up + make down + docker compose up test: ## run tests docker compose -f docker-compose.test.yml down --remove-orphans - make down; + make down docker compose -f docker-compose.test.yml up --no-log-prefix --attach test --exit-code-from test run: ## runs server 386/636 port clear;docker exec -it multidirectory sh -c "python ." launch: ## run standalone app without tty container - docker compose down; - docker compose run sh -c "alembic upgrade head && python ." + docker compose down + docker compose run sh -c "python multidirectory.py --migrate && python ." -downgrade: ## re-run migration - docker exec -it multidirectory_api sh -c\ - "alembic downgrade -1; alembic upgrade head;" +rerun_last_migration: ## re-run migration + docker exec -it multidirectory_api sh -c "alembic downgrade -1; python multidirectory.py --migrate;" down: ## shutdown services docker compose -f docker-compose.test.yml down --remove-orphans docker compose down --remove-orphans + docker volume prune -f + +migrations: ## generate migration file + docker compose run ldap_server alembic revision --autogenerate + +migrate: ## upgrade db + docker compose run ldap_server python multidirectory.py --migrate # server stage/development commands @@ -47,28 +54,21 @@ stage_build: ## build stage server docker compose -f docker-compose.dev.yml build stage_up: ## run app and detach - make stage_down; + make stage_down docker compose -f docker-compose.dev.yml up -d stage_down: ## stop all services docker compose -f docker-compose.dev.yml down --remove-orphans stage_update: ## update service - make stage_down; - make stage_build; - docker compose -f docker-compose.dev.yml pull; - make stage_up; - docker exec -it multidirectory-ldap sh -c\ - "alembic downgrade -1; alembic upgrade head; python -m extra.setup_dev" + make stage_down + make stage_build + docker compose -f docker-compose.dev.yml pull + make stage_up + docker exec -it multidirectory-ldap sh -c "alembic downgrade -1; python multidirectory.py --migrate; python -m extra.setup_dev" krb_client_build: ## build krb client service docker build -f integration_tests/kerberos/Dockerfile . -t krbclient:runtime krb_client: ## run krb client bash docker run --rm --init -it --name krbclient --network multidirectory_default krbclient:runtime bash - -migrations: ## generate migration file - docker compose run ldap_server alembic revision --autogenerate - -migrate: ## upgrade db - docker compose run ldap_server alembic upgrade head diff --git a/app/alembic/env.py b/app/alembic/env.py index fc5acfcc7..6c08a67bd 100644 --- a/app/alembic/env.py +++ b/app/alembic/env.py @@ -4,10 +4,18 @@ from logging.config import fileConfig from alembic import context +from dishka import AsyncContainer, make_async_container from sqlalchemy import Connection, text -from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.ext.asyncio import AsyncConnection from config import Settings +from ioc import ( + HTTPProvider, + MainProvider, + MFACredsProvider, + MFAProvider, + MigrationProvider, +) from repo.pg.tables import metadata # this is the Alembic Config object, which provides @@ -22,7 +30,11 @@ target_metadata = metadata -def run_sync_migrations(connection: Connection, schema_name: str) -> None: +def run_sync_migrations( + connection: Connection, + schema_name: str, + dishka_container: AsyncContainer, +) -> None: """Run sync migrations.""" if schema_name != "public": connection.execute(text(f"SET search_path = {schema_name}, public;")) @@ -35,18 +47,20 @@ def run_sync_migrations(connection: Connection, schema_name: str) -> None: ) with context.begin_transaction(): - context.run_migrations() + context.run_migrations(container=dishka_container) -async def run_async_migrations(settings: Settings) -> None: +async def run_async_migrations( + settings: Settings, + dishka_container: AsyncContainer, +) -> None: """Run async migrations.""" - engine = create_async_engine(str(settings.POSTGRES_URI)) - - async with engine.connect() as connection: - await connection.run_sync( - run_sync_migrations, - schema_name=settings.TEST_POSTGRES_SCHEMA, - ) + connection = await dishka_container.get(AsyncConnection) + await connection.run_sync( + run_sync_migrations, + schema_name=settings.TEST_POSTGRES_SCHEMA, + dishka_container=dishka_container, + ) def run_migrations_online() -> None: @@ -60,11 +74,25 @@ def run_migrations_online() -> None: "app_settings", Settings.from_os(), ) + dishka_container = context.config.attributes.get("dishka_container", None) + if not dishka_container: + dishka_container = make_async_container( + MainProvider(), + MFACredsProvider(), + MFAProvider(), + HTTPProvider(), + MigrationProvider(), + context={Settings: settings}, + ) if conn is None: - asyncio.run(run_async_migrations(settings)) + asyncio.run(run_async_migrations(settings, dishka_container)) else: - run_sync_migrations(conn, schema_name=settings.TEST_POSTGRES_SCHEMA) + run_sync_migrations( + conn, + schema_name=settings.TEST_POSTGRES_SCHEMA, + dishka_container=dishka_container, + ) run_migrations_online() diff --git a/app/alembic/script.py.mako b/app/alembic/script.py.mako index cc49f4c6c..b2cdf925a 100644 --- a/app/alembic/script.py.mako +++ b/app/alembic/script.py.mako @@ -7,6 +7,7 @@ Create Date: ${create_date} """ from alembic import op import sqlalchemy as sa +from dishka import AsyncContainer ${imports if imports else ""} # revision identifiers, used by Alembic. @@ -16,11 +17,11 @@ branch_labels: None | list[str] = ${repr(branch_labels)} depends_on: None | list[str] = ${repr(depends_on)} -def upgrade() -> None: +def upgrade(container: AsyncContainer) -> None: """Upgrade.""" ${upgrades if upgrades else "pass"} -def downgrade() -> None: +def downgrade(container: AsyncContainer) -> None: """Downgrade.""" ${downgrades if downgrades else "pass"} diff --git a/app/alembic/versions/01f3f05a5b11_add_primary_group_id.py b/app/alembic/versions/01f3f05a5b11_add_primary_group_id.py index 237a558ab..73be02c1d 100644 --- a/app/alembic/versions/01f3f05a5b11_add_primary_group_id.py +++ b/app/alembic/versions/01f3f05a5b11_add_primary_group_id.py @@ -1,22 +1,27 @@ """Add primaryGroupId attribute and domain computers group. Revision ID: 01f3f05a5b11 -Revises: 8164b4a9e1f1 +Revises: c007129b7973 Create Date: 2025-09-26 12:36:05.974255 """ +import sqlalchemy as sa from alembic import op +from dishka import AsyncContainer, Scope from sqlalchemy import delete, exists, select from sqlalchemy.exc import DBAPIError, IntegrityError from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession from sqlalchemy.orm import Session, selectinload +from constants import DOMAIN_COMPUTERS_GROUP_NAME from entities import Attribute, Directory, EntityType, Group +from enums import EntityTypeNames +from extra.alembic_utils import temporary_stub_column +from ldap_protocol.ldap_schema.attribute_value_validator import ( + AttributeValueValidator, +) from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO -from ldap_protocol.ldap_schema.object_class_dao import ObjectClassDAO -from ldap_protocol.roles.ace_dao import AccessControlEntryDAO -from ldap_protocol.roles.role_dao import RoleDAO from ldap_protocol.roles.role_use_case import RoleUseCase from ldap_protocol.utils.queries import ( create_group, @@ -33,30 +38,24 @@ depends_on: None = None -def upgrade() -> None: +@temporary_stub_column("is_system", sa.Boolean()) +def upgrade(container: AsyncContainer) -> None: """Upgrade.""" - async def _add_domain_computers_group(connection: AsyncConnection) -> None: - session = AsyncSession(connection) - await session.begin() + async def _add_domain_computers_group(connection: AsyncConnection) -> None: # noqa: ARG001 + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + entity_type_dao = await cnt.get(EntityTypeDAO) + role_use_case = await cnt.get(RoleUseCase) base_dn_list = await get_base_directories(session) if not base_dn_list: return - object_class_dao = ObjectClassDAO(session) - entity_type_dao = EntityTypeDAO( - session, - object_class_dao=object_class_dao, - ) - role_dao = RoleDAO(session) - ace_dao = AccessControlEntryDAO(session) - role_use_case = RoleUseCase(role_dao, ace_dao) - try: group_dir_query = select( exists(Directory) - .where(qa(Directory.name) == "domain computers"), + .where(qa(Directory.name) == DOMAIN_COMPUTERS_GROUP_NAME), ) # fmt: skip group_dir = (await session.scalars(group_dir_query)).one() @@ -64,14 +63,17 @@ async def _add_domain_computers_group(connection: AsyncConnection) -> None: return dir_, group_ = await create_group( - name="domain computers", + name=DOMAIN_COMPUTERS_GROUP_NAME, sid=515, + attribute_value_validator=AttributeValueValidator(), session=session, ) await session.flush() - computer_entity_type = await entity_type_dao.get("Computer") + computer_entity_type = await entity_type_dao.get( + EntityTypeNames.COMPUTER, + ) computer_dirs = await session.scalars( select(Directory) .where( @@ -116,9 +118,9 @@ async def _add_domain_computers_group(connection: AsyncConnection) -> None: op.run_async(_add_domain_computers_group) - async def _add_primary_group_id(connection: AsyncConnection) -> None: - session = AsyncSession(connection) - await session.begin() + async def _add_primary_group_id(connection: AsyncConnection) -> None: # noqa: ARG001 + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) base_dn_list = await get_base_directories(session) if not base_dn_list: @@ -126,7 +128,11 @@ async def _add_primary_group_id(connection: AsyncConnection) -> None: entity_type = await session.scalars( select(qa(EntityType.id)) - .where(qa(EntityType.name).in_(["User", "Computer"])), + .where( + qa(EntityType.name).in_( + [EntityTypeNames.USER, EntityTypeNames.COMPUTER], + ), + ), ) # fmt: skip entity_type_ids = list(entity_type.all()) @@ -160,21 +166,20 @@ async def _add_primary_group_id(connection: AsyncConnection) -> None: except (IntegrityError, DBAPIError): pass - await session.close() - op.run_async(_add_primary_group_id) -def downgrade() -> None: +@temporary_stub_column("is_system", sa.Boolean()) +def downgrade(container: AsyncContainer) -> None: """Downgrade.""" bind = op.get_bind() session = Session(bind=bind) async def _delete_domain_computers_group( - connection: AsyncConnection, + connection: AsyncConnection, # noqa: ARG001 ) -> None: - session = AsyncSession(connection) - await session.begin() + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) base_dn_list = await get_base_directories(session) if not base_dn_list: diff --git a/app/alembic/versions/05ddc0bd562a_add_roles.py b/app/alembic/versions/05ddc0bd562a_add_roles.py index dd1205544..aff773460 100644 --- a/app/alembic/versions/05ddc0bd562a_add_roles.py +++ b/app/alembic/versions/05ddc0bd562a_add_roles.py @@ -7,12 +7,12 @@ import sqlalchemy as sa from alembic import op +from dishka import AsyncContainer, Scope from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession from entities import Directory, Group -from ldap_protocol.roles.ace_dao import AccessControlEntryDAO -from ldap_protocol.roles.role_dao import RoleDAO +from extra.alembic_utils import temporary_stub_column from ldap_protocol.roles.role_use_case import RoleUseCase from ldap_protocol.utils.queries import get_base_directories from repo.pg.tables import queryable_attr as qa @@ -24,7 +24,8 @@ depends_on: None = None -def upgrade() -> None: +@temporary_stub_column("is_system", sa.Boolean()) +def upgrade(container: AsyncContainer) -> None: """Upgrade.""" op.create_table( "Roles", @@ -152,17 +153,15 @@ def upgrade() -> None: op.drop_table("AccessPolicyMemberships") op.drop_table("AccessPolicies") - async def _create_system_roles(connection: AsyncConnection) -> None: - session = AsyncSession(connection) - await session.begin() + async def _create_system_roles(connection: AsyncConnection) -> None: # noqa: ARG001 + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + role_use_case = await cnt.get(RoleUseCase) base_dn_list = await get_base_directories(session) if not base_dn_list: return - role_dao = RoleDAO(session) - ace_dao = AccessControlEntryDAO(session) - role_use_case = RoleUseCase(role_dao, ace_dao) await role_use_case.create_domain_admins_role() await role_use_case.create_read_only_role() @@ -184,7 +183,7 @@ async def _create_system_roles(connection: AsyncConnection) -> None: op.run_async(_create_system_roles) -def downgrade() -> None: +def downgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Downgrade.""" op.create_table( "AccessPolicies", diff --git a/app/alembic/versions/16a9fa2c1f1e_rename_readonly_group.py b/app/alembic/versions/16a9fa2c1f1e_rename_readonly_group.py index 06af9c227..b331dddd5 100644 --- a/app/alembic/versions/16a9fa2c1f1e_rename_readonly_group.py +++ b/app/alembic/versions/16a9fa2c1f1e_rename_readonly_group.py @@ -6,12 +6,16 @@ """ +import sqlalchemy as sa from alembic import op +from dishka import AsyncContainer from sqlalchemy import select, update from sqlalchemy.exc import DBAPIError, IntegrityError from sqlalchemy.orm import Session, selectinload +from constants import READ_ONLY_GROUP_NAME from entities import Attribute, Directory +from extra.alembic_utils import temporary_stub_column from repo.pg.tables import queryable_attr as qa # revision identifiers, used by Alembic. @@ -21,7 +25,8 @@ depends_on: None | list[str] = None -def upgrade() -> None: +@temporary_stub_column("is_system", sa.Boolean()) +def upgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Upgrade.""" bind = op.get_bind() session = Session(bind=bind) @@ -29,19 +34,15 @@ def upgrade() -> None: try: ro_dir_query = ( select(Directory) - .options( - selectinload(qa(Directory.parent)), - ) - .where( - qa(Directory.name) == "readonly domain controllers", - ) - ) # fmt: skip + .options(selectinload(qa(Directory.parent))) + .where(qa(Directory.name) == "readonly domain controllers") + ) ro_dir = session.scalar(ro_dir_query) if not ro_dir: return - ro_dir.name = "read-only" + ro_dir.name = READ_ONLY_GROUP_NAME ro_dir.create_path(ro_dir.parent, ro_dir.get_dn_prefix()) @@ -72,7 +73,8 @@ def upgrade() -> None: session.close() -def downgrade() -> None: +@temporary_stub_column("is_system", sa.Boolean()) +def downgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Downgrade.""" bind = op.get_bind() session = Session(bind=bind) @@ -80,13 +82,9 @@ def downgrade() -> None: try: ro_dir_query = ( select(Directory) - .options( - selectinload(qa(Directory.parent)), - ) - .where( - qa(Directory.name) == "read-only", - ) - ) # fmt: skip + .options(selectinload(qa(Directory.parent))) + .where(qa(Directory.name) == READ_ONLY_GROUP_NAME) + ) ro_dir = session.scalar(ro_dir_query) if not ro_dir: @@ -101,7 +99,7 @@ def downgrade() -> None: .filter_by( name="sAMAccountName", directory=ro_dir, - value="read-only", + value=READ_ONLY_GROUP_NAME, ) .values({"value": ro_dir.name}), ) @@ -111,7 +109,7 @@ def downgrade() -> None: .filter_by( name="cn", directory=ro_dir, - value="read-only", + value=READ_ONLY_GROUP_NAME, ) .values({"value": ro_dir.name}), ) diff --git a/app/alembic/versions/196f0d327c6a_.py b/app/alembic/versions/196f0d327c6a_.py index 4faf37682..877b27be6 100644 --- a/app/alembic/versions/196f0d327c6a_.py +++ b/app/alembic/versions/196f0d327c6a_.py @@ -7,6 +7,7 @@ """ from alembic import op +from dishka import AsyncContainer # revision identifiers, used by Alembic. revision = "196f0d327c6a" @@ -15,7 +16,7 @@ depends_on: None | str = None -def upgrade() -> None: +def upgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Upgrade.""" op.drop_constraint( "AccessPolicyMemberships_policy_id_fkey", @@ -201,7 +202,7 @@ def upgrade() -> None: ) -def downgrade() -> None: +def downgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Downgrade.""" op.drop_constraint( "PolicyMemberships_policy_id_fkey", diff --git a/app/alembic/versions/275222846605_initial_ldap_schema.py b/app/alembic/versions/275222846605_initial_ldap_schema.py index 82957bf42..226c9270b 100644 --- a/app/alembic/versions/275222846605_initial_ldap_schema.py +++ b/app/alembic/versions/275222846605_initial_ldap_schema.py @@ -10,16 +10,16 @@ import sqlalchemy as sa from alembic import op +from dishka import AsyncContainer, Scope from ldap3.protocol.schemas.ad2012R2 import ad_2012_r2_schema -from sqlalchemy import delete, or_ +from sqlalchemy import delete, or_, select from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, selectinload -from entities import Attribute -from extra.alembic_utils import temporary_stub_entity_type_name +from entities import Attribute, AttributeType, ObjectClass +from extra.alembic_utils import temporary_stub_column from ldap_protocol.ldap_schema.attribute_type_dao import AttributeTypeDAO from ldap_protocol.ldap_schema.dto import AttributeTypeDTO -from ldap_protocol.ldap_schema.object_class_dao import ObjectClassDAO from ldap_protocol.utils.raw_definition_parser import ( RawDefinitionParser as RDParser, ) @@ -35,8 +35,8 @@ ad_2012_r2_schema_json = json.loads(ad_2012_r2_schema) -@temporary_stub_entity_type_name -def upgrade() -> None: +@temporary_stub_column("entity_type_id", sa.Integer()) +def upgrade(container: AsyncContainer) -> None: """Upgrade.""" bind = op.get_bind() session = Session(bind=bind) @@ -267,9 +267,9 @@ def upgrade() -> None: session.commit() # NOTE: Load objectClasses into the database - async def _create_object_classes(connection: AsyncConnection) -> None: - session = AsyncSession(bind=connection) - await session.begin() + async def _create_object_classes(connection: AsyncConnection) -> None: # noqa: ARG001 + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) oc_already_created_oids = set() oc_first_priority_raw_definitions = ( @@ -342,11 +342,11 @@ async def _create_object_classes(connection: AsyncConnection) -> None: op.run_async(_create_object_classes) - async def _create_attribute_types(connection: AsyncConnection) -> None: - session = AsyncSession(bind=connection) - await session.begin() + async def _create_attribute_types(connection: AsyncConnection) -> None: # noqa: ARG001 + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + attribute_type_dao = await cnt.get(AttributeTypeDAO) - attribute_type_dao = AttributeTypeDAO(session) for oid, name in ( ("2.16.840.1.113730.3.1.610", "nsAccountLock"), ("1.3.6.1.4.1.99999.1.1", "posixEmail"), @@ -367,12 +367,9 @@ async def _create_attribute_types(connection: AsyncConnection) -> None: op.run_async(_create_attribute_types) - async def _modify_object_classes(connection: AsyncConnection) -> None: - session = AsyncSession(bind=connection) - await session.begin() - - at_dao = AttributeTypeDAO(session) - oc_dao = ObjectClassDAO(session) + async def _modify_object_classes(connection: AsyncConnection) -> None: # noqa: ARG001 + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) for oc_name, at_names in ( ("user", ["nsAccountLock", "shadowExpire"]), @@ -380,9 +377,22 @@ async def _modify_object_classes(connection: AsyncConnection) -> None: ("posixAccount", ["posixEmail"]), ("organizationalUnit", ["title", "jpegPhoto"]), ): - object_class = await oc_dao.get(oc_name) - attribute_types_may = await at_dao.get_all_by_names(at_names) - object_class.attribute_types_may.extend(attribute_types_may) + object_class = await session.scalar( + select(ObjectClass) + .filter_by(name=oc_name) + .options(selectinload(qa(ObjectClass.attribute_types_may))), + ) + + if not object_class: + continue + + attribute_types = await session.scalars( + select(AttributeType) + .where(qa(AttributeType.name).in_(at_names), + ), + ) # fmt: skip + + object_class.attribute_types_may.extend(attribute_types.all()) await session.commit() @@ -393,7 +403,7 @@ async def _modify_object_classes(connection: AsyncConnection) -> None: session.commit() -def downgrade() -> None: +def downgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Downgrade.""" op.drop_index( "idx_object_classes_name_gin_trgm", diff --git a/app/alembic/versions/35d1542d2505_add_entity_id.py b/app/alembic/versions/35d1542d2505_add_entity_id.py index 4a3d6009d..2e0c020bd 100644 --- a/app/alembic/versions/35d1542d2505_add_entity_id.py +++ b/app/alembic/versions/35d1542d2505_add_entity_id.py @@ -8,6 +8,7 @@ import sqlalchemy as sa from alembic import op +from dishka import AsyncContainer from sqlalchemy.sql import text # revision identifiers, used by Alembic. @@ -17,7 +18,7 @@ depends_on: None = None -def upgrade() -> None: +def upgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Upgrade.""" op.add_column( "EntityTypes", @@ -87,7 +88,7 @@ def upgrade() -> None: op.drop_column("Directory", "entity_type_name") -def downgrade() -> None: +def downgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Downgrade.""" op.add_column( "Directory", diff --git a/app/alembic/versions/4334e2e871a4_add_sessions_ttl.py b/app/alembic/versions/4334e2e871a4_add_sessions_ttl.py index e656fd296..583a11afc 100644 --- a/app/alembic/versions/4334e2e871a4_add_sessions_ttl.py +++ b/app/alembic/versions/4334e2e871a4_add_sessions_ttl.py @@ -8,6 +8,7 @@ import sqlalchemy as sa from alembic import op +from dishka import AsyncContainer # revision identifiers, used by Alembic. revision = "4334e2e871a4" @@ -16,7 +17,7 @@ depends_on: None | str = None -def upgrade() -> None: +def upgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Upgrade.""" op.add_column( "Policies", @@ -38,7 +39,7 @@ def upgrade() -> None: ) -def downgrade() -> None: +def downgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Downgrade.""" op.drop_column("Policies", "http_session_ttl") op.drop_column("Policies", "ldap_session_ttl") diff --git a/app/alembic/versions/4442d1d982a4_remove_krb_policy.py b/app/alembic/versions/4442d1d982a4_remove_krb_policy.py index 92d706192..5673da6a8 100644 --- a/app/alembic/versions/4442d1d982a4_remove_krb_policy.py +++ b/app/alembic/versions/4442d1d982a4_remove_krb_policy.py @@ -6,12 +6,14 @@ """ +import sqlalchemy as sa from alembic import op +from dishka import AsyncContainer from sqlalchemy import delete from sqlalchemy.orm import Session from entities import Attribute, Directory -from extra.alembic_utils import temporary_stub_entity_type_name +from extra.alembic_utils import temporary_stub_column # revision identifiers, used by Alembic. revision = "4442d1d982a4" @@ -20,8 +22,8 @@ depends_on: None | str = None -@temporary_stub_entity_type_name -def upgrade() -> None: +@temporary_stub_column("entity_type_id", sa.Integer()) +def upgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Upgrade.""" bind = op.get_bind() session = Session(bind=bind) @@ -29,5 +31,5 @@ def upgrade() -> None: session.execute(delete(Attribute).filter_by(name="krbpwdpolicyreference")) -def downgrade() -> None: +def downgrade(container: AsyncContainer) -> None: """Downgrade.""" diff --git a/app/alembic/versions/4798b12b97aa_dedicated_servers.py b/app/alembic/versions/4798b12b97aa_dedicated_servers.py index 460e36d20..ae625df78 100644 --- a/app/alembic/versions/4798b12b97aa_dedicated_servers.py +++ b/app/alembic/versions/4798b12b97aa_dedicated_servers.py @@ -8,6 +8,7 @@ import sqlalchemy as sa from alembic import op +from dishka import AsyncContainer from loguru import logger from sqlalchemy.orm import Session @@ -21,7 +22,7 @@ depends_on: None | str = None -def upgrade() -> None: +def upgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Upgrade.""" op.create_table( "DedicatedServer", @@ -88,7 +89,7 @@ def upgrade() -> None: session.commit() -def downgrade() -> None: +def downgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Downgrade.""" bind = op.get_bind() session = Session(bind=bind) diff --git a/app/alembic/versions/4e8772277cfe_add_web_permissions.py b/app/alembic/versions/4e8772277cfe_add_web_permissions.py index 437bcea4f..8567f40e5 100644 --- a/app/alembic/versions/4e8772277cfe_add_web_permissions.py +++ b/app/alembic/versions/4e8772277cfe_add_web_permissions.py @@ -7,6 +7,7 @@ """ from alembic import op +from dishka import AsyncContainer, Scope from sqlalchemy import Column, select, text from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession @@ -21,17 +22,19 @@ depends_on: None | list[str] = None -def upgrade() -> None: +def upgrade(container: AsyncContainer) -> None: """Upgrade.""" - async def _add_api_permissions(connection: AsyncConnection) -> None: - session = AsyncSession(connection) - await session.begin() + async def _add_api_permissions(connection: AsyncConnection) -> None: # noqa: ARG001 + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + query = ( select(Role) .filter_by(name=RoleConstants.DOMAIN_ADMINS_ROLE_NAME) ) # fmt: skip - role = (await session.scalars(query)).first() + role = await session.scalar(query) + if role: role.permissions = AuthorizationRules.get_all() await session.commit() @@ -48,6 +51,6 @@ async def _add_api_permissions(connection: AsyncConnection) -> None: op.run_async(_add_api_permissions) -def downgrade() -> None: +def downgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Downgrade.""" op.drop_column("Roles", "permissions") diff --git a/app/alembic/versions/56082d7ac0d4_remove_old_templates.py b/app/alembic/versions/56082d7ac0d4_remove_old_templates.py index b2dd56e0c..655ab9bc6 100644 --- a/app/alembic/versions/56082d7ac0d4_remove_old_templates.py +++ b/app/alembic/versions/56082d7ac0d4_remove_old_templates.py @@ -6,6 +6,8 @@ """ +from dishka import AsyncContainer + # revision identifiers, used by Alembic. revision: None | str = "56082d7ac0d4" down_revision: None | str = "16a9fa2c1f1e" @@ -13,9 +15,9 @@ depends_on: None | list[str] = None -def upgrade() -> None: +def upgrade(container: AsyncContainer) -> None: """Upgrade.""" -def downgrade() -> None: +def downgrade(container: AsyncContainer) -> None: """Downgrade.""" diff --git a/app/alembic/versions/59e98bbd8ad8_.py b/app/alembic/versions/59e98bbd8ad8_.py index 810ca7178..aa8ba7145 100644 --- a/app/alembic/versions/59e98bbd8ad8_.py +++ b/app/alembic/versions/59e98bbd8ad8_.py @@ -8,6 +8,7 @@ import sqlalchemy as sa from alembic import op +from dishka import AsyncContainer from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. @@ -17,7 +18,7 @@ depends_on: None | str = None -def upgrade() -> None: +def upgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Upgrade.""" op.create_table( "AccessPolicies", @@ -331,7 +332,7 @@ def upgrade() -> None: ) -def downgrade() -> None: +def downgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Downgrade.""" op.drop_index("ix_directory_objectGUID", table_name="Directory") op.drop_index(op.f("ix_Directory_path"), table_name="Directory") diff --git a/app/alembic/versions/6303f5c706ec_update_krbadmin_useraccountcontrol_.py b/app/alembic/versions/6303f5c706ec_update_krbadmin_useraccountcontrol_.py index f01dd0b00..6d101e95a 100644 --- a/app/alembic/versions/6303f5c706ec_update_krbadmin_useraccountcontrol_.py +++ b/app/alembic/versions/6303f5c706ec_update_krbadmin_useraccountcontrol_.py @@ -6,12 +6,15 @@ """ +import sqlalchemy as sa from alembic import op +from dishka import AsyncContainer, Scope from sqlalchemy import select, update from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession from sqlalchemy.orm import joinedload from entities import Attribute, Directory +from extra.alembic_utils import temporary_stub_column from ldap_protocol.objects import UserAccountControlFlag from ldap_protocol.utils.helpers import create_integer_hash from repo.pg.tables import queryable_attr as qa @@ -23,12 +26,13 @@ depends_on: None | list[str] = None -def upgrade() -> None: +@temporary_stub_column("is_system", sa.Boolean()) +def upgrade(container: AsyncContainer) -> None: """Upgrade.""" - async def _update_krbadmin_uac(connection: AsyncConnection) -> None: - session = AsyncSession(connection) - await session.begin() + async def _update_krbadmin_uac(connection: AsyncConnection) -> None: # noqa: ARG001 + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) krbadmin_user_dir = await session.scalar( select(Directory) @@ -51,9 +55,9 @@ async def _update_krbadmin_uac(connection: AsyncConnection) -> None: ), ) - async def _change_uid_admin(connection: AsyncConnection) -> None: - session = AsyncSession(bind=connection) - await session.begin() + async def _change_uid_admin(connection: AsyncConnection) -> None: # noqa: ARG001 + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) directory = await session.scalar( select(Directory) @@ -89,12 +93,13 @@ async def _change_uid_admin(connection: AsyncConnection) -> None: op.run_async(_change_uid_admin) -def downgrade() -> None: +@temporary_stub_column("is_system", sa.Boolean()) +def downgrade(container: AsyncContainer) -> None: """Downgrade.""" - async def _downgrade_krbadmin_uac(connection: AsyncConnection) -> None: - session = AsyncSession(connection) - await session.begin() + async def _downgrade_krbadmin_uac(connection: AsyncConnection) -> None: # noqa: ARG001 + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) krbadmin_user_dir = await session.scalar( select(Directory) diff --git a/app/alembic/versions/692ae64e0cc5_.py b/app/alembic/versions/692ae64e0cc5_.py index e2cef3b6b..0e6ab434c 100755 --- a/app/alembic/versions/692ae64e0cc5_.py +++ b/app/alembic/versions/692ae64e0cc5_.py @@ -7,6 +7,7 @@ """ from alembic import op +from dishka import AsyncContainer # revision identifiers, used by Alembic. revision = "692ae64e0cc5" @@ -15,23 +16,19 @@ depends_on: None | str = None -def upgrade() -> None: +def upgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Upgrade.""" - # ### commands auto generated by Alembic - please adjust! ### op.create_unique_constraint( "group_policy_uc", "GroupAccessPolicyMemberships", ["group_id", "policy_id"], ) - # ### end Alembic commands ### -def downgrade() -> None: +def downgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Downgrade.""" - # ### commands auto generated by Alembic - please adjust! ### op.drop_constraint( "group_policy_uc", "GroupAccessPolicyMemberships", type_="unique", ) - # ### end Alembic commands ### diff --git a/app/alembic/versions/6c858cc05da7_add_default_admin_name.py b/app/alembic/versions/6c858cc05da7_add_default_admin_name.py new file mode 100644 index 000000000..7b1a59f1e --- /dev/null +++ b/app/alembic/versions/6c858cc05da7_add_default_admin_name.py @@ -0,0 +1,57 @@ +"""Add givenName attribute to users without it. + +Revision ID: 6c858cc05da7 +Revises: 56082d7ac0d4 +Create Date: 2025-12-19 17:26:02.630201 + +""" + +import sqlalchemy as sa +from alembic import op +from dishka import AsyncContainer +from sqlalchemy.orm import Session + +from entities import Attribute, User +from extra.alembic_utils import temporary_stub_column +from repo.pg.tables import queryable_attr as qa + +# revision identifiers, used by Alembic. +revision: None | str = "6c858cc05da7" +down_revision: None | str = "56082d7ac0d4" +branch_labels: None | list[str] = None +depends_on: None | list[str] = None + + +@temporary_stub_column("is_system", sa.Boolean()) +def upgrade(container: AsyncContainer) -> None: # noqa: ARG001 + """Upgrade.""" + bind = op.get_bind() + session = Session(bind=bind) + + users_without_given_name = session.scalars( + sa.select(User).where( + ~sa.exists( + sa.select(1) + .where( + qa(Attribute.directory_id) == qa(User.directory_id), + qa(Attribute.name) == "givenName", + ) + .select_from(Attribute), + ), + ), + ).all() + + for user in users_without_given_name: + session.add( + Attribute( + directory_id=user.directory_id, + name="givenName", + value=user.sam_account_name, + ), + ) + + session.commit() + + +def downgrade(container: AsyncContainer) -> None: + """Downgrade.""" diff --git a/app/alembic/versions/6f8fe2548893_fix_read_only.py b/app/alembic/versions/6f8fe2548893_fix_read_only.py index f190cb676..8d0f87874 100644 --- a/app/alembic/versions/6f8fe2548893_fix_read_only.py +++ b/app/alembic/versions/6f8fe2548893_fix_read_only.py @@ -6,12 +6,15 @@ """ +import sqlalchemy as sa from alembic import op +from dishka import AsyncContainer from sqlalchemy import delete, select, update from sqlalchemy.orm import Session +from constants import DOMAIN_USERS_GROUP_NAME from entities import Attribute, Directory -from extra.alembic_utils import temporary_stub_entity_type_name +from extra.alembic_utils import temporary_stub_column from ldap_protocol.utils.helpers import create_integer_hash # revision identifiers, used by Alembic. @@ -21,8 +24,9 @@ depends_on: None = None -@temporary_stub_entity_type_name -def upgrade() -> None: +@temporary_stub_column("entity_type_id", sa.Integer()) +@temporary_stub_column("is_system", sa.Boolean()) +def upgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Upgrade.""" bind = op.get_bind() session = Session(bind=bind) @@ -42,7 +46,7 @@ def upgrade() -> None: .filter_by( name="sAMAccountName", directory=ro_dir, - value="domain users", + value=DOMAIN_USERS_GROUP_NAME, ) .values({"value": ro_dir.name}), ) @@ -84,5 +88,5 @@ def upgrade() -> None: session.commit() -def downgrade() -> None: +def downgrade(container: AsyncContainer) -> None: """Downgrade.""" diff --git a/app/alembic/versions/71e642808369_add_directory_is_system.py b/app/alembic/versions/71e642808369_add_directory_is_system.py new file mode 100644 index 000000000..2526190e4 --- /dev/null +++ b/app/alembic/versions/71e642808369_add_directory_is_system.py @@ -0,0 +1,104 @@ +"""Add directory is_system column. + +Revision ID: 71e642808369 +Revises: a99f866a7e3a +Create Date: 2026-01-15 09:08:12.866533 + +""" + +import sqlalchemy as sa +from alembic import op +from dishka import AsyncContainer, Scope +from sqlalchemy import update +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession +from sqlalchemy.orm import Session + +from constants import ( + COMPUTERS_CONTAINER_NAME, + DOMAIN_ADMIN_GROUP_NAME, + DOMAIN_COMPUTERS_GROUP_NAME, + DOMAIN_USERS_GROUP_NAME, + GROUPS_CONTAINER_NAME, + READ_ONLY_GROUP_NAME, + USERS_CONTAINER_NAME, +) +from entities import Directory +from ldap_protocol.utils.queries import get_base_directories +from repo.pg.tables import queryable_attr as qa + +# revision identifiers, used by Alembic. +revision: None | str = "71e642808369" +down_revision: None | str = "a99f866a7e3a" +branch_labels: None | list[str] = None +depends_on: None | list[str] = None + + +def upgrade(container: AsyncContainer) -> None: + """Upgrade.""" + bind = op.get_bind() + session = Session(bind=bind) + + op.add_column( + "Directory", + sa.Column("is_system", sa.Boolean(), nullable=True), + ) + # NOTE: If instances of Directories exists, set default value + session.execute(update(Directory).values({"is_system": False})) + op.alter_column("Directory", "is_system", nullable=False) + + async def _indicate_system_directories( + connection: AsyncConnection, # noqa: ARG001 + ) -> None: + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + + base_dn_list = await get_base_directories(session) + if not base_dn_list: + return + + for base_dn in base_dn_list: + base_dn.is_system = True + + await session.flush() + + await session.execute( + update(Directory) + .where( + qa(Directory.is_system).is_(False), + qa(Directory.name).in_( + ( + GROUPS_CONTAINER_NAME, + DOMAIN_ADMIN_GROUP_NAME, + DOMAIN_USERS_GROUP_NAME, + READ_ONLY_GROUP_NAME, + DOMAIN_COMPUTERS_GROUP_NAME, + COMPUTERS_CONTAINER_NAME, + USERS_CONTAINER_NAME, + "services", + "krbadmin", + "kerberos", + ), + ), + ) + .values(is_system=True), + ) + await session.flush() + + # NOTE: It's required to mark only administrator users as system. + # Because only main administrator has object_class=='user'. + await session.execute( + update(Directory) + .where( + qa(Directory.is_system).is_(False), + qa(Directory.object_class) == "user", + ) + .values(is_system=True), + ) + await session.flush() + + op.run_async(_indicate_system_directories) + + +def downgrade(container: AsyncContainer) -> None: # noqa: ARG001 + """Downgrade.""" + op.drop_column("Directory", "is_system") diff --git a/app/alembic/versions/8164b4a9e1f1_add_ou_computers.py b/app/alembic/versions/8164b4a9e1f1_add_ou_computers.py index 8d6ab8b18..5f8608a4a 100644 --- a/app/alembic/versions/8164b4a9e1f1_add_ou_computers.py +++ b/app/alembic/versions/8164b4a9e1f1_add_ou_computers.py @@ -6,18 +6,17 @@ """ +import sqlalchemy as sa from alembic import op +from dishka import AsyncContainer, Scope from sqlalchemy import delete, exists, select from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession +from constants import COMPUTERS_CONTAINER_NAME from entities import Directory -from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO -from ldap_protocol.ldap_schema.object_class_dao import ObjectClassDAO -from ldap_protocol.roles.ace_dao import AccessControlEntryDAO -from ldap_protocol.roles.role_dao import RoleDAO +from extra.alembic_utils import temporary_stub_column from ldap_protocol.roles.role_use_case import RoleUseCase from ldap_protocol.utils.queries import get_base_directories -from password_utils import PasswordUtils from repo.pg.tables import queryable_attr as qa # revision identifiers, used by Alembic. @@ -28,27 +27,23 @@ _OU_COMPUTERS_DATA = { - "name": "computers", + "name": COMPUTERS_CONTAINER_NAME, "object_class": "organizationalUnit", "attributes": {"objectClass": ["top", "container"]}, "children": [], } -def upgrade() -> None: +@temporary_stub_column("is_system", sa.Boolean()) +def upgrade(container: AsyncContainer) -> None: """Upgrade.""" from ldap_protocol.auth.setup_gateway import SetupGateway - async def _create_ou_computers(connection: AsyncConnection) -> None: - session = AsyncSession(bind=connection) - await session.begin() - object_class_dao = ObjectClassDAO(session) - entity_type_dao = EntityTypeDAO(session, object_class_dao) - setup_gateway = SetupGateway( - session, - PasswordUtils(), - entity_type_dao, - ) + async def _create_ou_computers(connection: AsyncConnection) -> None: # noqa: ARG001 + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + setup_gateway = await cnt.get(SetupGateway) + role_use_case = await cnt.get(RoleUseCase) base_directories = await get_base_directories(session) if not base_directories: @@ -58,7 +53,7 @@ async def _create_ou_computers(connection: AsyncConnection) -> None: exists_ou_computers = await session.scalar( select( exists(Directory) - .where(qa(Directory.name) == "computers"), + .where(qa(Directory.name) == COMPUTERS_CONTAINER_NAME), ), ) # fmt: skip if exists_ou_computers: @@ -66,20 +61,18 @@ async def _create_ou_computers(connection: AsyncConnection) -> None: await setup_gateway.create_dir( _OU_COMPUTERS_DATA, - domain_dir, - domain_dir, + is_system=True, + domain=domain_dir, + parent=domain_dir, ) ou_computers_dir = await session.scalar( select(Directory) - .where(qa(Directory.name) == "computers"), + .where(qa(Directory.name) == COMPUTERS_CONTAINER_NAME), ) # fmt: skip if not ou_computers_dir: raise Exception("Directory 'ou=computers' not found.") - role_dao = RoleDAO(session) - ace_dao = AccessControlEntryDAO(session) - role_use_case = RoleUseCase(role_dao, ace_dao) await role_use_case.inherit_parent_aces( parent_directory=domain_dir, directory=ou_computers_dir, @@ -90,12 +83,13 @@ async def _create_ou_computers(connection: AsyncConnection) -> None: op.run_async(_create_ou_computers) -def downgrade() -> None: +@temporary_stub_column("is_system", sa.Boolean()) +def downgrade(container: AsyncContainer) -> None: """Downgrade.""" - async def _delete_ou_computers(connection: AsyncConnection) -> None: - session = AsyncSession(bind=connection) - await session.begin() + async def _delete_ou_computers(connection: AsyncConnection) -> None: # noqa: ARG001 + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) base_dn_list = await get_base_directories(session) if not base_dn_list: @@ -103,7 +97,7 @@ async def _delete_ou_computers(connection: AsyncConnection) -> None: await session.execute( delete(Directory) - .where(qa(Directory.name) == "computers"), + .where(qa(Directory.name) == COMPUTERS_CONTAINER_NAME), ) # fmt: skip await session.commit() diff --git a/app/alembic/versions/8c2bd40dd809_add_protocols_attr.py b/app/alembic/versions/8c2bd40dd809_add_protocols_attr.py index 8d83c9170..b18a512ea 100644 --- a/app/alembic/versions/8c2bd40dd809_add_protocols_attr.py +++ b/app/alembic/versions/8c2bd40dd809_add_protocols_attr.py @@ -8,6 +8,7 @@ import sqlalchemy as sa from alembic import op +from dishka import AsyncContainer # revision identifiers, used by Alembic. revision = "8c2bd40dd809" @@ -16,7 +17,7 @@ depends_on: None = None -def upgrade() -> None: +def upgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Upgrade.""" for protocol_field in ("is_http", "is_ldap", "is_kerberos"): op.add_column( @@ -30,7 +31,7 @@ def upgrade() -> None: ) -def downgrade() -> None: +def downgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Downgrade.""" for protocol_field in ("is_http", "is_ldap", "is_kerberos"): op.drop_column("Policies", protocol_field) diff --git a/app/alembic/versions/93ba193c6a53_add_hash_index_on_dir_path.py b/app/alembic/versions/93ba193c6a53_add_hash_index_on_dir_path.py index 37baac07c..5610e4661 100644 --- a/app/alembic/versions/93ba193c6a53_add_hash_index_on_dir_path.py +++ b/app/alembic/versions/93ba193c6a53_add_hash_index_on_dir_path.py @@ -8,6 +8,7 @@ import sqlalchemy as sa from alembic import op +from dishka import AsyncContainer # revision identifiers, used by Alembic. revision: None | str = "93ba193c6a53" @@ -16,7 +17,7 @@ depends_on: None | list[str] = None -def upgrade() -> None: +def upgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Upgrade.""" op.execute( sa.text( @@ -26,6 +27,6 @@ def upgrade() -> None: ) -def downgrade() -> None: +def downgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Downgrade.""" op.execute(sa.text("DROP INDEX idx_directory_path_hash")) diff --git a/app/alembic/versions/a1b2c3d4e5f6_rename_services_to_system.py b/app/alembic/versions/a1b2c3d4e5f6_rename_services_to_system.py new file mode 100644 index 000000000..e8d480d94 --- /dev/null +++ b/app/alembic/versions/a1b2c3d4e5f6_rename_services_to_system.py @@ -0,0 +1,143 @@ +"""Rename services container to System for AD compatibility. + +Revision ID: a1b2c3d4e5f6 +Revises: 6c858cc05da7 +Create Date: 2026-01-13 12:00:00.000000 + +""" + +from alembic import op +from dishka import AsyncContainer, Scope +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession + +from entities import Attribute, Directory +from repo.pg.tables import queryable_attr as qa + +# revision identifiers, used by Alembic. +revision: None | str = "a1b2c3d4e5f6" +down_revision: None | str = "c5a9b3f2e8d7" +branch_labels: None | list[str] = None +depends_on: None | list[str] = None + + +async def _update_descendants( + session: AsyncSession, + parent_id: int, + ou_from: str, + ou_to: str, +) -> None: + """Recursively update paths of all descendants.""" + child_dirs = await session.scalars( + select(Directory) + .where(qa(Directory.parent_id) == parent_id), + ) # fmt: skip + + for child_dir in child_dirs: + child_dir.path = [ou_to if p == ou_from else p for p in child_dir.path] + await session.flush() + await _update_descendants( + session, + child_dir.id, + ou_from=ou_from, + ou_to=ou_to, + ) + + +async def _update_attributes( + session: AsyncSession, + old_value: str, + new_value: str, +) -> None: + """Update attribute values during downgrade.""" + result = await session.execute( + select(Attribute) + .where( + Attribute.value.ilike(f"%{old_value}%"), # type: ignore + ), + ) # fmt: skip + attributes = result.scalars().all() + + for attr in attributes: + if attr.value and old_value in attr.value: + attr.value = attr.value.replace(old_value, new_value) + + await session.flush() + + +def upgrade(container: AsyncContainer) -> None: + """Upgrade: Rename 'services' container to 'System'.""" + + async def _rename_services_to_system(connection: AsyncConnection) -> None: # noqa: ARG001 + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + + service_dir = await session.scalar( + select(Directory).where( + qa(Directory.name) == "services", + qa(Directory.is_system).is_(True), + ), + ) + if not service_dir: + return + ou_to = "ou=System" + ou_from = "ou=services" + + service_dir.name = "System" + service_dir.path = [ + ou_to if p == ou_from else p for p in service_dir.path + ] + + await session.flush() + await _update_descendants( + session, + service_dir.id, + ou_from=ou_from, + ou_to=ou_to, + ) + + await _update_attributes(session, ou_from, ou_to) + await session.commit() + + op.run_async(_rename_services_to_system) + + +def downgrade(container: AsyncContainer) -> None: + """Downgrade: Rename 'System' container back to 'services'.""" + + async def _rename_system_to_services(connection: AsyncConnection) -> None: # noqa: ARG001 + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + + system_dir = await session.scalar( + select(Directory).where( + qa(Directory.name) == "System", + qa(Directory.is_system).is_(True), + ), + ) + if not system_dir: + return + ou_to = "ou=services" + ou_from = "ou=System" + + system_dir.name = "services" + system_dir.path = [ + ou_to if p == ou_from else p for p in system_dir.path + ] + + await session.flush() + await _update_descendants( + session, + system_dir.id, + ou_from=ou_from, + ou_to=ou_to, + ) + + await _update_attributes( + session, + ou_from, + ou_to, + ) + await session.commit() + + op.run_async(_rename_system_to_services) diff --git a/app/alembic/versions/a7971f00ba4d_index_single_level.py b/app/alembic/versions/a7971f00ba4d_index_single_level.py index 611594085..0a147ea98 100644 --- a/app/alembic/versions/a7971f00ba4d_index_single_level.py +++ b/app/alembic/versions/a7971f00ba4d_index_single_level.py @@ -8,6 +8,7 @@ import sqlalchemy as sa from alembic import op +from dishka import AsyncContainer # revision identifiers, used by Alembic. revision = "a7971f00ba4d" @@ -16,7 +17,7 @@ depends_on: None | str = None -def upgrade() -> None: +def upgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Create index for Directory depth field.""" op.execute( sa.text( @@ -89,7 +90,7 @@ def upgrade() -> None: ) -def downgrade() -> None: +def downgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Remove indexes for Directory depth field and Attributes table.""" op.drop_index("idx_User_san_gin", "Users") op.drop_index("idx_User_upn_gin", "Users") diff --git a/app/alembic/versions/a99f866a7e3a_add_user_pwd_reset_permission.py b/app/alembic/versions/a99f866a7e3a_add_user_pwd_reset_permission.py new file mode 100644 index 000000000..47aece016 --- /dev/null +++ b/app/alembic/versions/a99f866a7e3a_add_user_pwd_reset_permission.py @@ -0,0 +1,59 @@ +"""Add user reset password history permission to Domain Admins role. + +Revision ID: a99f866a7e3a +Revises: 6c858cc05da7 +Create Date: 2025-12-23 10:20:29.147813 + +""" + +from alembic import op +from dishka import AsyncContainer, Scope +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession + +from entities import Role +from enums import AuthorizationRules, RoleConstants + +# revision identifiers, used by Alembic. +revision: None | str = "a99f866a7e3a" +down_revision: None | str = "6c858cc05da7" +branch_labels: None | list[str] = None +depends_on: None | list[str] = None + + +def upgrade(container: AsyncContainer) -> None: + """Upgrade.""" + + async def _add_api_permission(connection: AsyncConnection) -> None: # noqa: ARG001 + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + + query = ( + select(Role) + .filter_by(name=RoleConstants.DOMAIN_ADMINS_ROLE_NAME) + ) # fmt: skip + role = await session.scalar(query) + + if role: + role.permissions |= AuthorizationRules.USER_CLEAR_PASSWORD_HISTORY + + op.run_async(_add_api_permission) + + +def downgrade(container: AsyncContainer) -> None: + """Downgrade.""" + + async def _remove_api_permission(connection: AsyncConnection) -> None: # noqa: ARG001 + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + + query = ( + select(Role) + .filter_by(name=RoleConstants.DOMAIN_ADMINS_ROLE_NAME) + ) # fmt: skip + role = await session.scalar(query) + + if role: + role.permissions &= ~AuthorizationRules.USER_CLEAR_PASSWORD_HISTORY + + op.run_async(_remove_api_permission) diff --git a/app/alembic/versions/ad52bc16b87d_extend_password_policy_for_api.py b/app/alembic/versions/ad52bc16b87d_extend_password_policy_for_api.py index 5246e47c3..26b592787 100644 --- a/app/alembic/versions/ad52bc16b87d_extend_password_policy_for_api.py +++ b/app/alembic/versions/ad52bc16b87d_extend_password_policy_for_api.py @@ -8,6 +8,7 @@ import sqlalchemy as sa from alembic import op +from dishka import AsyncContainer from sqlalchemy import update from sqlalchemy.orm import Session @@ -23,7 +24,7 @@ depends_on: None | list[str] = None -def upgrade() -> None: +def upgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Upgrade.""" bind = op.get_bind() session = Session(bind=bind) @@ -101,7 +102,7 @@ def upgrade() -> None: ) -def downgrade() -> None: +def downgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Downgrade.""" op.drop_table("GroupPasswordPolicyMemberships") diff --git a/app/alembic/versions/ba78cef9700a_initial_entity_type.py b/app/alembic/versions/ba78cef9700a_initial_entity_type.py index 1e46ed611..0e6744919 100644 --- a/app/alembic/versions/ba78cef9700a_initial_entity_type.py +++ b/app/alembic/versions/ba78cef9700a_initial_entity_type.py @@ -8,17 +8,17 @@ import sqlalchemy as sa from alembic import op +from dishka import AsyncContainer, Scope from sqlalchemy import exists, or_, select from sqlalchemy.dialects import postgresql from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession from constants import ENTITY_TYPE_DATAS from entities import Attribute, Directory, User -from extra.alembic_utils import temporary_stub_entity_type_name +from extra.alembic_utils import temporary_stub_column from ldap_protocol.ldap_schema.dto import EntityTypeDTO from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO from ldap_protocol.ldap_schema.entity_type_use_case import EntityTypeUseCase -from ldap_protocol.ldap_schema.object_class_dao import ObjectClassDAO from ldap_protocol.utils.queries import get_base_directories from repo.pg.tables import queryable_attr as qa @@ -29,8 +29,9 @@ depends_on: None | str = None -@temporary_stub_entity_type_name -def upgrade() -> None: +@temporary_stub_column("entity_type_id", sa.Integer()) +@temporary_stub_column("is_system", sa.Boolean()) +def upgrade(container: AsyncContainer) -> None: """Upgrade database schema and data, creating Entity Types.""" op.create_table( "EntityTypes", @@ -96,28 +97,19 @@ def upgrade() -> None: ["oid"], ) - async def _create_entity_types(connection: AsyncConnection) -> None: - session = AsyncSession(bind=connection) - await session.begin() + async def _create_entity_types(connection: AsyncConnection) -> None: # noqa: ARG001 + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + entity_type_use_case = await cnt.get(EntityTypeUseCase) if not await get_base_directories(session): return - object_class_dao = ObjectClassDAO(session) - entity_type_dao = EntityTypeDAO( - session, - object_class_dao=object_class_dao, - ) - entity_type_use_case = EntityTypeUseCase( - entity_type_dao, - object_class_dao, - ) - for entity_type_data in ENTITY_TYPE_DATAS: await entity_type_use_case.create( EntityTypeDTO( - name=entity_type_data["name"], # type: ignore - object_class_names=entity_type_data["object_class_names"], # type: ignore + name=entity_type_data["name"], + object_class_names=entity_type_data["object_class_names"], is_system=True, ), ) @@ -125,10 +117,10 @@ async def _create_entity_types(connection: AsyncConnection) -> None: await session.commit() async def _append_object_class_to_user_dirs( - connection: AsyncConnection, + connection: AsyncConnection, # noqa: ARG001 ) -> None: - session = AsyncSession(bind=connection) - await session.begin() + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) if not await get_base_directories(session): return @@ -163,20 +155,15 @@ async def _append_object_class_to_user_dirs( await session.commit() async def _attach_entity_type_to_directories( - connection: AsyncConnection, + connection: AsyncConnection, # noqa: ARG001 ) -> None: - session = AsyncSession(bind=connection) - await session.begin() + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + entity_type_dao = await cnt.get(EntityTypeDAO) if not await get_base_directories(session): return - object_class_dao = ObjectClassDAO(session) - entity_type_dao = EntityTypeDAO( - session, - object_class_dao=object_class_dao, - ) - await entity_type_dao.attach_entity_type_to_directories() await session.commit() @@ -187,7 +174,7 @@ async def _attach_entity_type_to_directories( op.drop_column("EntityTypes", "id") -def downgrade() -> None: +def downgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Downgrade database schema and data back to the previous state.""" op.drop_index( "idx_entity_types_name_gin_trgm", diff --git a/app/alembic/versions/bf435bbd95ff_add_rdn_attr_name.py b/app/alembic/versions/bf435bbd95ff_add_rdn_attr_name.py index c0a1342f9..88eaf4581 100644 --- a/app/alembic/versions/bf435bbd95ff_add_rdn_attr_name.py +++ b/app/alembic/versions/bf435bbd95ff_add_rdn_attr_name.py @@ -8,10 +8,11 @@ import sqlalchemy as sa from alembic import op +from dishka import AsyncContainer from sqlalchemy.orm import Session from entities import Attribute, Directory -from extra.alembic_utils import temporary_stub_entity_type_name +from extra.alembic_utils import temporary_stub_column from repo.pg.tables import queryable_attr as qa # revision identifiers, used by Alembic. @@ -21,8 +22,9 @@ depends_on: None | str = None -@temporary_stub_entity_type_name -def upgrade() -> None: +@temporary_stub_column("entity_type_id", sa.Integer()) +@temporary_stub_column("is_system", sa.Boolean()) +def upgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Upgrade.""" op.add_column("Directory", sa.Column("rdname", sa.String(length=64))) @@ -56,8 +58,9 @@ def upgrade() -> None: op.alter_column("Directory", "rdname", nullable=False) -@temporary_stub_entity_type_name -def downgrade() -> None: +@temporary_stub_column("entity_type_id", sa.Integer()) +@temporary_stub_column("is_system", sa.Boolean()) +def downgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Downgrade.""" bind = op.get_bind() session = Session(bind=bind) diff --git a/app/alembic/versions/bv546ccd35fa_fix_krbadmin_attrs.py b/app/alembic/versions/bv546ccd35fa_fix_krbadmin_attrs.py index 1f4c7ac0f..dfaa36aa0 100644 --- a/app/alembic/versions/bv546ccd35fa_fix_krbadmin_attrs.py +++ b/app/alembic/versions/bv546ccd35fa_fix_krbadmin_attrs.py @@ -8,10 +8,11 @@ import sqlalchemy as sa from alembic import op +from dishka import AsyncContainer from sqlalchemy.orm import Session from entities import Attribute, Directory -from extra.alembic_utils import temporary_stub_entity_type_name +from extra.alembic_utils import temporary_stub_column from repo.pg.tables import queryable_attr as qa # revision identifiers, used by Alembic. @@ -21,8 +22,9 @@ depends_on: None | str = None -@temporary_stub_entity_type_name -def upgrade() -> None: +@temporary_stub_column("entity_type_id", sa.Integer()) +@temporary_stub_column("is_system", sa.Boolean()) +def upgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Upgrade.""" bind = op.get_bind() session = Session(bind=bind) @@ -79,5 +81,5 @@ def upgrade() -> None: session.commit() -def downgrade() -> None: +def downgrade(container: AsyncContainer) -> None: """Downgrade.""" diff --git a/app/alembic/versions/c007129b7973_renew_sls_gpos_rsops.py b/app/alembic/versions/c007129b7973_renew_sls_gpos_rsops.py index 1e3b7c40c..bc6bb3701 100644 --- a/app/alembic/versions/c007129b7973_renew_sls_gpos_rsops.py +++ b/app/alembic/versions/c007129b7973_renew_sls_gpos_rsops.py @@ -6,6 +6,8 @@ """ +from dishka import AsyncContainer + # revision identifiers, used by Alembic. revision: None | str = "c007129b7973" down_revision: None | str = "8164b4a9e1f1" @@ -13,9 +15,9 @@ depends_on: None | list[str] = None -def upgrade() -> None: +def upgrade(container: AsyncContainer) -> None: """Upgrade.""" -def downgrade() -> None: +def downgrade(container: AsyncContainer) -> None: """Downgrade.""" diff --git a/app/alembic/versions/c4888c68e221_fix_admin_attr_and_policy.py b/app/alembic/versions/c4888c68e221_fix_admin_attr_and_policy.py index bbd904958..dbaa321be 100644 --- a/app/alembic/versions/c4888c68e221_fix_admin_attr_and_policy.py +++ b/app/alembic/versions/c4888c68e221_fix_admin_attr_and_policy.py @@ -8,12 +8,13 @@ import sqlalchemy as sa from alembic import op +from dishka import AsyncContainer, Scope from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession from sqlalchemy.orm import joinedload from entities import Attribute, Directory, NetworkPolicy +from extra.alembic_utils import temporary_stub_column from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO -from ldap_protocol.ldap_schema.object_class_dao import ObjectClassDAO from ldap_protocol.utils.helpers import create_integer_hash from ldap_protocol.utils.queries import get_base_directories from repo.pg.tables import queryable_attr as qa @@ -25,31 +26,26 @@ depends_on: None | list[str] = None -def upgrade() -> None: +@temporary_stub_column("is_system", sa.Boolean()) +def upgrade(container: AsyncContainer) -> None: """Upgrade.""" async def _attach_entity_type_to_directories( - connection: AsyncConnection, + connection: AsyncConnection, # noqa: ARG001 ) -> None: - session = AsyncSession(bind=connection) - await session.begin() + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + entity_type_dao = await cnt.get(EntityTypeDAO) if not await get_base_directories(session): return - object_class_dao = ObjectClassDAO( - session, - ) - entity_type_dao = EntityTypeDAO( - session, - object_class_dao=object_class_dao, - ) await entity_type_dao.attach_entity_type_to_directories() await session.commit() - async def _change_uid_admin(connection: AsyncConnection) -> None: - session = AsyncSession(bind=connection) - await session.begin() + async def _change_uid_admin(connection: AsyncConnection) -> None: # noqa: ARG001 + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) directory = await session.scalar( sa.select(Directory) @@ -81,9 +77,9 @@ async def _change_uid_admin(connection: AsyncConnection) -> None: ) await session.commit() - async def _change_ldap_session_ttl(connection: AsyncConnection) -> None: - session = AsyncSession(bind=connection) - await session.begin() + async def _change_ldap_session_ttl(connection: AsyncConnection) -> None: # noqa: ARG001 + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) await session.execute( sa.update(NetworkPolicy) @@ -101,5 +97,5 @@ async def _change_ldap_session_ttl(connection: AsyncConnection) -> None: op.run_async(_attach_entity_type_to_directories) -def downgrade() -> None: +def downgrade(container: AsyncContainer) -> None: """Downgrade.""" diff --git a/app/alembic/versions/c5a9b3f2e8d7_add_contact_object_class.py b/app/alembic/versions/c5a9b3f2e8d7_add_contact_object_class.py new file mode 100644 index 000000000..33bd3a433 --- /dev/null +++ b/app/alembic/versions/c5a9b3f2e8d7_add_contact_object_class.py @@ -0,0 +1,82 @@ +"""Add Contact objectClass and mailRecipient to LDAP schema. + +Revision ID: c5a9b3f2e8d7 +Revises: 8164b4a9e1f1, f1abf7ef2443 +Create Date: 2026-01-19 12:00:00.000000 + +""" + +from alembic import op +from dishka import AsyncContainer, Scope +from sqlalchemy import delete +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession + +from entities import EntityType +from enums import EntityTypeNames +from ldap_protocol.ldap_schema.dto import EntityTypeDTO +from ldap_protocol.ldap_schema.entity_type_use_case import EntityTypeUseCase +from ldap_protocol.utils.queries import get_base_directories +from repo.pg.tables import queryable_attr as qa + +# revision identifiers, used by Alembic. +revision = "c5a9b3f2e8d7" +down_revision = "71e642808369" +branch_labels: None | str = None +depends_on: None | str = None + + +def upgrade(container: AsyncContainer) -> None: + """Add Contact objectClass and mailRecipient to LDAP schema.""" + + async def _create_entity_type( + connection: AsyncConnection, # noqa: ARG001 + ) -> None: + """Create Contact Entity Type.""" + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + entity_type_use_case = await cnt.get(EntityTypeUseCase) + + if not await get_base_directories(session): + return + + await entity_type_use_case.create( + EntityTypeDTO( + name=EntityTypeNames.CONTACT, + object_class_names=[ + "top", + "person", + "organizationalPerson", + "contact", + "mailRecipient", + ], + is_system=True, + ), + ) + + await session.commit() + + op.run_async(_create_entity_type) + + +def downgrade(container: AsyncContainer) -> None: + """Remove Contact objectClass and mailRecipient from LDAP schema.""" + + async def _delete_entity_type( + connection: AsyncConnection, # noqa: ARG001 + ) -> None: + """Delete Contact Entity Type.""" + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + + if not await get_base_directories(session): + return + + await session.execute( + delete(EntityType).where( + qa(EntityType.name) == EntityTypeNames.CONTACT, + ), + ) + + await session.commit() + + op.run_async(_delete_entity_type) diff --git a/app/alembic/versions/dafg3a4b22ab_add_preauth_princ.py b/app/alembic/versions/dafg3a4b22ab_add_preauth_princ.py index 06731ae67..38d982694 100644 --- a/app/alembic/versions/dafg3a4b22ab_add_preauth_princ.py +++ b/app/alembic/versions/dafg3a4b22ab_add_preauth_princ.py @@ -8,10 +8,11 @@ import sqlalchemy as sa from alembic import op +from dishka import AsyncContainer from sqlalchemy.orm import Session from entities import Attribute, CatalogueSetting, User -from extra.alembic_utils import temporary_stub_entity_type_name +from extra.alembic_utils import temporary_stub_column from ldap_protocol.kerberos import KERBEROS_STATE_NAME from repo.pg.tables import queryable_attr as qa @@ -22,8 +23,9 @@ depends_on: None | str = None -@temporary_stub_entity_type_name -def upgrade() -> None: +@temporary_stub_column("entity_type_id", sa.Integer()) +@temporary_stub_column("is_system", sa.Boolean()) +def upgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Upgrade.""" bind = op.get_bind() session = Session(bind=bind) @@ -76,7 +78,7 @@ def upgrade() -> None: ) -def downgrade() -> None: +def downgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Downgrade.""" op.drop_index(op.f("ix_Settings_name"), table_name="Settings") op.create_index( diff --git a/app/alembic/versions/df4c52a613e5_migrate_password_prop_from_E.py b/app/alembic/versions/df4c52a613e5_migrate_password_prop_from_E.py index d1c078db1..2a2c7d436 100644 --- a/app/alembic/versions/df4c52a613e5_migrate_password_prop_from_E.py +++ b/app/alembic/versions/df4c52a613e5_migrate_password_prop_from_E.py @@ -8,6 +8,7 @@ import sqlalchemy as sa from alembic import op +from dishka import AsyncContainer, Scope from sqlalchemy import update from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession from sqlalchemy.orm import Session @@ -27,7 +28,7 @@ _BAN_WORDS = set(f.read().split("\n")) -def upgrade() -> None: +def upgrade(container: AsyncContainer) -> None: """Upgrade.""" bind = op.get_bind() session = Session(bind=bind) @@ -45,11 +46,11 @@ def upgrade() -> None: ), ) - async def _create_common_passwords(connection: AsyncConnection) -> None: - session = AsyncSession(bind=connection) - await session.begin() + async def _create_common_passwords(connection: AsyncConnection) -> None: # noqa: ARG001 + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + password_ban_word_repo = await cnt.get(PasswordBanWordRepository) - password_ban_word_repo = PasswordBanWordRepository(session) await password_ban_word_repo.replace(_BAN_WORDS) await session.commit() @@ -256,7 +257,7 @@ async def _create_common_passwords(connection: AsyncConnection) -> None: ) -def downgrade() -> None: +def downgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Downgrade.""" op.execute( sa.text("DROP INDEX IF EXISTS idx_password_ban_words_word_gin_trgm"), diff --git a/app/alembic/versions/e4d6d99d32bd_add_audit_policies.py b/app/alembic/versions/e4d6d99d32bd_add_audit_policies.py index 77ad0f091..0a64e7fb7 100644 --- a/app/alembic/versions/e4d6d99d32bd_add_audit_policies.py +++ b/app/alembic/versions/e4d6d99d32bd_add_audit_policies.py @@ -10,9 +10,11 @@ import sqlalchemy as sa from alembic import op +from dishka import AsyncContainer, Scope from sqlalchemy.dialects import postgresql from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession +from extra.alembic_utils import temporary_stub_column from ldap_protocol.policies.audit.audit_use_case import AuditUseCase from ldap_protocol.policies.audit.destination_dao import AuditDestinationDAO from ldap_protocol.policies.audit.events.managers import RawAuditManager @@ -26,16 +28,19 @@ depends_on: None | str = None -def upgrade() -> None: +@temporary_stub_column("is_system", sa.Boolean()) +def upgrade(container: AsyncContainer) -> None: """Upgrade.""" - async def _create_audit_policies(connection: AsyncConnection) -> None: - session = AsyncSession(bind=connection) + async def _create_audit_policies(connection: AsyncConnection) -> None: # noqa: ARG001 + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + audit_dao = await cnt.get(AuditPoliciesDAO) + dest_dao = await cnt.get(AuditDestinationDAO) if not await get_base_directories(session): return - audit_dao = AuditPoliciesDAO(session) - dest_dao = AuditDestinationDAO(session) + manager = Mock(spec=RawAuditManager) use_case = AuditUseCase(audit_dao, dest_dao, manager) await use_case.create_policies() @@ -139,7 +144,7 @@ async def _create_audit_policies(connection: AsyncConnection) -> None: op.run_async(_create_audit_policies) -def downgrade() -> None: +def downgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Downgrade.""" op.drop_table("AuditPolicyTriggers") op.drop_table("AuditPolicies") diff --git a/app/alembic/versions/ec45e3e8aa0f_drop_password_policy_id_column.py b/app/alembic/versions/ec45e3e8aa0f_drop_password_policy_id_column.py new file mode 100644 index 000000000..a01bd542d --- /dev/null +++ b/app/alembic/versions/ec45e3e8aa0f_drop_password_policy_id_column.py @@ -0,0 +1,47 @@ +"""Drop unused Directory.password_policy_id column. + +Revision ID: ec45e3e8aa0f +Revises: a1b2c3d4e5f6 +Create Date: 2026-01-20 14:33:36.236135 + +""" + +import sqlalchemy as sa +from alembic import op +from dishka import AsyncContainer + +# revision identifiers, used by Alembic. +revision: None | str = "ec45e3e8aa0f" +down_revision: None | str = "a1b2c3d4e5f6" +branch_labels: None | list[str] = None +depends_on: None | list[str] = None + + +def upgrade(container: AsyncContainer) -> None: # noqa: ARG001 + """Upgrade.""" + op.drop_constraint( + op.f("Directory_password_policy_id_fkey"), + "Directory", + type_="foreignkey", + ) + op.drop_column("Directory", "password_policy_id") + + +def downgrade(container: AsyncContainer) -> None: # noqa: ARG001 + """Downgrade.""" + op.add_column( + "Directory", + sa.Column( + "password_policy_id", + sa.INTEGER(), + autoincrement=False, + nullable=True, + ), + ) + op.create_foreign_key( + op.f("Directory_password_policy_id_fkey"), + "Directory", + "PasswordPolicies", + ["password_policy_id"], + ["id"], + ) diff --git a/app/alembic/versions/eeaed5989eb0_group_policies.py b/app/alembic/versions/eeaed5989eb0_group_policies.py index 49b46fed6..de76492b8 100644 --- a/app/alembic/versions/eeaed5989eb0_group_policies.py +++ b/app/alembic/versions/eeaed5989eb0_group_policies.py @@ -6,6 +6,8 @@ """ +from dishka import AsyncContainer + # revision identifiers, used by Alembic. revision = "eeaed5989eb0" down_revision = "e4d6d99d32bd" @@ -13,9 +15,9 @@ depends_on: None | str = None -def upgrade() -> None: +def upgrade(container: AsyncContainer) -> None: """Upgrade.""" -def downgrade() -> None: +def downgrade(container: AsyncContainer) -> None: """Downgrade.""" diff --git a/app/alembic/versions/f1abf7ef2443_add_container_object_class.py b/app/alembic/versions/f1abf7ef2443_add_container_object_class.py index d879bd591..cf1c80e1b 100644 --- a/app/alembic/versions/f1abf7ef2443_add_container_object_class.py +++ b/app/alembic/versions/f1abf7ef2443_add_container_object_class.py @@ -6,11 +6,15 @@ """ +import sqlalchemy as sa from alembic import op +from dishka import AsyncContainer, Scope from sqlalchemy import delete, func, insert, select, update from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession from entities import Attribute, Directory, EntityType +from enums import EntityTypeNames +from extra.alembic_utils import temporary_stub_column from repo.pg.tables import queryable_attr as qa # revision identifiers, used by Alembic. @@ -20,15 +24,16 @@ depends_on: None | str = None -def upgrade() -> None: +@temporary_stub_column("is_system", sa.Boolean()) +def upgrade(container: AsyncContainer) -> None: """Upgrade.""" async def _migrate_ou_to_cn_containers( - connection: AsyncConnection, + connection: AsyncConnection, # noqa: ARG001 ) -> None: """Migrate existing ou= containers to cn= containers.""" - session = AsyncSession(bind=connection) - await session.begin() + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) containers_to_migrate = ["groups", "computers", "users"] directories = await session.scalars( @@ -39,7 +44,7 @@ async def _migrate_ou_to_cn_containers( ) entity_type = await session.scalar( select(EntityType) - .where(qa(EntityType.name) == "Container"), + .where(qa(EntityType.name) == EntityTypeNames.CONTAINER), ) # fmt: skip for directory in directories: @@ -105,15 +110,16 @@ async def _migrate_ou_to_cn_containers( op.run_async(_migrate_ou_to_cn_containers) -def downgrade() -> None: +@temporary_stub_column("is_system", sa.Boolean()) +def downgrade(container: AsyncContainer) -> None: """Downgrade.""" async def _migrate_cn_to_ou_containers( - connection: AsyncConnection, + connection: AsyncConnection, # noqa: ARG001 ) -> None: """Migrate existing cn= containers back to ou= containers.""" - session = AsyncSession(bind=connection) - await session.begin() + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) containers_to_migrate = ["groups", "computers", "users"] directories = await session.scalars( @@ -124,7 +130,7 @@ async def _migrate_cn_to_ou_containers( ) entity_type = await session.scalar( select(EntityType) - .where(qa(EntityType.name) == "Organizational Unit"), + .where(qa(EntityType.name) == EntityTypeNames.ORGANIZATIONAL_UNIT), ) # fmt: skip for directory in directories: diff --git a/app/alembic/versions/f24ed0e49df2_add_filter_anr.py b/app/alembic/versions/f24ed0e49df2_add_filter_anr.py index 991da67c7..b6ec3ee1a 100644 --- a/app/alembic/versions/f24ed0e49df2_add_filter_anr.py +++ b/app/alembic/versions/f24ed0e49df2_add_filter_anr.py @@ -8,6 +8,7 @@ import sqlalchemy as sa from alembic import op +from dishka import AsyncContainer from sqlalchemy.dialects import postgresql from sqlalchemy.orm import Session @@ -34,7 +35,7 @@ ) -def upgrade() -> None: +def upgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Upgrade.""" bind = op.get_bind() session = Session(bind=bind) @@ -67,7 +68,7 @@ def upgrade() -> None: session.commit() -def downgrade() -> None: +def downgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Downgrade.""" op.alter_column( "EntityTypes", diff --git a/app/alembic/versions/f68a134a3685_add_bypass.py b/app/alembic/versions/f68a134a3685_add_bypass.py index c39e0136d..fe8b90e23 100644 --- a/app/alembic/versions/f68a134a3685_add_bypass.py +++ b/app/alembic/versions/f68a134a3685_add_bypass.py @@ -8,6 +8,7 @@ import sqlalchemy as sa from alembic import op +from dishka import AsyncContainer # revision identifiers, used by Alembic. revision = "f68a134a3685" @@ -16,7 +17,7 @@ depends_on: None | str = None -def upgrade() -> None: +def upgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Upgrade.""" op.add_column( "Policies", @@ -38,7 +39,7 @@ def upgrade() -> None: ) -def downgrade() -> None: +def downgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Downgrade.""" op.drop_column("Policies", "bypass_service_failure") op.drop_column("Policies", "bypass_no_connection") diff --git a/app/alembic/versions/fafc3d0b11ec_.py b/app/alembic/versions/fafc3d0b11ec_.py index 40ce32bab..1ab09b6b3 100644 --- a/app/alembic/versions/fafc3d0b11ec_.py +++ b/app/alembic/versions/fafc3d0b11ec_.py @@ -6,13 +6,18 @@ """ +import sqlalchemy as sa from alembic import op +from dishka import AsyncContainer, Scope from sqlalchemy import delete, exists, select from sqlalchemy.exc import DBAPIError, IntegrityError from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession from entities import Directory -from extra.alembic_utils import temporary_stub_entity_type_name +from extra.alembic_utils import temporary_stub_column +from ldap_protocol.ldap_schema.attribute_value_validator import ( + AttributeValueValidator, +) from ldap_protocol.utils.queries import ( create_group, get_base_directories, @@ -27,15 +32,20 @@ depends_on: None | str = None -@temporary_stub_entity_type_name -def upgrade() -> None: +@temporary_stub_column("entity_type_id", sa.Integer()) +@temporary_stub_column("is_system", sa.Boolean()) +def upgrade(container: AsyncContainer) -> None: """Upgrade.""" async def _create_readonly_grp_and_plcy( - connection: AsyncConnection, + connection: AsyncConnection, # noqa: ARG001 ) -> None: - session = AsyncSession(bind=connection) - await session.begin() + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + attribute_value_validator = await cnt.get( + AttributeValueValidator, + ) + base_dn_list = await get_base_directories(session) if not base_dn_list: return @@ -52,6 +62,7 @@ async def _create_readonly_grp_and_plcy( dir_, _ = await create_group( name="readonly domain controllers", sid=521, + attribute_value_validator=attribute_value_validator, session=session, ) @@ -65,15 +76,17 @@ async def _create_readonly_grp_and_plcy( op.run_async(_create_readonly_grp_and_plcy) -@temporary_stub_entity_type_name -def downgrade() -> None: +@temporary_stub_column("entity_type_id", sa.Integer()) +@temporary_stub_column("is_system", sa.Boolean()) +def downgrade(container: AsyncContainer) -> None: """Downgrade.""" async def _delete_readonly_grp_and_plcy( - connection: AsyncConnection, + connection: AsyncConnection, # noqa: ARG001 ) -> None: - session = AsyncSession(bind=connection) - await session.begin() + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + base_dn_list = await get_base_directories(session) if not base_dn_list: return diff --git a/app/alembic/versions/fc8b7617c60a_attr_index.py b/app/alembic/versions/fc8b7617c60a_attr_index.py index 4ce948f4c..ced401ee1 100644 --- a/app/alembic/versions/fc8b7617c60a_attr_index.py +++ b/app/alembic/versions/fc8b7617c60a_attr_index.py @@ -7,6 +7,7 @@ import sqlalchemy as sa from alembic import op +from dishka import AsyncContainer # revision identifiers, used by Alembic. revision = "fc8b7617c60a" @@ -15,7 +16,7 @@ depends_on: str | None = None -def upgrade() -> None: +def upgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Create index for Attribute name field.""" op.execute( sa.text( @@ -31,7 +32,7 @@ def upgrade() -> None: ) -def downgrade() -> None: +def downgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Drop index for Attribute name field.""" op.execute(sa.text("DROP INDEX idx_attributes_lw_name_btree")) op.execute(sa.text("DROP INDEX idx_attributes_name_gin_trgm")) diff --git a/app/api/__init__.py b/app/api/__init__.py index 468cea5c5..69f1e8f37 100644 --- a/app/api/__init__.py +++ b/app/api/__init__.py @@ -14,7 +14,11 @@ from .main.krb5_router import krb5_router from .main.router import entry_router from .network.router import network_router -from .password_policy import password_ban_word_router, password_policy_router +from .password_policy import ( + password_ban_word_router, + password_policy_router, + user_password_history_router, +) from .shadow.router import shadow_router __all__ = [ @@ -25,6 +29,7 @@ "mfa_router", "password_ban_word_router", "password_policy_router", + "user_password_history_router", "ldap_schema_router", "dns_router", "krb5_router", diff --git a/app/api/audit/adapter.py b/app/api/audit/adapter.py index f06e24231..e7a39665d 100644 --- a/app/api/audit/adapter.py +++ b/app/api/audit/adapter.py @@ -4,19 +4,11 @@ License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE """ -from typing import ParamSpec, TypeVar - -from fastapi import status - from api.base_adapter import BaseAdapter from ldap_protocol.policies.audit.dataclasses import ( AuditDestinationDTO, AuditPolicyDTO, ) -from ldap_protocol.policies.audit.exception import ( - AuditAlreadyExistsError, - AuditNotFoundError, -) from ldap_protocol.policies.audit.schemas import ( AuditDestinationResponse, AuditDestinationSchemaRequest, @@ -25,18 +17,10 @@ ) from ldap_protocol.policies.audit.service import AuditService -P = ParamSpec("P") -R = TypeVar("R") - class AuditPoliciesAdapter(BaseAdapter[AuditService]): """Adapter for audit policies.""" - _exceptions_map: dict[type[Exception], int] = { - AuditNotFoundError: status.HTTP_404_NOT_FOUND, - AuditAlreadyExistsError: status.HTTP_409_CONFLICT, - } - async def get_policies(self) -> list[AuditPolicyResponse]: """Get all audit policies.""" return [ diff --git a/app/api/audit/router.py b/app/api/audit/router.py index bbafe770a..4a328e2ef 100644 --- a/app/api/audit/router.py +++ b/app/api/audit/router.py @@ -5,10 +5,21 @@ """ from dishka import FromDishka -from dishka.integrations.fastapi import DishkaRoute -from fastapi import APIRouter, Depends, status +from fastapi import Depends, status +from fastapi_error_map.routing import ErrorAwareRouter +from fastapi_error_map.rules import rule from api.auth.utils import verify_auth +from api.error_routing import ( + ERROR_MAP_TYPE, + DishkaErrorAwareRoute, + DomainErrorTranslator, +) +from enums import DomainCodes +from ldap_protocol.policies.audit.exception import ( + AuditAlreadyExistsError, + AuditNotFoundError, +) from ldap_protocol.policies.audit.schemas import ( AuditDestinationResponse, AuditDestinationSchemaRequest, @@ -18,15 +29,29 @@ from .adapter import AuditPoliciesAdapter -audit_router = APIRouter( +translator = DomainErrorTranslator(DomainCodes.AUDIT) + + +error_map: ERROR_MAP_TYPE = { + AuditNotFoundError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + AuditAlreadyExistsError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), +} + +audit_router = ErrorAwareRouter( prefix="/audit", tags=["Audit policy"], dependencies=[Depends(verify_auth)], - route_class=DishkaRoute, + route_class=DishkaErrorAwareRoute, ) -@audit_router.get("/policies") +@audit_router.get("/policies", error_map=error_map) async def get_audit_policies( audit_adapter: FromDishka[AuditPoliciesAdapter], ) -> list[AuditPolicyResponse]: @@ -34,7 +59,7 @@ async def get_audit_policies( return await audit_adapter.get_policies() -@audit_router.put("/policy/{policy_id}") +@audit_router.put("/policy/{policy_id}", error_map=error_map) async def update_audit_policy( policy_id: int, policy_data: AuditPolicySchemaRequest, @@ -44,7 +69,7 @@ async def update_audit_policy( return await audit_adapter.update_policy(policy_id, policy_data) -@audit_router.get("/destinations") +@audit_router.get("/destinations", error_map=error_map) async def get_audit_destinations( audit_adapter: FromDishka[AuditPoliciesAdapter], ) -> list[AuditDestinationResponse]: @@ -52,7 +77,11 @@ async def get_audit_destinations( return await audit_adapter.get_destinations() -@audit_router.post("/destination", status_code=status.HTTP_201_CREATED) +@audit_router.post( + "/destination", + status_code=status.HTTP_201_CREATED, + error_map=error_map, +) async def create_audit_destination( destination_data: AuditDestinationSchemaRequest, audit_adapter: FromDishka[AuditPoliciesAdapter], @@ -61,7 +90,7 @@ async def create_audit_destination( return await audit_adapter.create_destination(destination_data) -@audit_router.delete("/destination/{destination_id}") +@audit_router.delete("/destination/{destination_id}", error_map=error_map) async def delete_audit_destination( destination_id: int, audit_adapter: FromDishka[AuditPoliciesAdapter], @@ -70,7 +99,7 @@ async def delete_audit_destination( await audit_adapter.delete_destination(destination_id) -@audit_router.put("/destination/{destination_id}") +@audit_router.put("/destination/{destination_id}", error_map=error_map) async def update_audit_destination( destination_id: int, destination_data: AuditDestinationSchemaRequest, diff --git a/app/api/auth/adapters/auth.py b/app/api/auth/adapters/auth.py index 5b121b58e..50ed85ee7 100644 --- a/app/api/auth/adapters/auth.py +++ b/app/api/auth/adapters/auth.py @@ -7,32 +7,17 @@ from ipaddress import IPv4Address, IPv6Address from adaptix.conversion import get_converter -from fastapi import Request, status +from fastapi import Request from api.base_adapter import BaseAdapter from ldap_protocol.auth import AuthManager from ldap_protocol.auth.dto import SetupDTO -from ldap_protocol.auth.exceptions.mfa import ( - MFAAPIError, - MFAConnectError, - MFARequiredError, - MissingMFACredentialsError, -) from ldap_protocol.auth.schemas import ( MFAChallengeResponse, OAuth2Form, SetupRequest, ) from ldap_protocol.dialogue import UserSchema -from ldap_protocol.identity.exceptions import ( - AlreadyConfiguredError, - AuthValidationError, - LoginFailedError, - PasswordPolicyError, - UnauthorizedError, - UserNotFoundError, -) -from ldap_protocol.kerberos.exceptions import KRBAPIChangePasswordError _convert_request_to_dto = get_converter(SetupRequest, SetupDTO) @@ -40,21 +25,6 @@ class AuthFastAPIAdapter(BaseAdapter[AuthManager]): """Adapter for using IdentityManager with FastAPI.""" - _exceptions_map: dict[type[Exception], int] = { - UnauthorizedError: status.HTTP_401_UNAUTHORIZED, - LoginFailedError: status.HTTP_403_FORBIDDEN, - MFARequiredError: status.HTTP_426_UPGRADE_REQUIRED, - PasswordPolicyError: status.HTTP_422_UNPROCESSABLE_ENTITY, - AuthValidationError: status.HTTP_422_UNPROCESSABLE_ENTITY, - PermissionError: status.HTTP_403_FORBIDDEN, - UserNotFoundError: status.HTTP_404_NOT_FOUND, - KRBAPIChangePasswordError: status.HTTP_424_FAILED_DEPENDENCY, - AlreadyConfiguredError: status.HTTP_423_LOCKED, - MissingMFACredentialsError: status.HTTP_403_FORBIDDEN, - MFAAPIError: status.HTTP_406_NOT_ACCEPTABLE, - MFAConnectError: status.HTTP_406_NOT_ACCEPTABLE, - } - async def login( self, form: OAuth2Form, diff --git a/app/api/auth/adapters/mfa.py b/app/api/auth/adapters/mfa.py index 96011271f..9fa3b4a02 100644 --- a/app/api/auth/adapters/mfa.py +++ b/app/api/auth/adapters/mfa.py @@ -11,16 +11,7 @@ from api.base_adapter import BaseAdapter from ldap_protocol.auth import MFAManager -from ldap_protocol.auth.exceptions.mfa import ( - ForbiddenError, - InvalidCredentialsError, - MFAAPIError, - MFAConnectError, - MFATokenError, - MissingMFACredentialsError, - NetworkPolicyError, - NotFoundError, -) +from ldap_protocol.auth.exceptions.mfa import MFATokenError from ldap_protocol.auth.schemas import MFACreateRequest, MFAGetResponse from ldap_protocol.multifactor import MFA_HTTP_Creds, MFA_LDAP_Creds @@ -28,16 +19,6 @@ class MFAFastAPIAdapter(BaseAdapter[MFAManager]): """Adapter for using MFAManager with FastAPI.""" - _exceptions_map: dict[type[Exception], int] = { - MissingMFACredentialsError: status.HTTP_403_FORBIDDEN, - NetworkPolicyError: status.HTTP_403_FORBIDDEN, - ForbiddenError: status.HTTP_403_FORBIDDEN, - InvalidCredentialsError: status.HTTP_422_UNPROCESSABLE_ENTITY, - NotFoundError: status.HTTP_404_NOT_FOUND, - MFAAPIError: status.HTTP_406_NOT_ACCEPTABLE, - MFAConnectError: status.HTTP_406_NOT_ACCEPTABLE, - } - async def setup_mfa(self, mfa: MFACreateRequest) -> bool: """Create or update MFA keys. diff --git a/app/api/auth/adapters/session_gateway.py b/app/api/auth/adapters/session_gateway.py index 2b02bf35e..165c8db74 100644 --- a/app/api/auth/adapters/session_gateway.py +++ b/app/api/auth/adapters/session_gateway.py @@ -5,8 +5,6 @@ from ipaddress import IPv4Address, IPv6Address from typing import Literal, ParamSpec, TypeVar -from fastapi import status - from api.base_adapter import BaseAdapter from ldap_protocol.session_storage import SessionRepository @@ -54,9 +52,9 @@ class UserSessionsResponseSchema: class SessionFastAPIGateway(BaseAdapter[SessionRepository]): """Base class for session storage.""" - _exceptions_map: dict[type[Exception], int] = { - LookupError: status.HTTP_404_NOT_FOUND, - } + def __init__(self, repository: SessionRepository) -> None: + """Initialize the session gateway with a repository.""" + self._service = repository async def get_user_sessions( self, diff --git a/app/api/auth/router_auth.py b/app/api/auth/router_auth.py index 28d06a9db..ae8df7bfd 100644 --- a/app/api/auth/router_auth.py +++ b/app/api/auth/router_auth.py @@ -8,25 +8,111 @@ from typing import Annotated from dishka import FromDishka -from dishka.integrations.fastapi import DishkaRoute -from fastapi import APIRouter, Body, Depends, Request, Response, status +from fastapi import Body, Depends, Request, Response, status +from fastapi_error_map.routing import ErrorAwareRouter +from fastapi_error_map.rules import rule from api.auth.adapters import AuthFastAPIAdapter from api.auth.utils import get_ip_from_request, get_user_agent_from_request +from api.error_routing import ( + ERROR_MAP_TYPE, + DishkaErrorAwareRoute, + DomainErrorTranslator, +) +from enums import DomainCodes +from ldap_protocol.auth.exceptions.mfa import ( + MFAAPIError, + MFAConnectError, + MFARequiredError, + MissingMFACredentialsError, +) from ldap_protocol.auth.schemas import ( MFAChallengeResponse, OAuth2Form, SetupRequest, ) from ldap_protocol.dialogue import UserSchema +from ldap_protocol.identity.exceptions import ( + AlreadyConfiguredError, + AuthValidationError, + ForbiddenError, + LoginFailedError, + PasswordPolicyError, + UnauthorizedError, + UserNotFoundError, +) +from ldap_protocol.kerberos.exceptions import KRBAPIChangePasswordError from ldap_protocol.session_storage import SessionStorage from .utils import verify_auth -auth_router = APIRouter(prefix="/auth", tags=["Auth"], route_class=DishkaRoute) +translator = DomainErrorTranslator(DomainCodes.AUTH) + + +error_map: ERROR_MAP_TYPE = { + UnauthorizedError: rule( + status=status.HTTP_401_UNAUTHORIZED, + translator=translator, + ), + AlreadyConfiguredError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + ForbiddenError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + LoginFailedError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + PasswordPolicyError: rule( + status=status.HTTP_422_UNPROCESSABLE_ENTITY, + translator=translator, + ), + UserNotFoundError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + AuthValidationError: rule( + status=status.HTTP_422_UNPROCESSABLE_ENTITY, + translator=translator, + ), + MFARequiredError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + MissingMFACredentialsError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + MFAAPIError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + MFAConnectError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + PermissionError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + KRBAPIChangePasswordError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), +} + + +auth_router = ErrorAwareRouter( + prefix="/auth", + tags=["Auth"], + route_class=DishkaErrorAwareRoute, +) -@auth_router.post("/") +@auth_router.post("/", error_map=error_map) async def login( form: Annotated[OAuth2Form, Depends()], request: Request, @@ -62,7 +148,7 @@ async def login( ) -@auth_router.get("/me") +@auth_router.get("/me", error_map=error_map) async def users_me( identity_adapter: FromDishka[AuthFastAPIAdapter], ) -> UserSchema: @@ -75,7 +161,11 @@ async def users_me( return await identity_adapter.get_current_user() -@auth_router.delete("/", response_class=Response) +@auth_router.delete( + "/", + response_class=Response, + error_map=error_map, +) async def logout( response: Response, storage: FromDishka[SessionStorage], @@ -97,6 +187,7 @@ async def logout( "/user/password", status_code=200, dependencies=[Depends(verify_auth)], + error_map=error_map, ) async def password_reset( auth_manager: FromDishka[AuthFastAPIAdapter], @@ -121,7 +212,7 @@ async def password_reset( await auth_manager.reset_password(identity, new_password, old_password) -@auth_router.get("/setup") +@auth_router.get("/setup", error_map=error_map) async def check_setup( auth_manager: FromDishka[AuthFastAPIAdapter], ) -> bool: @@ -137,6 +228,7 @@ async def check_setup( "/setup", status_code=status.HTTP_200_OK, responses={423: {"detail": "Locked"}}, + error_map=error_map, ) async def first_setup( request: SetupRequest, diff --git a/app/api/auth/router_mfa.py b/app/api/auth/router_mfa.py index 57daeb214..18424c8ca 100644 --- a/app/api/auth/router_mfa.py +++ b/app/api/auth/router_mfa.py @@ -8,10 +8,10 @@ from typing import Annotated, Literal from dishka import FromDishka -from dishka.integrations.fastapi import DishkaRoute from fastapi import Depends, Form, status from fastapi.responses import RedirectResponse -from fastapi.routing import APIRouter +from fastapi_error_map.routing import ErrorAwareRouter +from fastapi_error_map.rules import rule from api.auth.adapters import MFAFastAPIAdapter from api.auth.utils import ( @@ -19,13 +19,62 @@ get_user_agent_from_request, verify_auth, ) +from api.error_routing import ( + ERROR_MAP_TYPE, + DishkaErrorAwareRoute, + DomainErrorTranslator, +) +from enums import DomainCodes +from ldap_protocol.auth.exceptions.mfa import ( + ForbiddenError, + InvalidCredentialsError, + MFAAPIError, + MFAConnectError, + MissingMFACredentialsError, + NetworkPolicyError, + NotFoundError, +) from ldap_protocol.auth.schemas import MFACreateRequest, MFAGetResponse from ldap_protocol.multifactor import MFA_HTTP_Creds, MFA_LDAP_Creds -mfa_router = APIRouter( +translator = DomainErrorTranslator(DomainCodes.MFA) + + +error_map: ERROR_MAP_TYPE = { + MFAAPIError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + MFAConnectError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + MissingMFACredentialsError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + NetworkPolicyError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + ForbiddenError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + InvalidCredentialsError: rule( + status=status.HTTP_422_UNPROCESSABLE_ENTITY, + translator=translator, + ), + NotFoundError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), +} + +mfa_router = ErrorAwareRouter( prefix="/multifactor", tags=["Multifactor"], - route_class=DishkaRoute, + route_class=DishkaErrorAwareRoute, ) @@ -33,6 +82,7 @@ "/setup", status_code=status.HTTP_201_CREATED, dependencies=[Depends(verify_auth)], + error_map=error_map, ) async def setup_mfa( mfa: MFACreateRequest, @@ -51,6 +101,7 @@ async def setup_mfa( @mfa_router.delete( "/keys", dependencies=[Depends(verify_auth)], + error_map=error_map, ) async def remove_mfa( scope: Literal["ldap", "http"], @@ -60,7 +111,11 @@ async def remove_mfa( await mfa_manager.remove_mfa(scope) -@mfa_router.post("/get", dependencies=[Depends(verify_auth)]) +@mfa_router.post( + "/get", + dependencies=[Depends(verify_auth)], + error_map=error_map, +) async def get_mfa( mfa_creds: FromDishka[MFA_HTTP_Creds], mfa_creds_ldap: FromDishka[MFA_LDAP_Creds], @@ -74,7 +129,12 @@ async def get_mfa( return await mfa_manager.get_mfa(mfa_creds, mfa_creds_ldap) -@mfa_router.post("/create", name="callback_mfa", include_in_schema=True) +@mfa_router.post( + "/create", + name="callback_mfa", + include_in_schema=True, + error_map=error_map, +) async def callback_mfa( access_token: Annotated[ str, diff --git a/app/api/auth/session_router.py b/app/api/auth/session_router.py index 3958b9a8c..43da97542 100644 --- a/app/api/auth/session_router.py +++ b/app/api/auth/session_router.py @@ -1,9 +1,17 @@ """Session router for handling user sessions.""" from dishka import FromDishka -from dishka.integrations.fastapi import DishkaRoute from fastapi import Depends, status -from fastapi.routing import APIRouter +from fastapi_error_map.routing import ErrorAwareRouter +from fastapi_error_map.rules import rule + +from api.error_routing import ( + ERROR_MAP_TYPE, + DishkaErrorAwareRoute, + DomainErrorTranslator, +) +from enums import DomainCodes +from ldap_protocol.session_storage.exceptions import SessionUserNotFoundError from .adapters.session_gateway import ( SessionContentResponseSchema, @@ -11,15 +19,25 @@ ) from .utils import verify_auth -session_router = APIRouter( +translator = DomainErrorTranslator(DomainCodes.SESSION) + + +error_map: ERROR_MAP_TYPE = { + SessionUserNotFoundError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), +} + +session_router = ErrorAwareRouter( prefix="/sessions", tags=["Session"], - route_class=DishkaRoute, + route_class=DishkaErrorAwareRoute, dependencies=[Depends(verify_auth)], ) -@session_router.get("/{upn}") +@session_router.get("/{upn}", error_map=error_map) async def get_user_session( upn: str, gateway: FromDishka[SessionFastAPIGateway], @@ -28,7 +46,11 @@ async def get_user_session( return await gateway.get_user_sessions(upn) -@session_router.delete("/{upn}", status_code=status.HTTP_204_NO_CONTENT) +@session_router.delete( + "/{upn}", + status_code=status.HTTP_204_NO_CONTENT, + error_map=error_map, +) async def delete_user_sessions( upn: str, gateway: FromDishka[SessionFastAPIGateway], @@ -40,6 +62,7 @@ async def delete_user_sessions( @session_router.delete( "/session/{session_id}", status_code=status.HTTP_204_NO_CONTENT, + error_map=error_map, ) async def delete_session( session_id: str, diff --git a/app/api/auth/utils.py b/app/api/auth/utils.py index f0461f7c5..0557267cd 100644 --- a/app/api/auth/utils.py +++ b/app/api/auth/utils.py @@ -43,7 +43,7 @@ def get_ip_from_request(request: Request) -> IPv4Address | IPv6Address: client_ip = forwarded_for.split(",")[0] else: if request.client is None: - raise HTTPException(status.HTTP_403_FORBIDDEN) + raise HTTPException(status.HTTP_400_BAD_REQUEST) client_ip = request.client.host return ip_address(client_ip) diff --git a/app/api/base_adapter.py b/app/api/base_adapter.py index e9f359374..991669195 100644 --- a/app/api/base_adapter.py +++ b/app/api/base_adapter.py @@ -4,16 +4,10 @@ License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE """ -from asyncio import iscoroutinefunction -from functools import wraps -from typing import Awaitable, Callable, NoReturn, ParamSpec, Protocol, TypeVar - -from fastapi import HTTPException, status -from loguru import logger +from typing import ParamSpec, Protocol, TypeVar from abstract_service import AbstractService from authorization_provider_protocol import AuthorizationProviderProtocol -from ldap_protocol.permissions_checker import AuthorizationError _P = ParamSpec("_P") _R = TypeVar("_R") @@ -23,7 +17,6 @@ class BaseAdapter(Protocol[_T]): """Abstract Adapter interface.""" - _exceptions_map: dict[type[Exception], int] _service: _T def __init__( @@ -34,66 +27,3 @@ def __init__( """Set service.""" self._service = service self._service.set_permissions_checker(perm_checker) - - def __new__( - cls, - *_: tuple, - **__: dict, - ) -> "BaseAdapter[_T]": - """Wrap all public methods with try catch for _exceptions_map.""" - instance = super().__new__(cls) - - def wrap_sync(func: Callable[_P, _R]) -> Callable[_P, _R]: - @wraps(func) - def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: - try: - return func(*args, **kwargs) - except Exception as err: - instance._reraise(err) - - return wrapper - - def wrap_async( - func: Callable[_P, Awaitable[_R]], - ) -> Callable[_P, Awaitable[_R]]: - @wraps(func) - async def awrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: - try: - return await func(*args, **kwargs) - except Exception as err: - instance._reraise(err) - - return awrapper - - for name in dir(instance): - if name.startswith("_"): - continue - - attr = getattr(instance, name) - - if not callable(attr): - continue - - if iscoroutinefunction(attr): - wrapped = wrap_async(attr) - else: - wrapped = wrap_sync(attr) - - setattr(instance, name, wrapped) - - return instance - - def _reraise(self, exc: Exception) -> NoReturn: - """Reraise exception with mapped HTTPException.""" - exceptions_map = self._exceptions_map | { - AuthorizationError: status.HTTP_403_FORBIDDEN, - } - code = exceptions_map.get(type(exc)) - logger.debug(f"Reraising exception {exc} with code {code}") - if code is None: - raise - - raise HTTPException( - status_code=code, - detail=str(exc), - ) from exc diff --git a/app/api/dhcp/adapter.py b/app/api/dhcp/adapter.py index f69f4aebc..d063ad144 100644 --- a/app/api/dhcp/adapter.py +++ b/app/api/dhcp/adapter.py @@ -6,27 +6,18 @@ from ipaddress import IPv4Address -from fastapi import status - from api.base_adapter import BaseAdapter from ldap_protocol.dhcp import ( AbstractDHCPManager, - DHCPAPIError, DHCPChangeStateSchemaRequest, - DHCPEntryAddError, - DHCPEntryDeleteError, - DHCPEntryNotFoundError, - DHCPEntryUpdateError, DHCPLeaseSchemaRequest, DHCPLeaseSchemaResponse, DHCPLeaseToReservationErrorResponse, - DHCPOperationError, DHCPReservationSchemaRequest, DHCPReservationSchemaResponse, DHCPStateSchemaResponse, DHCPSubnetSchemaAddRequest, DHCPSubnetSchemaResponse, - DHCPValidatonError, ) from ldap_protocol.dhcp.dataclasses import ( DHCPLease, @@ -40,16 +31,6 @@ class DHCPAdapter(BaseAdapter[AbstractDHCPManager]): """Adapter for DHCP management using KeaDHCPManager.""" - _exceptions_map: dict[type[Exception], int] = { - DHCPEntryNotFoundError: status.HTTP_404_NOT_FOUND, - DHCPEntryDeleteError: status.HTTP_409_CONFLICT, - DHCPEntryAddError: status.HTTP_409_CONFLICT, - DHCPEntryUpdateError: status.HTTP_409_CONFLICT, - DHCPAPIError: status.HTTP_400_BAD_REQUEST, - DHCPValidatonError: status.HTTP_422_UNPROCESSABLE_ENTITY, - DHCPOperationError: status.HTTP_400_BAD_REQUEST, - } - async def create_subnet( self, subnet_data: DHCPSubnetSchemaAddRequest, diff --git a/app/api/dhcp/router.py b/app/api/dhcp/router.py index 2790f12cc..053241809 100644 --- a/app/api/dhcp/router.py +++ b/app/api/dhcp/router.py @@ -7,10 +7,26 @@ from ipaddress import IPv4Address from dishka import FromDishka -from dishka.integrations.fastapi import DishkaRoute -from fastapi import APIRouter, Depends, status +from fastapi import Depends, status +from fastapi_error_map.routing import ErrorAwareRouter +from fastapi_error_map.rules import rule from api.auth.utils import verify_auth +from api.error_routing import ( + ERROR_MAP_TYPE, + DishkaErrorAwareRoute, + DomainErrorTranslator, +) +from enums import DomainCodes +from ldap_protocol.dhcp.exceptions import ( + DHCPAPIError, + DHCPEntryAddError, + DHCPEntryDeleteError, + DHCPEntryNotFoundError, + DHCPEntryUpdateError, + DHCPOperationError, + DHCPValidatonError, +) from ldap_protocol.dhcp.schemas import ( DHCPChangeStateSchemaRequest, DHCPLeaseSchemaRequest, @@ -25,15 +41,53 @@ from .adapter import DHCPAdapter -dhcp_router = APIRouter( +translator = DomainErrorTranslator(DomainCodes.DHCP) + + +error_map: ERROR_MAP_TYPE = { + DHCPEntryNotFoundError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + DHCPEntryDeleteError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + DHCPEntryAddError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + DHCPEntryUpdateError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + DHCPAPIError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + DHCPValidatonError: rule( + status=status.HTTP_422_UNPROCESSABLE_ENTITY, + translator=translator, + ), + DHCPOperationError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), +} + +dhcp_router = ErrorAwareRouter( prefix="/dhcp", tags=["DHCP"], dependencies=[Depends(verify_auth)], - route_class=DishkaRoute, + route_class=DishkaErrorAwareRoute, ) -@dhcp_router.post("/service/change_state", status_code=status.HTTP_200_OK) +@dhcp_router.post( + "/service/change_state", + status_code=status.HTTP_200_OK, + error_map=error_map, +) async def setup_dhcp( state_data: DHCPChangeStateSchemaRequest, dhcp_adapter: FromDishka[DHCPAdapter], @@ -42,7 +96,7 @@ async def setup_dhcp( await dhcp_adapter.change_state(state_data) -@dhcp_router.get("/service/state") +@dhcp_router.get("/service/state", error_map=error_map) async def get_dhcp_state( dhcp_adapter: FromDishka[DHCPAdapter], ) -> DHCPStateSchemaResponse: @@ -50,7 +104,11 @@ async def get_dhcp_state( return await dhcp_adapter.get_state() -@dhcp_router.post("/subnet", status_code=status.HTTP_201_CREATED) +@dhcp_router.post( + "/subnet", + status_code=status.HTTP_201_CREATED, + error_map=error_map, +) async def create_dhcp_subnet( subnet_data: DHCPSubnetSchemaAddRequest, dhcp_adapter: FromDishka[DHCPAdapter], @@ -59,7 +117,7 @@ async def create_dhcp_subnet( await dhcp_adapter.create_subnet(subnet_data) -@dhcp_router.get("/subnets") +@dhcp_router.get("/subnets", error_map=error_map) async def get_dhcp_subnets( dhcp_adapter: FromDishka[DHCPAdapter], ) -> list[DHCPSubnetSchemaResponse]: @@ -67,7 +125,7 @@ async def get_dhcp_subnets( return await dhcp_adapter.get_subnets() -@dhcp_router.put("/subnet/{subnet_id}") +@dhcp_router.put("/subnet/{subnet_id}", error_map=error_map) async def update_dhcp_subnet( subnet_id: int, subnet_data: DHCPSubnetSchemaAddRequest, @@ -77,7 +135,7 @@ async def update_dhcp_subnet( await dhcp_adapter.update_subnet(subnet_id, subnet_data) -@dhcp_router.delete("/subnet/{subnet_id}") +@dhcp_router.delete("/subnet/{subnet_id}", error_map=error_map) async def delete_dhcp_subnet( subnet_id: int, dhcp_adapter: FromDishka[DHCPAdapter], @@ -86,7 +144,11 @@ async def delete_dhcp_subnet( await dhcp_adapter.delete_subnet(subnet_id) -@dhcp_router.post("/lease", status_code=status.HTTP_201_CREATED) +@dhcp_router.post( + "/lease", + status_code=status.HTTP_201_CREATED, + error_map=error_map, +) async def create_dhcp_lease( lease_data: DHCPLeaseSchemaRequest, dhcp_adapter: FromDishka[DHCPAdapter], @@ -95,7 +157,7 @@ async def create_dhcp_lease( await dhcp_adapter.create_lease(lease_data) -@dhcp_router.get("/lease/{subnet_id}") +@dhcp_router.get("/lease/{subnet_id}", error_map=error_map) async def get_dhcp_leases( subnet_id: int, dhcp_adapter: FromDishka[DHCPAdapter], @@ -104,7 +166,7 @@ async def get_dhcp_leases( return await dhcp_adapter.list_active_leases(subnet_id) -@dhcp_router.get("/lease/") +@dhcp_router.get("/lease/", error_map=error_map) async def find_dhcp_lease( dhcp_adapter: FromDishka[DHCPAdapter], mac_address: str | None = None, @@ -114,7 +176,7 @@ async def find_dhcp_lease( return await dhcp_adapter.find_lease(mac_address, hostname) -@dhcp_router.delete("/lease/{ip_address}") +@dhcp_router.delete("/lease/{ip_address}", error_map=error_map) async def delete_dhcp_lease( ip_address: IPv4Address, dhcp_adapter: FromDishka[DHCPAdapter], @@ -123,7 +185,7 @@ async def delete_dhcp_lease( await dhcp_adapter.release_lease(ip_address) -@dhcp_router.patch("/lease/to_reservation") +@dhcp_router.patch("/lease/to_reservation", error_map=error_map) async def lease_to_reservation( data: list[DHCPReservationSchemaRequest], dhcp_adapter: FromDishka[DHCPAdapter], @@ -132,7 +194,11 @@ async def lease_to_reservation( return await dhcp_adapter.lease_to_reservation(data) -@dhcp_router.post("/reservation", status_code=status.HTTP_201_CREATED) +@dhcp_router.post( + "/reservation", + status_code=status.HTTP_201_CREATED, + error_map=error_map, +) async def create_dhcp_reservation( reservation_data: DHCPReservationSchemaRequest, dhcp_adapter: FromDishka[DHCPAdapter], @@ -141,7 +207,7 @@ async def create_dhcp_reservation( await dhcp_adapter.add_reservation(reservation_data) -@dhcp_router.get("/reservation/{subnet_id}") +@dhcp_router.get("/reservation/{subnet_id}", error_map=error_map) async def get_dhcp_reservation( subnet_id: int, dhcp_adapter: FromDishka[DHCPAdapter], @@ -150,7 +216,7 @@ async def get_dhcp_reservation( return await dhcp_adapter.get_reservations(subnet_id) -@dhcp_router.put("/reservation") +@dhcp_router.put("/reservation", error_map=error_map) async def update_dhcp_reservation( data: DHCPReservationSchemaRequest, dhcp_adapter: FromDishka[DHCPAdapter], @@ -159,7 +225,7 @@ async def update_dhcp_reservation( await dhcp_adapter.update_reservation(data) -@dhcp_router.delete("/reservation") +@dhcp_router.delete("/reservation", error_map=error_map) async def delete_dhcp_reservation( mac_address: str, ip_address: IPv4Address, diff --git a/app/api/error_routing.py b/app/api/error_routing.py new file mode 100644 index 000000000..805132fb7 --- /dev/null +++ b/app/api/error_routing.py @@ -0,0 +1,56 @@ +"""Error routing. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from dataclasses import dataclass +from enum import IntEnum + +from dishka.integrations.fastapi import DishkaRoute +from fastapi_error_map.routing import ErrorAwareRoute +from fastapi_error_map.rules import Rule +from fastapi_error_map.translators import ErrorTranslator + +from enums import DomainCodes +from errors import BaseDomainException + +ERROR_MAP_TYPE = dict[type[Exception], int | Rule] | None + + +@dataclass +class ErrorResponse: + """Error response.""" + + detail: str + domain_code: DomainCodes + error_code: IntEnum + + +class DishkaErrorAwareRoute(ErrorAwareRoute, DishkaRoute): + """Route class that combines ErrorAwareRoute and DishkaRoute.""" + + +class DomainErrorTranslator(ErrorTranslator[ErrorResponse]): + """DNS error translator.""" + + domain_code: DomainCodes + + def __init__(self, domain_code: DomainCodes) -> None: + """Initialize error translator.""" + self.domain_code = domain_code + + @property + def error_response_model_cls(self) -> type[ErrorResponse]: + return ErrorResponse + + def from_error(self, err: Exception) -> ErrorResponse: + """Translate exception to error response.""" + if not isinstance(err, BaseDomainException): + raise TypeError(f"Expected BaseDomainException, got {type(err)}") + + return ErrorResponse( + detail=str(err), + domain_code=self.domain_code, + error_code=err.code, + ) diff --git a/app/api/exception_handlers.py b/app/api/exception_handlers.py index 510c68c8f..edf5ce363 100644 --- a/app/api/exception_handlers.py +++ b/app/api/exception_handlers.py @@ -24,30 +24,11 @@ def handle_db_connect_error( raise HTTPException(status.HTTP_503_SERVICE_UNAVAILABLE) -async def handle_dns_error( +async def handle_auth_error( request: Request, # noqa: ARG001 exc: Exception, ) -> NoReturn: - """Handle EmptyLabel exception.""" - logger.critical("DNS manager error: {}", exc) - raise HTTPException(status.HTTP_503_SERVICE_UNAVAILABLE) - - -async def handle_dns_api_error( - request: Request, # noqa: ARG001 - exc: Exception, -) -> NoReturn: - """Handle DNS API error.""" - logger.critical("DNS API error: {}", exc) - raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=str(exc)) - - -async def handle_not_implemented_error( - request: Request, # noqa: ARG001 - exc: Exception, # noqa: ARG001 -) -> NoReturn: - """Handle Not Implemented error.""" - raise HTTPException( - status_code=status.HTTP_501_NOT_IMPLEMENTED, - detail="This feature is supported with selfhosted DNS server.", - ) + """Handle Auth error.""" + # fastapi-error-map doesn't handle exceptions from dependencies + # (get_ldap_session), so we catch them manually here + raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=str(exc)) diff --git a/app/api/ldap_schema/__init__.py b/app/api/ldap_schema/__init__.py index 6deb0c2fc..a0314c320 100644 --- a/app/api/ldap_schema/__init__.py +++ b/app/api/ldap_schema/__init__.py @@ -7,10 +7,28 @@ from typing import Annotated from annotated_types import Len -from dishka.integrations.fastapi import DishkaRoute -from fastapi import APIRouter, Body, Depends +from fastapi import Body, Depends, status +from fastapi_error_map.routing import ErrorAwareRouter +from fastapi_error_map.rules import rule from api.auth.utils import verify_auth +from api.error_routing import ( + ERROR_MAP_TYPE, + DishkaErrorAwareRoute, + DomainErrorTranslator, +) +from enums import DomainCodes +from ldap_protocol.ldap_schema.exceptions import ( + AttributeTypeAlreadyExistsError, + AttributeTypeCantModifyError, + AttributeTypeNotFoundError, + EntityTypeAlreadyExistsError, + EntityTypeCantModifyError, + EntityTypeNotFoundError, + ObjectClassAlreadyExistsError, + ObjectClassCantModifyError, + ObjectClassNotFoundError, +) LimitedListType = Annotated[ list[str], @@ -18,9 +36,51 @@ Body(embed=True), ] -ldap_schema_router = APIRouter( +translator = DomainErrorTranslator(DomainCodes.LDAP_SCHEMA) + + +error_map: ERROR_MAP_TYPE = { + AttributeTypeAlreadyExistsError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + AttributeTypeNotFoundError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + AttributeTypeCantModifyError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + ObjectClassAlreadyExistsError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + ObjectClassNotFoundError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + ObjectClassCantModifyError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + EntityTypeAlreadyExistsError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + EntityTypeNotFoundError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + EntityTypeCantModifyError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), +} + +ldap_schema_router = ErrorAwareRouter( prefix="/schema", tags=["Schema"], dependencies=[Depends(verify_auth)], - route_class=DishkaRoute, + route_class=DishkaErrorAwareRoute, ) diff --git a/app/api/ldap_schema/adapters/attribute_type.py b/app/api/ldap_schema/adapters/attribute_type.py index d803ce8ff..ad1ea6516 100644 --- a/app/api/ldap_schema/adapters/attribute_type.py +++ b/app/api/ldap_schema/adapters/attribute_type.py @@ -12,7 +12,6 @@ get_converter, link_function, ) -from fastapi import status from api.base_adapter import BaseAdapter from api.ldap_schema.adapters.base_ldap_schema_adapter import ( @@ -32,11 +31,6 @@ DEFAULT_ATTRIBUTE_TYPE_SYNTAX, ) from ldap_protocol.ldap_schema.dto import AttributeTypeDTO -from ldap_protocol.ldap_schema.exceptions import ( - AttributeTypeAlreadyExistsError, - AttributeTypeCantModifyError, - AttributeTypeNotFoundError, -) def _convert_update_uschema_to_dto( @@ -97,9 +91,3 @@ class AttributeTypeFastAPIAdapter( _converter_to_dto = staticmethod(_convert_schema_to_dto) _converter_to_schema = staticmethod(_convert_dto_to_schema) _converter_update_sch_to_dto = staticmethod(_convert_update_uschema_to_dto) - - _exceptions_map: dict[type[Exception], int] = { - AttributeTypeAlreadyExistsError: status.HTTP_409_CONFLICT, - AttributeTypeNotFoundError: status.HTTP_404_NOT_FOUND, - AttributeTypeCantModifyError: status.HTTP_403_FORBIDDEN, - } diff --git a/app/api/ldap_schema/adapters/entity_type.py b/app/api/ldap_schema/adapters/entity_type.py index 4d56b37ed..03199b634 100644 --- a/app/api/ldap_schema/adapters/entity_type.py +++ b/app/api/ldap_schema/adapters/entity_type.py @@ -5,7 +5,6 @@ """ from adaptix.conversion import get_converter -from fastapi import status from api.base_adapter import BaseAdapter from api.ldap_schema.adapters.base_ldap_schema_adapter import ( @@ -19,11 +18,6 @@ from ldap_protocol.ldap_schema.constants import DEFAULT_ENTITY_TYPE_IS_SYSTEM from ldap_protocol.ldap_schema.dto import EntityTypeDTO from ldap_protocol.ldap_schema.entity_type_use_case import EntityTypeUseCase -from ldap_protocol.ldap_schema.exceptions import ( - EntityTypeCantModifyError, - EntityTypeNotFoundError, - ObjectClassNotFoundError, -) def _convert_update_chema_to_dto( @@ -66,12 +60,6 @@ class LDAPEntityTypeFastAPIAdapter( _converter_to_schema = staticmethod(_convert_dto_to_schema) _converter_update_sch_to_dto = staticmethod(_convert_update_chema_to_dto) - _exceptions_map: dict[type[Exception], int] = { - EntityTypeNotFoundError: status.HTTP_404_NOT_FOUND, - EntityTypeCantModifyError: status.HTTP_403_FORBIDDEN, - ObjectClassNotFoundError: status.HTTP_404_NOT_FOUND, - } - async def get_entity_type_attributes(self, name: str) -> list[str]: """Get all attribute names for an Entity Type. diff --git a/app/api/ldap_schema/adapters/object_class.py b/app/api/ldap_schema/adapters/object_class.py index e2f77a5e5..7c0199a88 100644 --- a/app/api/ldap_schema/adapters/object_class.py +++ b/app/api/ldap_schema/adapters/object_class.py @@ -6,7 +6,6 @@ from adaptix import P from adaptix.conversion import get_converter, link_function -from fastapi import status from api.base_adapter import BaseAdapter from api.ldap_schema.adapters.base_ldap_schema_adapter import ( @@ -20,11 +19,6 @@ from enums import KindType from ldap_protocol.ldap_schema.constants import DEFAULT_OBJECT_CLASS_IS_SYSTEM from ldap_protocol.ldap_schema.dto import AttributeTypeDTO, ObjectClassDTO -from ldap_protocol.ldap_schema.exceptions import ( - ObjectClassAlreadyExistsError, - ObjectClassCantModifyError, - ObjectClassNotFoundError, -) from ldap_protocol.ldap_schema.object_class_use_case import ObjectClassUseCase @@ -96,9 +90,3 @@ class ObjectClassFastAPIAdapter( _converter_to_dto = staticmethod(_convert_schema_to_dto) _converter_to_schema = staticmethod(_convert_dto_to_schema) _converter_update_sch_to_dto = staticmethod(_convert_update_schema_to_dto) - - _exceptions_map: dict[type[Exception], int] = { - ObjectClassAlreadyExistsError: status.HTTP_409_CONFLICT, - ObjectClassNotFoundError: status.HTTP_404_NOT_FOUND, - ObjectClassCantModifyError: status.HTTP_403_FORBIDDEN, - } diff --git a/app/api/ldap_schema/attribute_type_router.py b/app/api/ldap_schema/attribute_type_router.py index 0029218f1..5a2f1f368 100644 --- a/app/api/ldap_schema/attribute_type_router.py +++ b/app/api/ldap_schema/attribute_type_router.py @@ -9,7 +9,7 @@ from dishka.integrations.fastapi import FromDishka from fastapi import Query, status -from api.ldap_schema import LimitedListType, ldap_schema_router +from api.ldap_schema import LimitedListType, error_map, ldap_schema_router from api.ldap_schema.adapters.attribute_type import AttributeTypeFastAPIAdapter from api.ldap_schema.schema import ( AttributeTypePaginationSchema, @@ -22,6 +22,7 @@ @ldap_schema_router.post( "/attribute_type", status_code=status.HTTP_201_CREATED, + error_map=error_map, ) async def create_one_attribute_type( request_data: AttributeTypeSchema[None], @@ -31,7 +32,10 @@ async def create_one_attribute_type( await adapter.create(request_data) -@ldap_schema_router.get("/attribute_type/{attribute_type_name}") +@ldap_schema_router.get( + "/attribute_type/{attribute_type_name}", + error_map=error_map, +) async def get_one_attribute_type( attribute_type_name: str, adapter: FromDishka[AttributeTypeFastAPIAdapter], @@ -40,7 +44,10 @@ async def get_one_attribute_type( return await adapter.get(attribute_type_name) -@ldap_schema_router.get("/attribute_types") +@ldap_schema_router.get( + "/attribute_types", + error_map=error_map, +) async def get_list_attribute_types_with_pagination( adapter: FromDishka[AttributeTypeFastAPIAdapter], params: Annotated[PaginationParams, Query()], @@ -49,7 +56,10 @@ async def get_list_attribute_types_with_pagination( return await adapter.get_list_paginated(params) -@ldap_schema_router.patch("/attribute_type/{attribute_type_name}") +@ldap_schema_router.patch( + "/attribute_type/{attribute_type_name}", + error_map=error_map, +) async def modify_one_attribute_type( attribute_type_name: str, request_data: AttributeTypeUpdateSchema, @@ -59,7 +69,10 @@ async def modify_one_attribute_type( await adapter.update(name=attribute_type_name, data=request_data) -@ldap_schema_router.post("/attribute_types/delete") +@ldap_schema_router.post( + "/attribute_types/delete", + error_map=error_map, +) async def delete_bulk_attribute_types( attribute_types_names: LimitedListType, adapter: FromDishka[AttributeTypeFastAPIAdapter], diff --git a/app/api/ldap_schema/entity_type_router.py b/app/api/ldap_schema/entity_type_router.py index db6b3607b..31de91616 100644 --- a/app/api/ldap_schema/entity_type_router.py +++ b/app/api/ldap_schema/entity_type_router.py @@ -9,7 +9,7 @@ from dishka.integrations.fastapi import FromDishka from fastapi import Query, status -from api.ldap_schema import LimitedListType +from api.ldap_schema import LimitedListType, error_map from api.ldap_schema.adapters.entity_type import LDAPEntityTypeFastAPIAdapter from api.ldap_schema.object_class_router import ldap_schema_router from api.ldap_schema.schema import ( @@ -20,7 +20,11 @@ from ldap_protocol.utils.pagination import PaginationParams -@ldap_schema_router.post("/entity_type", status_code=status.HTTP_201_CREATED) +@ldap_schema_router.post( + "/entity_type", + status_code=status.HTTP_201_CREATED, + error_map=error_map, +) async def create_one_entity_type( request_data: EntityTypeSchema[None], adapter: FromDishka[LDAPEntityTypeFastAPIAdapter], @@ -29,7 +33,7 @@ async def create_one_entity_type( await adapter.create(request_data) -@ldap_schema_router.get("/entity_type/{entity_type_name}") +@ldap_schema_router.get("/entity_type/{entity_type_name}", error_map=error_map) async def get_one_entity_type( entity_type_name: str, adapter: FromDishka[LDAPEntityTypeFastAPIAdapter], @@ -38,7 +42,7 @@ async def get_one_entity_type( return await adapter.get(entity_type_name) -@ldap_schema_router.get("/entity_types") +@ldap_schema_router.get("/entity_types", error_map=error_map) async def get_list_entity_types_with_pagination( adapter: FromDishka[LDAPEntityTypeFastAPIAdapter], params: Annotated[PaginationParams, Query()], @@ -47,7 +51,10 @@ async def get_list_entity_types_with_pagination( return await adapter.get_list_paginated(params=params) -@ldap_schema_router.get("/entity_type/{entity_type_name}/attrs") +@ldap_schema_router.get( + "/entity_type/{entity_type_name}/attrs", + error_map=error_map, +) async def get_entity_type_attributes( entity_type_name: str, adapter: FromDishka[LDAPEntityTypeFastAPIAdapter], @@ -56,7 +63,10 @@ async def get_entity_type_attributes( return await adapter.get_entity_type_attributes(entity_type_name) -@ldap_schema_router.patch("/entity_type/{entity_type_name}") +@ldap_schema_router.patch( + "/entity_type/{entity_type_name}", + error_map=error_map, +) async def modify_one_entity_type( entity_type_name: str, request_data: EntityTypeUpdateSchema, @@ -66,7 +76,7 @@ async def modify_one_entity_type( await adapter.update(name=entity_type_name, data=request_data) -@ldap_schema_router.post("/entity_type/delete") +@ldap_schema_router.post("/entity_type/delete", error_map=error_map) async def delete_bulk_entity_types( entity_type_names: LimitedListType, adapter: FromDishka[LDAPEntityTypeFastAPIAdapter], diff --git a/app/api/ldap_schema/object_class_router.py b/app/api/ldap_schema/object_class_router.py index a658161bf..a351f3b33 100644 --- a/app/api/ldap_schema/object_class_router.py +++ b/app/api/ldap_schema/object_class_router.py @@ -9,7 +9,7 @@ from dishka.integrations.fastapi import FromDishka from fastapi import Query, status -from api.ldap_schema import LimitedListType +from api.ldap_schema import LimitedListType, error_map from api.ldap_schema.adapters.object_class import ObjectClassFastAPIAdapter from api.ldap_schema.attribute_type_router import ldap_schema_router from api.ldap_schema.schema import ( @@ -20,7 +20,11 @@ from ldap_protocol.utils.pagination import PaginationParams -@ldap_schema_router.post("/object_class", status_code=status.HTTP_201_CREATED) +@ldap_schema_router.post( + "/object_class", + status_code=status.HTTP_201_CREATED, + error_map=error_map, +) async def create_one_object_class( request_data: ObjectClassSchema[None], adapter: FromDishka[ObjectClassFastAPIAdapter], @@ -29,7 +33,10 @@ async def create_one_object_class( await adapter.create(request_data) -@ldap_schema_router.get("/object_class/{object_class_name}") +@ldap_schema_router.get( + "/object_class/{object_class_name}", + error_map=error_map, +) async def get_one_object_class( object_class_name: str, adapter: FromDishka[ObjectClassFastAPIAdapter], @@ -38,7 +45,7 @@ async def get_one_object_class( return await adapter.get(object_class_name) -@ldap_schema_router.get("/object_classes") +@ldap_schema_router.get("/object_classes", error_map=error_map) async def get_list_object_classes_with_pagination( adapter: FromDishka[ObjectClassFastAPIAdapter], params: Annotated[PaginationParams, Query()], @@ -47,7 +54,10 @@ async def get_list_object_classes_with_pagination( return await adapter.get_list_paginated(params=params) -@ldap_schema_router.patch("/object_class/{object_class_name}") +@ldap_schema_router.patch( + "/object_class/{object_class_name}", + error_map=error_map, +) async def modify_one_object_class( object_class_name: str, request_data: ObjectClassUpdateSchema, @@ -57,7 +67,7 @@ async def modify_one_object_class( await adapter.update(object_class_name, request_data) -@ldap_schema_router.post("/object_class/delete") +@ldap_schema_router.post("/object_class/delete", error_map=error_map) async def delete_bulk_object_classes( object_classes_names: LimitedListType, adapter: FromDishka[ObjectClassFastAPIAdapter], diff --git a/app/api/ldap_schema/schema.py b/app/api/ldap_schema/schema.py index 86100ed05..9e6453eff 100644 --- a/app/api/ldap_schema/schema.py +++ b/app/api/ldap_schema/schema.py @@ -8,7 +8,7 @@ from pydantic import BaseModel, Field -from enums import KindType +from enums import EntityTypeNames, KindType from ldap_protocol.ldap_schema.constants import ( DEFAULT_ENTITY_TYPE_IS_SYSTEM, OID_REGEX_PATTERN, @@ -82,7 +82,7 @@ class EntityTypeSchema(BaseModel, Generic[_IdT]): """Entity Type Schema.""" id: _IdT = Field(default=None) # type: ignore[assignment] - name: str + name: EntityTypeNames | str is_system: bool object_class_names: list[str] = Field( default_factory=list, diff --git a/app/api/main/adapters/dns.py b/app/api/main/adapters/dns.py index 9044c009c..352099fad 100644 --- a/app/api/main/adapters/dns.py +++ b/app/api/main/adapters/dns.py @@ -4,9 +4,6 @@ License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE """ -from fastapi import status - -import ldap_protocol.dns.exceptions as dns_exc from api.base_adapter import BaseAdapter from api.main.schema import ( DNSServiceForwardZoneCheckRequest, @@ -32,17 +29,6 @@ class DNSFastAPIAdapter(BaseAdapter[DNSUseCase]): """DNS adapter.""" - _exceptions_map = { - dns_exc.DNSSetupError: status.HTTP_424_FAILED_DEPENDENCY, - dns_exc.DNSRecordCreateError: status.HTTP_400_BAD_REQUEST, - dns_exc.DNSRecordUpdateError: status.HTTP_400_BAD_REQUEST, - dns_exc.DNSRecordDeleteError: status.HTTP_400_BAD_REQUEST, - dns_exc.DNSZoneCreateError: status.HTTP_400_BAD_REQUEST, - dns_exc.DNSZoneUpdateError: status.HTTP_400_BAD_REQUEST, - dns_exc.DNSZoneDeleteError: status.HTTP_400_BAD_REQUEST, - dns_exc.DNSUpdateServerOptionsError: status.HTTP_400_BAD_REQUEST, - } - async def create_record( self, data: DNSServiceRecordCreateRequest, diff --git a/app/api/main/adapters/kerberos.py b/app/api/main/adapters/kerberos.py index 9fbe5bbd7..1bbe252e2 100644 --- a/app/api/main/adapters/kerberos.py +++ b/app/api/main/adapters/kerberos.py @@ -4,9 +4,9 @@ License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE """ -from typing import Any, AsyncGenerator, ParamSpec, TypeVar +from typing import Any, AsyncGenerator -from fastapi import Request, Response, status +from fastapi import Request, Response from fastapi.responses import StreamingResponse from pydantic import SecretStr from starlette.background import BackgroundTask @@ -15,31 +15,13 @@ from api.main.schema import KerberosSetupRequest from ldap_protocol.dialogue import LDAPSession, UserSchema from ldap_protocol.kerberos import KerberosState -from ldap_protocol.kerberos.exceptions import ( - KerberosBaseDnNotFoundError, - KerberosConflictError, - KerberosDependencyError, - KerberosNotFoundError, - KerberosUnavailableError, -) from ldap_protocol.kerberos.service import KerberosService from ldap_protocol.ldap_requests.contexts import LDAPAddRequestContext -P = ParamSpec("P") -R = TypeVar("R") - class KerberosFastAPIAdapter(BaseAdapter[KerberosService]): """Adapter for using KerberosService with FastAPI and background tasks.""" - _exceptions_map: dict[type[Exception], int] = { - KerberosBaseDnNotFoundError: status.HTTP_503_SERVICE_UNAVAILABLE, - KerberosConflictError: status.HTTP_409_CONFLICT, - KerberosDependencyError: status.HTTP_424_FAILED_DEPENDENCY, - KerberosNotFoundError: status.HTTP_404_NOT_FOUND, - KerberosUnavailableError: status.HTTP_503_SERVICE_UNAVAILABLE, - } - async def setup_krb_catalogue( self, mail: str, diff --git a/app/api/main/dns_router.py b/app/api/main/dns_router.py index bac6a4915..d93382512 100644 --- a/app/api/main/dns_router.py +++ b/app/api/main/dns_router.py @@ -5,11 +5,18 @@ """ from dishka import FromDishka -from dishka.integrations.fastapi import DishkaRoute -from fastapi import Depends -from fastapi.routing import APIRouter +from dns.exception import DNSException +from fastapi import Depends, status +from fastapi_error_map import rule +from fastapi_error_map.routing import ErrorAwareRouter +import ldap_protocol.dns.exceptions as dns_exc from api.auth.utils import verify_auth +from api.error_routing import ( + ERROR_MAP_TYPE, + DishkaErrorAwareRoute, + DomainErrorTranslator, +) from api.main.adapters.dns import DNSFastAPIAdapter from api.main.schema import ( DNSServiceForwardZoneCheckRequest, @@ -22,6 +29,7 @@ DNSServiceZoneDeleteRequest, DNSServiceZoneUpdateRequest, ) +from enums import DomainCodes from ldap_protocol.dns import ( DNSForwardServerStatus, DNSForwardZone, @@ -30,15 +38,65 @@ DNSZone, ) -dns_router = APIRouter( +translator = DomainErrorTranslator(DomainCodes.DNS) + + +error_map: ERROR_MAP_TYPE = { + dns_exc.DNSSetupError: rule( + status=status.HTTP_422_UNPROCESSABLE_ENTITY, + translator=translator, + ), + dns_exc.DNSRecordCreateError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + dns_exc.DNSRecordUpdateError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + dns_exc.DNSRecordDeleteError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + dns_exc.DNSZoneCreateError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + dns_exc.DNSZoneUpdateError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + dns_exc.DNSZoneDeleteError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + dns_exc.DNSUpdateServerOptionsError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + DNSException: rule( + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + translator=translator, + ), + dns_exc.DNSConnectionError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + dns_exc.DNSNotImplementedError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), +} + +dns_router = ErrorAwareRouter( prefix="/dns", tags=["DNS_SERVICE"], dependencies=[Depends(verify_auth)], - route_class=DishkaRoute, + route_class=DishkaErrorAwareRoute, ) -@dns_router.post("/record") +@dns_router.post("/record", error_map=error_map) async def create_record( data: DNSServiceRecordCreateRequest, adapter: FromDishka[DNSFastAPIAdapter], @@ -47,7 +105,7 @@ async def create_record( await adapter.create_record(data) -@dns_router.delete("/record") +@dns_router.delete("/record", error_map=error_map) async def delete_single_record( data: DNSServiceRecordDeleteRequest, adapter: FromDishka[DNSFastAPIAdapter], @@ -56,7 +114,7 @@ async def delete_single_record( await adapter.delete_record(data) -@dns_router.patch("/record") +@dns_router.patch("/record", error_map=error_map) async def update_record( data: DNSServiceRecordUpdateRequest, adapter: FromDishka[DNSFastAPIAdapter], @@ -65,7 +123,7 @@ async def update_record( await adapter.update_record(data) -@dns_router.get("/record") +@dns_router.get("/record", error_map=error_map) async def get_all_records( adapter: FromDishka[DNSFastAPIAdapter], ) -> list[DNSRecords]: @@ -73,7 +131,7 @@ async def get_all_records( return await adapter.get_all_records() -@dns_router.get("/status") +@dns_router.get("/status", error_map=error_map) async def get_dns_status( adapter: FromDishka[DNSFastAPIAdapter], ) -> dict[str, str | None]: @@ -81,7 +139,7 @@ async def get_dns_status( return await adapter.get_dns_status() -@dns_router.post("/setup") +@dns_router.post("/setup", error_map=error_map) async def setup_dns( data: DNSServiceSetupRequest, adapter: FromDishka[DNSFastAPIAdapter], @@ -90,7 +148,7 @@ async def setup_dns( await adapter.setup_dns(data) -@dns_router.get("/zone") +@dns_router.get("/zone", error_map=error_map) async def get_dns_zone( adapter: FromDishka[DNSFastAPIAdapter], ) -> list[DNSZone]: @@ -98,7 +156,7 @@ async def get_dns_zone( return await adapter.get_dns_zone() -@dns_router.get("/zone/forward") +@dns_router.get("/zone/forward", error_map=error_map) async def get_forward_dns_zones( adapter: FromDishka[DNSFastAPIAdapter], ) -> list[DNSForwardZone]: @@ -106,7 +164,12 @@ async def get_forward_dns_zones( return await adapter.get_forward_dns_zones() -@dns_router.post("/zone") +@dns_router.post( + "/zone", + error_map=error_map, + warn_on_unmapped=False, + default_client_error_translator=translator, +) async def create_zone( data: DNSServiceZoneCreateRequest, adapter: FromDishka[DNSFastAPIAdapter], @@ -115,7 +178,7 @@ async def create_zone( await adapter.create_zone(data) -@dns_router.patch("/zone") +@dns_router.patch("/zone", error_map=error_map) async def update_zone( data: DNSServiceZoneUpdateRequest, adapter: FromDishka[DNSFastAPIAdapter], @@ -124,7 +187,7 @@ async def update_zone( await adapter.update_zone(data) -@dns_router.delete("/zone") +@dns_router.delete("/zone", error_map=error_map) async def delete_zone( data: DNSServiceZoneDeleteRequest, adapter: FromDishka[DNSFastAPIAdapter], @@ -133,7 +196,7 @@ async def delete_zone( await adapter.delete_zone(data) -@dns_router.post("/forward_check") +@dns_router.post("/forward_check", error_map=error_map) async def check_dns_forward_zone( data: DNSServiceForwardZoneCheckRequest, adapter: FromDishka[DNSFastAPIAdapter], @@ -142,7 +205,7 @@ async def check_dns_forward_zone( return await adapter.check_dns_forward_zone(data) -@dns_router.get("/zone/reload/") +@dns_router.get("/zone/reload/", error_map=error_map) async def reload_zone( data: DNSServiceReloadZoneRequest, adapter: FromDishka[DNSFastAPIAdapter], diff --git a/app/api/main/krb5_router.py b/app/api/main/krb5_router.py index cdec9a002..91f64a5b6 100644 --- a/app/api/main/krb5_router.py +++ b/app/api/main/krb5_router.py @@ -8,28 +8,72 @@ from annotated_types import Len from dishka import FromDishka -from dishka.integrations.fastapi import DishkaRoute -from fastapi import Body, Request, Response +from fastapi import Body, Request, Response, status from fastapi.params import Depends from fastapi.responses import StreamingResponse -from fastapi.routing import APIRouter +from fastapi_error_map.routing import ErrorAwareRouter +from fastapi_error_map.rules import rule from pydantic import SecretStr from api.auth.adapters.auth import AuthFastAPIAdapter from api.auth.utils import verify_auth +from api.error_routing import ( + ERROR_MAP_TYPE, + DishkaErrorAwareRoute, + DomainErrorTranslator, +) from api.main.adapters.kerberos import KerberosFastAPIAdapter from api.main.schema import KerberosSetupRequest +from enums import DomainCodes from ldap_protocol.dialogue import LDAPSession from ldap_protocol.kerberos import KerberosState +from ldap_protocol.kerberos.exceptions import ( + KerberosBaseDnNotFoundError, + KerberosConflictError, + KerberosDependencyError, + KerberosNotFoundError, + KerberosUnavailableError, + KRBAPIConnectionError, +) from ldap_protocol.ldap_requests.contexts import LDAPAddRequestContext from ldap_protocol.utils.const import EmailStr from .utils import get_ldap_session -krb5_router = APIRouter( +translator = DomainErrorTranslator(DomainCodes.KERBEROS) + + +error_map: ERROR_MAP_TYPE = { + KerberosBaseDnNotFoundError: rule( + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + translator=translator, + ), + KerberosConflictError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + KerberosDependencyError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + KerberosNotFoundError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + KerberosUnavailableError: rule( + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + translator=translator, + ), + KRBAPIConnectionError: rule( + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + translator=translator, + ), +} + +krb5_router = ErrorAwareRouter( prefix="/kerberos", tags=["KRB5 API"], - route_class=DishkaRoute, + route_class=DishkaErrorAwareRoute, ) KERBEROS_POLICY_NAME = "Kerberos Access Policy" @@ -37,6 +81,7 @@ @krb5_router.post( "/setup/tree", response_class=Response, + error_map=error_map, dependencies=[Depends(verify_auth)], ) async def setup_krb_catalogue( @@ -61,7 +106,7 @@ async def setup_krb_catalogue( ) -@krb5_router.post("/setup", response_class=Response) +@krb5_router.post("/setup", response_class=Response, error_map=error_map) async def setup_kdc( data: KerberosSetupRequest, identity_adapter: FromDishka[AuthFastAPIAdapter], @@ -92,7 +137,11 @@ async def setup_kdc( ] -@krb5_router.post("/ktadd", dependencies=[Depends(verify_auth)]) +@krb5_router.post( + "/ktadd", + dependencies=[Depends(verify_auth)], + error_map=error_map, +) async def ktadd( names: Annotated[LIMITED_LIST, Body()], kerberos_adapter: FromDishka[KerberosFastAPIAdapter], @@ -105,7 +154,11 @@ async def ktadd( return await kerberos_adapter.ktadd(names) -@krb5_router.get("/status", dependencies=[Depends(verify_auth)]) +@krb5_router.get( + "/status", + dependencies=[Depends(verify_auth)], + error_map=error_map, +) async def get_krb_status( kerberos_adapter: FromDishka[KerberosFastAPIAdapter], ) -> KerberosState: @@ -118,7 +171,11 @@ async def get_krb_status( return await kerberos_adapter.get_status() -@krb5_router.post("/principal/add", dependencies=[Depends(verify_auth)]) +@krb5_router.post( + "/principal/add", + dependencies=[Depends(verify_auth)], + error_map=error_map, +) async def add_principal( primary: Annotated[LIMITED_STR, Body()], instance: Annotated[LIMITED_STR, Body()], @@ -137,6 +194,7 @@ async def add_principal( @krb5_router.patch( "/principal/rename", dependencies=[Depends(verify_auth)], + error_map=error_map, ) async def rename_principal( principal_name: Annotated[LIMITED_STR, Body()], @@ -160,6 +218,7 @@ async def rename_principal( @krb5_router.patch( "/principal/reset", dependencies=[Depends(verify_auth)], + error_map=error_map, ) async def reset_principal_pw( principal_name: Annotated[LIMITED_STR, Body()], @@ -180,6 +239,7 @@ async def reset_principal_pw( @krb5_router.delete( "/principal/delete", dependencies=[Depends(verify_auth)], + error_map=error_map, ) async def delete_principal( principal_name: Annotated[LIMITED_STR, Body(embed=True)], diff --git a/app/api/main/router.py b/app/api/main/router.py index 197e2cad2..f4df578e8 100644 --- a/app/api/main/router.py +++ b/app/api/main/router.py @@ -5,12 +5,19 @@ """ from dishka import FromDishka -from dishka.integrations.fastapi import DishkaRoute -from fastapi import Depends, HTTPException, Request -from fastapi.routing import APIRouter +from fastapi import Depends, HTTPException, Request, status +from fastapi_error_map.routing import ErrorAwareRouter +from fastapi_error_map.rules import rule from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession +from api.error_routing import ( + ERROR_MAP_TYPE, + DishkaErrorAwareRoute, + DomainErrorTranslator, +) +from enums import DomainCodes +from ldap_protocol.identity.exceptions import UnauthorizedError from ldap_protocol.ldap_requests import ( AddRequest, DeleteRequest, @@ -28,15 +35,25 @@ ) from .utils import get_ldap_session -entry_router = APIRouter( +translator = DomainErrorTranslator(DomainCodes.LDAP) + + +error_map: ERROR_MAP_TYPE = { + UnauthorizedError: rule( + status=status.HTTP_401_UNAUTHORIZED, + translator=translator, + ), +} + +entry_router = ErrorAwareRouter( prefix="/entry", tags=["LDAP API"], - route_class=DishkaRoute, + route_class=DishkaErrorAwareRoute, dependencies=[Depends(get_ldap_session)], ) -@entry_router.post("/search") +@entry_router.post("/search", error_map=error_map) async def search( request: SearchRequest, req: Request, @@ -55,7 +72,7 @@ async def search( ) -@entry_router.post("/add") +@entry_router.post("/add", error_map=error_map) async def add( request: AddRequest, req: Request, @@ -64,7 +81,7 @@ async def add( return await request.handle_api(req.state.dishka_container) -@entry_router.patch("/update") +@entry_router.patch("/update", error_map=error_map) async def modify( request: ModifyRequest, req: Request, @@ -73,7 +90,7 @@ async def modify( return await request.handle_api(req.state.dishka_container) -@entry_router.patch("/update_many") +@entry_router.patch("/update_many", error_map=error_map) async def modify_many( requests: list[ModifyRequest], req: Request, @@ -85,7 +102,7 @@ async def modify_many( return results -@entry_router.put("/update/dn") +@entry_router.put("/update/dn", error_map=error_map) async def modify_dn( request: ModifyDNRequest, req: Request, @@ -94,7 +111,7 @@ async def modify_dn( return await request.handle_api(req.state.dishka_container) -@entry_router.delete("/delete") +@entry_router.delete("/delete", error_map=error_map) async def delete( request: DeleteRequest, req: Request, @@ -103,7 +120,7 @@ async def delete( return await request.handle_api(req.state.dishka_container) -@entry_router.post("/delete_many") +@entry_router.post("/delete_many", error_map=error_map) async def delete_many( requests: list[DeleteRequest], req: Request, diff --git a/app/api/network/adapters/network.py b/app/api/network/adapters/network.py index 74313944d..9316b8c7e 100644 --- a/app/api/network/adapters/network.py +++ b/app/api/network/adapters/network.py @@ -19,16 +19,11 @@ PolicyUpdate, SwapResponse, ) -from ldap_protocol.policies.network.dto import ( +from ldap_protocol.policies.network import ( NetworkPolicyDTO, NetworkPolicyUpdateDTO, + NetworkPolicyUseCase, ) -from ldap_protocol.policies.network.exceptions import ( - LastActivePolicyError, - NetworkPolicyAlreadyExistsError, - NetworkPolicyNotFoundError, -) -from ldap_protocol.policies.network.use_cases import NetworkPolicyUseCase def _convert_netmasks( @@ -68,12 +63,6 @@ def _convert_raw(dto: NetworkPolicyDTO[int]) -> list[str | dict]: class NetworkPolicyFastAPIAdapter(BaseAdapter[NetworkPolicyUseCase]): """Network policy adapter.""" - _exceptions_map: dict[type[Exception], int] = { - NetworkPolicyAlreadyExistsError: status.HTTP_422_UNPROCESSABLE_ENTITY, - LastActivePolicyError: status.HTTP_422_UNPROCESSABLE_ENTITY, - NetworkPolicyNotFoundError: status.HTTP_404_NOT_FOUND, - } - async def create(self, policy: Policy) -> PolicyResponse: """Create network policy.""" policy_dto = await self._service.create( diff --git a/app/api/network/router.py b/app/api/network/router.py index 554148f59..bc65ed858 100644 --- a/app/api/network/router.py +++ b/app/api/network/router.py @@ -5,14 +5,25 @@ """ from dishka import FromDishka -from dishka.integrations.fastapi import DishkaRoute from fastapi import Request, status from fastapi.params import Depends from fastapi.responses import RedirectResponse -from fastapi.routing import APIRouter +from fastapi_error_map.routing import ErrorAwareRouter +from fastapi_error_map.rules import rule from api.auth.utils import verify_auth +from api.error_routing import ( + ERROR_MAP_TYPE, + DishkaErrorAwareRoute, + DomainErrorTranslator, +) from api.network.adapters.network import NetworkPolicyFastAPIAdapter +from enums import DomainCodes +from ldap_protocol.policies.network.exceptions import ( + LastActivePolicyError, + NetworkPolicyAlreadyExistsError, + NetworkPolicyNotFoundError, +) from .schema import ( Policy, @@ -22,15 +33,38 @@ SwapResponse, ) -network_router = APIRouter( +translator = DomainErrorTranslator(DomainCodes.NETWORK) + + +error_map: ERROR_MAP_TYPE = { + NetworkPolicyAlreadyExistsError: rule( + status=status.HTTP_422_UNPROCESSABLE_ENTITY, + translator=translator, + ), + NetworkPolicyNotFoundError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + LastActivePolicyError: rule( + status=status.HTTP_422_UNPROCESSABLE_ENTITY, + translator=translator, + ), +} + + +network_router = ErrorAwareRouter( prefix="/policy", tags=["Network policy"], - route_class=DishkaRoute, + route_class=DishkaErrorAwareRoute, dependencies=[Depends(verify_auth)], ) -@network_router.post("", status_code=status.HTTP_201_CREATED) +@network_router.post( + "", + status_code=status.HTTP_201_CREATED, + error_map=error_map, +) async def add_network_policy( policy: Policy, adapter: FromDishka[NetworkPolicyFastAPIAdapter], @@ -46,7 +80,7 @@ async def add_network_policy( return await adapter.create(policy) -@network_router.get("", name="policy") +@network_router.get("", name="policy", error_map=error_map) async def get_list_network_policies( adapter: FromDishka[NetworkPolicyFastAPIAdapter], ) -> list[PolicyResponse]: @@ -62,6 +96,7 @@ async def get_list_network_policies( "/{policy_id}", response_class=RedirectResponse, status_code=status.HTTP_303_SEE_OTHER, + error_map=error_map, ) async def delete_network_policy( policy_id: int, @@ -79,7 +114,7 @@ async def delete_network_policy( return await adapter.delete(request, policy_id) # type: ignore -@network_router.patch("/{policy_id}") +@network_router.patch("/{policy_id}", error_map=error_map) async def switch_network_policy( policy_id: int, adapter: FromDishka[NetworkPolicyFastAPIAdapter], @@ -98,7 +133,7 @@ async def switch_network_policy( return await adapter.switch_network_policy(policy_id) -@network_router.put("") +@network_router.put("", error_map=error_map) async def update_network_policy( request: PolicyUpdate, adapter: FromDishka[NetworkPolicyFastAPIAdapter], @@ -115,7 +150,7 @@ async def update_network_policy( return await adapter.update(request) -@network_router.post("/swap") +@network_router.post("/swap", error_map=error_map) async def swap_network_policy( swap: SwapRequest, adapter: FromDishka[NetworkPolicyFastAPIAdapter], diff --git a/app/api/password_policy/__init__.py b/app/api/password_policy/__init__.py index fbb1affbf..c66ab654a 100644 --- a/app/api/password_policy/__init__.py +++ b/app/api/password_policy/__init__.py @@ -6,8 +6,10 @@ from .password_ban_word_router import password_ban_word_router from .password_policy_router import password_policy_router +from .user_password_history_router import user_password_history_router __all__ = [ "password_ban_word_router", "password_policy_router", + "user_password_history_router", ] diff --git a/app/api/password_policy/adapter.py b/app/api/password_policy/adapter.py index 4fae20e5d..81b5d7be1 100644 --- a/app/api/password_policy/adapter.py +++ b/app/api/password_policy/adapter.py @@ -7,26 +7,19 @@ import io from adaptix.conversion import get_converter -from fastapi import UploadFile, status +from fastapi import UploadFile from fastapi.responses import StreamingResponse from api.base_adapter import BaseAdapter from api.password_policy.schemas import PasswordPolicySchema, PriorityT from ldap_protocol.policies.password.dataclasses import PasswordPolicyDTO from ldap_protocol.policies.password.exceptions import ( - PasswordBanWordFileHasDuplicatesError, PasswordBanWordWrongFileExtensionError, - PasswordPolicyAgeDaysError, - PasswordPolicyAlreadyExistsError, - PasswordPolicyBaseDnNotFoundError, - PasswordPolicyCantChangeDefaultDomainError, - PasswordPolicyDirIsNotUserError, - PasswordPolicyNotFoundError, - PasswordPolicyPriorityError, ) from ldap_protocol.policies.password.use_cases import ( PasswordBanWordUseCases, PasswordPolicyUseCases, + UserPasswordHistoryUseCases, ) _convert_schema_to_dto = get_converter(PasswordPolicySchema, PasswordPolicyDTO) @@ -36,19 +29,18 @@ ) +class UserPasswordHistoryResetFastAPIAdapter( + BaseAdapter[UserPasswordHistoryUseCases], +): + """Adapter for clearing user password history.""" + + async def clear(self, identity: str) -> None: + await self._service.clear(identity) + + class PasswordPolicyFastAPIAdapter(BaseAdapter[PasswordPolicyUseCases]): """Adapter for password policies.""" - _exceptions_map: dict[type[Exception], int] = { - PasswordPolicyBaseDnNotFoundError: status.HTTP_404_NOT_FOUND, - PasswordPolicyNotFoundError: status.HTTP_404_NOT_FOUND, - PasswordPolicyDirIsNotUserError: status.HTTP_404_NOT_FOUND, - PasswordPolicyAlreadyExistsError: status.HTTP_409_CONFLICT, - PasswordPolicyCantChangeDefaultDomainError: status.HTTP_400_BAD_REQUEST, # noqa: E501 - PasswordPolicyPriorityError: status.HTTP_400_BAD_REQUEST, - PasswordPolicyAgeDaysError: status.HTTP_400_BAD_REQUEST, - } - async def get_all(self) -> list[PasswordPolicySchema[int]]: """Get all Password Policies.""" dtos = await self._service.get_all() @@ -64,9 +56,7 @@ async def get_password_policy_by_dir_path_dn( path_dn: str, ) -> PasswordPolicySchema[int]: """Get one Password Policy for one Directory by its path.""" - dto = await self._service.get_password_policy_by_dir_path_dn( - path_dn, - ) + dto = await self._service.get_password_policy_by_dir_path_dn(path_dn) return _convert_dto_to_schema(dto) async def update( @@ -86,11 +76,6 @@ async def reset_domain_policy_to_default_config(self) -> None: class PasswordBanWordsFastAPIAdapter(BaseAdapter[PasswordBanWordUseCases]): """Adapter for password ban words.""" - _exceptions_map: dict[type[Exception], int] = { - PasswordBanWordWrongFileExtensionError: status.HTTP_400_BAD_REQUEST, - PasswordBanWordFileHasDuplicatesError: status.HTTP_409_CONFLICT, - } - async def upload_ban_words_txt(self, file: UploadFile) -> None: if ( file diff --git a/app/api/password_policy/error_utils.py b/app/api/password_policy/error_utils.py new file mode 100644 index 000000000..201f2b823 --- /dev/null +++ b/app/api/password_policy/error_utils.py @@ -0,0 +1,64 @@ +"""Password policy error utils. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from fastapi import status +from fastapi_error_map.rules import rule + +from api.error_routing import ERROR_MAP_TYPE, DomainErrorTranslator +from enums import DomainCodes +from ldap_protocol.permissions_checker import AuthorizationError +from ldap_protocol.policies.password.exceptions import ( + PasswordBanWordWrongFileExtensionError, + PasswordPolicyAgeDaysError, + PasswordPolicyAlreadyExistsError, + PasswordPolicyBaseDnNotFoundError, + PasswordPolicyCantChangeDefaultDomainError, + PasswordPolicyDirIsNotUserError, + PasswordPolicyNotFoundError, + PasswordPolicyPriorityError, +) + +translator = DomainErrorTranslator(DomainCodes.PASSWORD_POLICY) + + +error_map: ERROR_MAP_TYPE = { + PasswordPolicyBaseDnNotFoundError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + PasswordPolicyNotFoundError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + PasswordPolicyDirIsNotUserError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + PasswordPolicyAlreadyExistsError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + PasswordPolicyCantChangeDefaultDomainError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + PasswordPolicyPriorityError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + PasswordPolicyAgeDaysError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + PasswordBanWordWrongFileExtensionError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + AuthorizationError: rule( + status=status.HTTP_401_UNAUTHORIZED, + translator=translator, + ), +} diff --git a/app/api/password_policy/password_ban_word_router.py b/app/api/password_policy/password_ban_word_router.py index a774fd6eb..a0c06a04e 100644 --- a/app/api/password_policy/password_ban_word_router.py +++ b/app/api/password_policy/password_ban_word_router.py @@ -5,24 +5,27 @@ """ from dishka import FromDishka -from dishka.integrations.fastapi import DishkaRoute -from fastapi import APIRouter, Depends, UploadFile, status +from fastapi import Depends, UploadFile, status from fastapi.responses import StreamingResponse +from fastapi_error_map.routing import ErrorAwareRouter from api.auth.utils import verify_auth +from api.error_routing import DishkaErrorAwareRoute from api.password_policy.adapter import PasswordBanWordsFastAPIAdapter +from api.password_policy.error_utils import error_map -password_ban_word_router = APIRouter( +password_ban_word_router = ErrorAwareRouter( prefix="/password_ban_word", tags=["Password Ban Word"], dependencies=[Depends(verify_auth)], - route_class=DishkaRoute, + route_class=DishkaErrorAwareRoute, ) @password_ban_word_router.post( "/upload_txt", status_code=status.HTTP_201_CREATED, + error_map=error_map, ) async def upload_ban_words_txt( file: UploadFile, @@ -43,6 +46,7 @@ async def upload_ban_words_txt( "/download_txt", response_class=StreamingResponse, status_code=status.HTTP_200_OK, + error_map=error_map, ) async def download_ban_words_txt( password_ban_word_adapter: FromDishka[PasswordBanWordsFastAPIAdapter], diff --git a/app/api/password_policy/password_policy_router.py b/app/api/password_policy/password_policy_router.py index f74025d86..812777ecd 100644 --- a/app/api/password_policy/password_policy_router.py +++ b/app/api/password_policy/password_policy_router.py @@ -5,25 +5,27 @@ """ from dishka import FromDishka -from dishka.integrations.fastapi import DishkaRoute -from fastapi import APIRouter, Depends +from fastapi import Depends +from fastapi_error_map.routing import ErrorAwareRouter from api.auth.utils import verify_auth +from api.error_routing import DishkaErrorAwareRoute from api.password_policy.adapter import PasswordPolicyFastAPIAdapter +from api.password_policy.error_utils import error_map from api.password_policy.schemas import PasswordPolicySchema from ldap_protocol.utils.const import GRANT_DN_STRING from .schemas import PriorityT -password_policy_router = APIRouter( +password_policy_router = ErrorAwareRouter( prefix="/password-policy", dependencies=[Depends(verify_auth)], tags=["Password Policy"], - route_class=DishkaRoute, + route_class=DishkaErrorAwareRoute, ) -@password_policy_router.get("/all") +@password_policy_router.get("/all", error_map=error_map) async def get_all( adapter: FromDishka[PasswordPolicyFastAPIAdapter], ) -> list[PasswordPolicySchema[int]]: @@ -31,7 +33,7 @@ async def get_all( return await adapter.get_all() -@password_policy_router.get("/{id_}") +@password_policy_router.get("/{id_}", error_map=error_map) async def get( id_: int, adapter: FromDishka[PasswordPolicyFastAPIAdapter], @@ -40,7 +42,7 @@ async def get( return await adapter.get(id_) -@password_policy_router.get("/by_dir_path_dn/{path_dn}") +@password_policy_router.get("/by_dir_path_dn/{path_dn}", error_map=error_map) async def get_password_policy_by_dir_path_dn( path_dn: GRANT_DN_STRING, adapter: FromDishka[PasswordPolicyFastAPIAdapter], @@ -49,7 +51,7 @@ async def get_password_policy_by_dir_path_dn( return await adapter.get_password_policy_by_dir_path_dn(path_dn) -@password_policy_router.put("/{id_}") +@password_policy_router.put("/{id_}", error_map=error_map) async def update( id_: int, policy: PasswordPolicySchema[PriorityT], @@ -59,7 +61,7 @@ async def update( await adapter.update(id_, policy) -@password_policy_router.put("/reset/domain_policy") +@password_policy_router.put("/reset/domain_policy", error_map=error_map) async def reset_domain_policy_to_default_config( adapter: FromDishka[PasswordPolicyFastAPIAdapter], ) -> None: diff --git a/app/api/password_policy/user_password_history_router.py b/app/api/password_policy/user_password_history_router.py new file mode 100644 index 000000000..2285c3cdd --- /dev/null +++ b/app/api/password_policy/user_password_history_router.py @@ -0,0 +1,53 @@ +"""User Password history router. + +Copyright (c) 2024 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from typing import Annotated + +from dishka import FromDishka +from fastapi import Body, Depends, status +from fastapi_error_map.routing import ErrorAwareRouter +from fastapi_error_map.rules import rule + +from api.auth.utils import verify_auth +from api.error_routing import ( + ERROR_MAP_TYPE, + DishkaErrorAwareRoute, + DomainErrorTranslator, +) +from api.password_policy.adapter import UserPasswordHistoryResetFastAPIAdapter +from enums import DomainCodes +from ldap_protocol.identity.exceptions import ( + AuthorizationError, + UserNotFoundError, +) + +translator = DomainErrorTranslator(DomainCodes.PASSWORD_POLICY) + +error_map: ERROR_MAP_TYPE = { + UserNotFoundError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + AuthorizationError: rule( + status=status.HTTP_401_UNAUTHORIZED, + translator=translator, + ), +} + +user_password_history_router = ErrorAwareRouter( + prefix="/user/password_history", + dependencies=[Depends(verify_auth)], + tags=["User Password history"], + route_class=DishkaErrorAwareRoute, +) + + +@user_password_history_router.post("/clear", error_map=error_map) +async def clear( + identity: Annotated[str, Body(examples=["admin"])], + adapter: FromDishka[UserPasswordHistoryResetFastAPIAdapter], +) -> None: + await adapter.clear(identity) diff --git a/app/api/shadow/adapter.py b/app/api/shadow/adapter.py index 51b4ac271..3062a0bb4 100644 --- a/app/api/shadow/adapter.py +++ b/app/api/shadow/adapter.py @@ -5,34 +5,14 @@ """ from ipaddress import IPv4Address -from typing import ParamSpec, TypeVar - -from fastapi import status from api.base_adapter import BaseAdapter from ldap_protocol.auth import AuthManager, MFAManager -from ldap_protocol.auth.exceptions.mfa import ( - AuthenticationError, - InvalidCredentialsError, - NetworkPolicyError, -) -from ldap_protocol.identity.exceptions import PasswordPolicyError - -P = ParamSpec("P") -R = TypeVar("R") class ShadowAdapter(BaseAdapter): """Adapter for shadow api with FastAPI.""" - _exceptions_map: dict[type[Exception], int] = { - InvalidCredentialsError: status.HTTP_404_NOT_FOUND, - NetworkPolicyError: status.HTTP_403_FORBIDDEN, - AuthenticationError: status.HTTP_401_UNAUTHORIZED, - PasswordPolicyError: status.HTTP_422_UNPROCESSABLE_ENTITY, - PermissionError: status.HTTP_403_FORBIDDEN, - } - def __init__( self, mfa_manager: MFAManager, diff --git a/app/api/shadow/router.py b/app/api/shadow/router.py index cb844a27a..ee8938a18 100644 --- a/app/api/shadow/router.py +++ b/app/api/shadow/router.py @@ -8,18 +8,56 @@ from typing import Annotated from dishka import FromDishka -from dishka.integrations.fastapi import DishkaRoute -from fastapi import APIRouter, Body +from fastapi import Body, status +from fastapi_error_map.routing import ErrorAwareRouter +from fastapi_error_map.rules import rule +from api.error_routing import ( + ERROR_MAP_TYPE, + DishkaErrorAwareRoute, + DomainErrorTranslator, +) +from enums import DomainCodes +from ldap_protocol.auth.exceptions.mfa import ( + AuthenticationError, + InvalidCredentialsError, + NetworkPolicyError, +) +from ldap_protocol.policies.password.exceptions import PasswordPolicyError from ldap_protocol.rootdse.dto import DomainControllerInfo from ldap_protocol.rootdse.reader import DCInfoReader from .adapter import ShadowAdapter -shadow_router = APIRouter(route_class=DishkaRoute) +translator = DomainErrorTranslator(DomainCodes.SHADOW) -@shadow_router.post("/mfa/push") +error_map: ERROR_MAP_TYPE = { + InvalidCredentialsError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + NetworkPolicyError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), + AuthenticationError: rule( + status=status.HTTP_401_UNAUTHORIZED, + translator=translator, + ), + PasswordPolicyError: rule( + status=status.HTTP_422_UNPROCESSABLE_ENTITY, + translator=translator, + ), + PermissionError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), +} +shadow_router = ErrorAwareRouter(route_class=DishkaErrorAwareRoute) + + +@shadow_router.post("/mfa/push", error_map=error_map) async def proxy_request( principal: Annotated[str, Body(embed=True)], ip: Annotated[IPv4Address, Body(embed=True)], @@ -29,7 +67,7 @@ async def proxy_request( return await adapter.proxy_request(principal, ip) -@shadow_router.post("/sync/password") +@shadow_router.post("/sync/password", error_map=error_map) async def change_password( principal: Annotated[str, Body(embed=True)], new_password: Annotated[str, Body(embed=True)], diff --git a/app/config.py b/app/config.py index e5b82da31..423eb2bf8 100644 --- a/app/config.py +++ b/app/config.py @@ -41,6 +41,8 @@ class Settings(BaseModel): PORT: int = 389 TLS_PORT: int = 636 HTTP_PORT: int = 8000 + GLOBAL_LDAP_PORT: int = 3268 + GLOBAL_LDAP_TLS_PORT: int = 3269 USE_CORE_TLS: bool = False LDAP_LOAD_SSL_CERT: bool = False @@ -197,6 +199,23 @@ def get_copy_4_tls(self) -> "Settings": tls_settings.PORT = tls_settings.TLS_PORT return tls_settings + def get_copy_4_global(self) -> "Settings": + """Create a copy for global LDAP server.""" + from copy import copy + + global_settings = copy(self) + global_settings.PORT = global_settings.GLOBAL_LDAP_PORT + return global_settings + + def get_copy_4_global_tls(self) -> "Settings": + """Create a copy for global LDAP server with TLS.""" + from copy import copy + + global_tls_settings = copy(self) + global_tls_settings.USE_CORE_TLS = True + global_tls_settings.PORT = global_tls_settings.GLOBAL_LDAP_TLS_PORT + return global_tls_settings + def check_certs_exist(self) -> bool: """Check if certs exist.""" return os.path.exists(self.SSL_CERT) and os.path.exists(self.SSL_KEY) diff --git a/app/constants.py b/app/constants.py index 136335ffc..f54d78a35 100644 --- a/app/constants.py +++ b/app/constants.py @@ -4,6 +4,21 @@ License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE """ +from typing import TypedDict + +from enums import EntityTypeNames + +GROUPS_CONTAINER_NAME = "groups" +COMPUTERS_CONTAINER_NAME = "computers" +USERS_CONTAINER_NAME = "users" + +READ_ONLY_GROUP_NAME = "read-only" + +DOMAIN_ADMIN_GROUP_NAME = "domain admins" +DOMAIN_USERS_GROUP_NAME = "domain users" +DOMAIN_COMPUTERS_GROUP_NAME = "domain computers" + + group_attrs = { "objectClass": ["top"], "groupType": ["-2147483646"], @@ -117,7 +132,7 @@ }, }, { - "name": "users", + "name": USERS_CONTAINER_NAME, "object_class": "container", "attributes": {"objectClass": ["top"]}, "children": [ @@ -207,24 +222,37 @@ ] -ENTITY_TYPE_DATAS = [ - { - "name": "Domain", - "object_class_names": ["top", "domain", "domainDNS"], - }, - {"name": "Computer", "object_class_names": ["top", "computer"]}, - {"name": "Container", "object_class_names": ["top", "container"]}, - { - "name": "Organizational Unit", - "object_class_names": ["top", "container", "organizationalUnit"], - }, - { - "name": "Group", - "object_class_names": ["top", "group", "posixGroup"], - }, - { - "name": "User", - "object_class_names": [ +class EntityTypeData(TypedDict): + """Entity Type data.""" + + name: EntityTypeNames + object_class_names: list[str] + + +ENTITY_TYPE_DATAS: tuple[EntityTypeData, ...] = ( + EntityTypeData( + name=EntityTypeNames.DOMAIN, + object_class_names=["top", "domain", "domainDNS"], + ), + EntityTypeData( + name=EntityTypeNames.COMPUTER, + object_class_names=["top", "computer"], + ), + EntityTypeData( + name=EntityTypeNames.CONTAINER, + object_class_names=["top", "container"], + ), + EntityTypeData( + name=EntityTypeNames.ORGANIZATIONAL_UNIT, + object_class_names=["top", "container", "organizationalUnit"], + ), + EntityTypeData( + name=EntityTypeNames.GROUP, + object_class_names=["top", "group", "posixGroup"], + ), + EntityTypeData( + name=EntityTypeNames.USER, + object_class_names=[ "top", "user", "person", @@ -233,32 +261,39 @@ "shadowAccount", "inetOrgPerson", ], - }, - {"name": "KRB Container", "object_class_names": ["krbContainer"]}, - { - "name": "KRB Principal", - "object_class_names": [ + ), + EntityTypeData( + name=EntityTypeNames.CONTACT, + object_class_names=[ + "top", + "person", + "organizationalPerson", + "contact", + "mailRecipient", + ], + ), + EntityTypeData( + name=EntityTypeNames.KRB_CONTAINER, + object_class_names=["krbContainer"], + ), + EntityTypeData( + name=EntityTypeNames.KRB_PRINCIPAL, + object_class_names=[ "krbprincipal", "krbprincipalaux", "krbTicketPolicyAux", ], - }, - { - "name": "KRB Realm Container", - "object_class_names": [ - "top", - "krbrealmcontainer", - "krbticketpolicyaux", - ], - }, -] -PRIMARY_ENTITY_TYPE_NAMES = { - entity_type_data["name"] for entity_type_data in ENTITY_TYPE_DATAS -} + ), + EntityTypeData( + name=EntityTypeNames.KRB_REALM_CONTAINER, + object_class_names=["top", "krbrealmcontainer", "krbticketpolicyaux"], + ), +) + FIRST_SETUP_DATA = [ { - "name": "groups", + "name": GROUPS_CONTAINER_NAME, "object_class": "container", "attributes": { "objectClass": ["top"], @@ -266,52 +301,52 @@ }, "children": [ { - "name": "domain admins", + "name": DOMAIN_ADMIN_GROUP_NAME, "object_class": "group", "attributes": { "objectClass": ["top", "posixGroup"], "groupType": ["-2147483646"], "instanceType": ["4"], - "sAMAccountName": ["domain admins"], + "sAMAccountName": [DOMAIN_ADMIN_GROUP_NAME], "sAMAccountType": ["268435456"], "gidNumber": ["512"], }, "objectSid": 512, }, { - "name": "domain users", + "name": DOMAIN_USERS_GROUP_NAME, "object_class": "group", "attributes": { "objectClass": ["top", "posixGroup"], "groupType": ["-2147483646"], "instanceType": ["4"], - "sAMAccountName": ["domain users"], + "sAMAccountName": [DOMAIN_USERS_GROUP_NAME], "sAMAccountType": ["268435456"], "gidNumber": ["513"], }, "objectSid": 513, }, { - "name": "read-only", + "name": READ_ONLY_GROUP_NAME, "object_class": "group", "attributes": { "objectClass": ["top", "posixGroup"], "groupType": ["-2147483646"], "instanceType": ["4"], - "sAMAccountName": ["read-only"], + "sAMAccountName": [READ_ONLY_GROUP_NAME], "sAMAccountType": ["268435456"], "gidNumber": ["521"], }, "objectSid": 521, }, { - "name": "domain computers", + "name": DOMAIN_COMPUTERS_GROUP_NAME, "object_class": "group", "attributes": { "objectClass": ["top", "posixGroup"], "groupType": ["-2147483646"], "instanceType": ["4"], - "sAMAccountName": ["domain computers"], + "sAMAccountName": [DOMAIN_COMPUTERS_GROUP_NAME], "sAMAccountType": ["268435456"], "gidNumber": ["515"], }, @@ -320,12 +355,13 @@ ], }, { - "name": "computers", + "name": COMPUTERS_CONTAINER_NAME, "object_class": "container", "attributes": {"objectClass": ["top"]}, "children": [], }, ] + DEFAULT_DC_POSTFIX = "DC1" UNC_PREFIX = "\\\\" diff --git a/app/entities.py b/app/entities.py index 417c18c93..535da02f6 100644 --- a/app/entities.py +++ b/app/entities.py @@ -190,6 +190,7 @@ class Directory: id: int = field(init=False) name: str + is_system: bool = field(default=False) object_sid: str = field(default="") object_guid: uuid.UUID = field(default_factory=uuid.uuid4) parent_id: int | None = None @@ -201,7 +202,6 @@ class Directory: ) updated_at: datetime | None = field(default=None) depth: int = field(default=0) - password_policy_id: int | None = None path: list[str] = field(default_factory=list) parent: Directory | None = field(default=None, repr=False, compare=False) diff --git a/app/enums.py b/app/enums.py index fb95e3a2f..f482b928e 100644 --- a/app/enums.py +++ b/app/enums.py @@ -45,6 +45,25 @@ class MFAChallengeStatuses(StrEnum): PENDING = "pending" +class EntityTypeNames(StrEnum): + """Enum of base (system) Entity Types. + + Used for system objects. + Custom Entity Types aren't included here. + """ + + DOMAIN = "Domain" + COMPUTER = "Computer" + CONTAINER = "Container" + ORGANIZATIONAL_UNIT = "Organizational Unit" + GROUP = "Group" + USER = "User" + CONTACT = "Contact" + KRB_CONTAINER = "KRB Container" + KRB_PRINCIPAL = "KRB Principal" + KRB_REALM_CONTAINER = "KRB Realm Container" + + class KindType(StrEnum): """Object kind types.""" @@ -189,6 +208,16 @@ class AuthorizationRules(IntFlag): SESSION_CLEAR_USER_SESSIONS = auto() SESSION_DELETE = auto() + NETWORK_POLICY_VALIDATOR_GET_BY_PROTOCOL = auto() + NETWORK_POLICY_VALIDATOR_GET_USER_NETWORK_POLICY = auto() + NETWORK_POLICY_VALIDATOR_GET_USER_HTTP_POLICY = auto() + NETWORK_POLICY_VALIDATOR_GET_USER_KERBEROS_POLICY = auto() + NETWORK_POLICY_VALIDATOR_GET_USER_LDAP_POLICY = auto() + NETWORK_POLICY_VALIDATOR_IS_USER_GROUP_VALID = auto() + NETWORK_POLICY_VALIDATOR_CHECK_MFA_GROUP = auto() + + USER_CLEAR_PASSWORD_HISTORY = auto() + @classmethod def get_all(cls) -> Self: return cls(sum(cls)) @@ -198,3 +227,30 @@ def combine( permissions: Iterable[AuthorizationRules], ) -> AuthorizationRules: return reduce(or_, permissions, AuthorizationRules(0)) + + +class ProtocolType(StrEnum): + """Protocol fields.""" + + LDAP = "is_ldap" + HTTP = "is_http" + KERBEROS = "is_kerberos" + + +class DomainCodes(IntEnum): + """Error code parts.""" + + AUDIT = 1 + AUTH = 2 + SESSION = 3 + DNS = 4 + GENERAL = 5 + KERBEROS = 6 + LDAP = 7 + MFA = 8 + NETWORK = 9 + PASSWORD_POLICY = 10 + ROLES = 11 + DHCP = 12 + LDAP_SCHEMA = 13 + SHADOW = 14 diff --git a/app/errors/__init__.py b/app/errors/__init__.py new file mode 100644 index 000000000..ef8e2bde4 --- /dev/null +++ b/app/errors/__init__.py @@ -0,0 +1,11 @@ +"""Errors package. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from .base import BaseDomainException + +__all__ = [ + "BaseDomainException", +] diff --git a/app/errors/base.py b/app/errors/base.py new file mode 100644 index 000000000..18419ccba --- /dev/null +++ b/app/errors/base.py @@ -0,0 +1,20 @@ +"""Errors base. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from enum import IntEnum + + +class BaseDomainException(Exception): # noqa: N818 + """Base exception.""" + + code: IntEnum + + def __init_subclass__(cls) -> None: + """Initialize subclass.""" + super().__init_subclass__() + + if not hasattr(cls, "code"): + raise AttributeError("code must be set") diff --git a/app/extra/alembic_utils.py b/app/extra/alembic_utils.py index 2afc558ab..ac8cfffd8 100644 --- a/app/extra/alembic_utils.py +++ b/app/extra/alembic_utils.py @@ -1,37 +1,39 @@ """Alembic utils.""" -from typing import Callable +from typing import Any, Callable import sqlalchemy as sa from alembic import op -def temporary_stub_entity_type_name(func: Callable) -> Callable: - """Add and drop the 'entity_type_name' column in the 'Directory' table. +def temporary_stub_column(column_name: str, type_: Any) -> Callable: + """Add and drop a temporary column in the 'Directory' table. State of the database at the time of migration - doesn't contain 'entity_type_name' column in the 'Directory' table, + doesn't contain the specified column in the 'Directory' table, but 'Directory' model has the column. - Before starting the migration, add 'entity_type_name' column. - Then migration completed, delete 'entity_type_name' column. + Before starting the migration, add the specified column. + Then migration completed, delete the column. Don`t like excluding columns with Deferred(), because you will need to refactor SQL queries - that precede the 'ba78cef9700a_initial_entity_type.py' migration - and include working with the Directory. + that precede migrations and include working with the Directory. - :param Callable func: any function - :return Callable: any function + :param str column_name: column name to temporarily add + :return Callable: decorator function """ - def wrapper(*args: tuple, **kwargs: dict) -> None: - op.add_column( - "Directory", - sa.Column("entity_type_id", sa.Integer(), nullable=True), - ) - func(*args, **kwargs) - op.drop_column("Directory", "entity_type_id") - return None + def decorator(func: Callable) -> Callable: + def wrapper(*args: tuple, **kwargs: dict) -> None: + op.add_column( + "Directory", + sa.Column(column_name, type_, nullable=True), + ) + func(*args, **kwargs) + op.drop_column("Directory", column_name) + return None - return wrapper + return wrapper + + return decorator diff --git a/app/extra/scripts/uac_sync.py b/app/extra/scripts/uac_sync.py index dd7b8514d..f0623e1b1 100644 --- a/app/extra/scripts/uac_sync.py +++ b/app/extra/scripts/uac_sync.py @@ -49,12 +49,10 @@ async def disable_accounts( String, ) conditions = [ - ( - cast(Attribute.value, Integer).op("&")( - UserAccountControlFlag.ACCOUNTDISABLE, - ) - == 0 - ), + cast(Attribute.value, Integer).op("&")( + UserAccountControlFlag.ACCOUNTDISABLE, + ) + == 0, qa(Attribute.directory_id).in_(subquery), qa(Attribute.name) == "userAccountControl", ] diff --git a/app/extra/scripts/update_krb5_config.py b/app/extra/scripts/update_krb5_config.py index c83b6ef07..b0ecda0f6 100644 --- a/app/extra/scripts/update_krb5_config.py +++ b/app/extra/scripts/update_krb5_config.py @@ -1,40 +1,69 @@ -"""Kerberos update config. +"""Kerberos configuration update script. Copyright (c) 2025 MultiFactor License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE """ +from pathlib import Path + from loguru import logger from sqlalchemy.ext.asyncio import AsyncSession from config import Settings -from ldap_protocol.kerberos import AbstractKadmin +from ldap_protocol.kerberos.utils import get_system_container_dn from ldap_protocol.utils.queries import get_base_directories +KRB5_CONF_PATH = Path("/etc/krb5kdc/krb5.conf") +KDC_CONF_PATH = Path("/etc/krb5kdc/kdc.conf") +STASH_FILE_PATH = Path("/etc/krb5kdc/krb5.d/stash.keyfile") + + +def _migrate_legacy_dns(content: str) -> str: + """Replace legacy DN formats with current ones. + + :param content: File content to migrate. + :return: Migrated content. + """ + return content.replace("ou=services", "ou=System").replace( + "ou=users", + "cn=users", + ) + async def update_krb5_config( - kadmin: AbstractKadmin, session: AsyncSession, settings: Settings, ) -> None: - """Update kerberos config.""" - if not (await kadmin.get_status(wait_for_positive=True)): - logger.error("kadmin_api is not running") - return + """Update Kerberos configuration files via direct write to shared volume. - base_dn_list = await get_base_directories(session) - base_dn = base_dn_list[0].path_dn - domain: str = base_dn_list[0].name + Renders krb5.conf and kdc.conf from templates and writes them directly + to the shared volume. Also migrates legacy DN formats in stash.keyfile + if present (ou=services -> ou=System, ou=users -> cn=users). - krbadmin = "cn=krbadmin,cn=users," + base_dn - services_container = "ou=services," + base_dn + :param session: Database session for fetching base directories. + :param settings: Application settings with template environment. + :raises Exception: If config rendering or writing fails. + """ + if not KRB5_CONF_PATH.parent.exists(): + logger.error( + f"Config directory {KRB5_CONF_PATH.parent} not found, " + "kerberos volume not mounted", + ) + return - krb5_template = settings.TEMPLATES.get_template("krb5.conf") - kdc_template = settings.TEMPLATES.get_template("kdc.conf") + base_dn_list = await get_base_directories(session) + if not base_dn_list: + logger.error("No base directories found") + return - kdc_config = await kdc_template.render_async(domain=domain) + base_dn = base_dn_list[0].path_dn + domain = base_dn_list[0].name + krbadmin = f"cn=krbadmin,cn=users,{base_dn}" + services_container = get_system_container_dn(base_dn) - krb5_config = await krb5_template.render_async( + krb5_config = await settings.TEMPLATES.get_template( + "krb5.conf", + ).render_async( domain=domain, krbadmin=krbadmin, services_container=services_container, @@ -42,5 +71,19 @@ async def update_krb5_config( mfa_push_url=settings.KRB5_MFA_PUSH_URL, sync_password_url=settings.KRB5_SYNC_PASSWORD_URL, ) + kdc_config = await settings.TEMPLATES.get_template( + "kdc.conf", + ).render_async( + domain=domain, + ) + + KRB5_CONF_PATH.write_text(krb5_config, encoding="utf-8") + KDC_CONF_PATH.write_text(kdc_config, encoding="utf-8") - await kadmin.setup_configs(krb5_config, kdc_config) + if STASH_FILE_PATH.exists(): + stash_content = STASH_FILE_PATH.read_text(encoding="utf-8") + if "ou=services" in stash_content or "ou=users" in stash_content: + STASH_FILE_PATH.write_text( + _migrate_legacy_dns(stash_content), + encoding="utf-8", + ) diff --git a/app/ioc.py b/app/ioc.py index 2c733e6a8..d6489f842 100644 --- a/app/ioc.py +++ b/app/ioc.py @@ -12,6 +12,7 @@ from fastapi import Request from loguru import logger from sqlalchemy.ext.asyncio import ( + AsyncConnection, AsyncEngine, AsyncSession, async_sessionmaker, @@ -34,6 +35,7 @@ from api.password_policy.adapter import ( PasswordBanWordsFastAPIAdapter, PasswordPolicyFastAPIAdapter, + UserPasswordHistoryResetFastAPIAdapter, ) from api.shadow.adapter import ShadowAdapter from authorization_provider_protocol import AuthorizationProviderProtocol @@ -65,6 +67,7 @@ from ldap_protocol.kerberos.service import KerberosService from ldap_protocol.kerberos.template_render import KRBTemplateRenderer from ldap_protocol.ldap_requests.contexts import ( + LDAPAbandonRequestContext, LDAPAddRequestContext, LDAPBindRequestContext, LDAPDeleteRequestContext, @@ -78,6 +81,9 @@ from ldap_protocol.ldap_schema.attribute_type_use_case import ( AttributeTypeUseCase, ) +from ldap_protocol.ldap_schema.attribute_value_validator import ( + AttributeValueValidator, +) from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO from ldap_protocol.ldap_schema.entity_type_use_case import EntityTypeUseCase from ldap_protocol.ldap_schema.object_class_dao import ObjectClassDAO @@ -112,8 +118,13 @@ ) from ldap_protocol.policies.audit.policies_dao import AuditPoliciesDAO from ldap_protocol.policies.audit.service import AuditService -from ldap_protocol.policies.network.gateway import NetworkPolicyGateway -from ldap_protocol.policies.network.use_cases import NetworkPolicyUseCase +from ldap_protocol.policies.network import ( + NetworkPolicyGateway, + NetworkPolicyUseCase, + NetworkPolicyValidatorGateway, + NetworkPolicyValidatorProtocol, + NetworkPolicyValidatorUseCase, +) from ldap_protocol.policies.password import ( PasswordPolicyDAO, PasswordPolicyUseCases, @@ -123,7 +134,10 @@ PasswordBanWordRepository, ) from ldap_protocol.policies.password.settings import PasswordValidatorSettings -from ldap_protocol.policies.password.use_cases import PasswordBanWordUseCases +from ldap_protocol.policies.password.use_cases import ( + PasswordBanWordUseCases, + UserPasswordHistoryUseCases, +) from ldap_protocol.roles.access_manager import AccessManager from ldap_protocol.roles.ace_dao import AccessControlEntryDAO from ldap_protocol.roles.role_dao import RoleDAO @@ -205,7 +219,7 @@ async def get_kadmin_http( yield KadminHTTPClient(client) @provide(scope=Scope.REQUEST) - async def get_kadmin( + def get_kadmin( self, client: KadminHTTPClient, kadmin_class: type[AbstractKadmin], @@ -260,14 +274,14 @@ async def get_dns_http_client( yield DNSManagerHTTPClient(client) @provide(scope=Scope.REQUEST) - async def get_dns_mngr( + def get_dns_mngr( self, settings: DNSManagerSettings, dns_manager_class: type[AbstractDNSManager], http_client: DNSManagerHTTPClient, - ) -> AsyncIterator[AbstractDNSManager]: + ) -> AbstractDNSManager: """Get DNSManager class.""" - yield dns_manager_class(settings=settings, http_client=http_client) + return dns_manager_class(settings=settings, http_client=http_client) @provide(scope=Scope.APP) async def get_redis_for_sessions( @@ -284,7 +298,7 @@ async def get_redis_for_sessions( await client.aclose() @provide(scope=Scope.APP) - async def get_session_storage( + def get_session_storage( self, client: SessionStorageClient, settings: Settings, @@ -297,7 +311,7 @@ async def get_session_storage( ) @provide() - async def get_normalized_audit_event( + def get_normalized_audit_event( self, ) -> type[NormalizedAuditEvent]: """Get normalized audit event class.""" @@ -318,13 +332,13 @@ async def get_audit_redis_client( await client.aclose() @provide(scope=Scope.APP) - async def get_raw_audit_manager( + def get_raw_audit_manager( self, client: AuditRedisClient, settings: Settings, - ) -> AsyncIterator[RawAuditManager]: + ) -> RawAuditManager: """Get raw audit manager.""" - yield RawAuditManager( + return RawAuditManager( client, settings.RAW_EVENT_STREAM_NAME, settings.EVENT_HANDLER_GROUP, @@ -333,13 +347,13 @@ async def get_raw_audit_manager( ) @provide(scope=Scope.APP) - async def get_normalized_audit_manager( + def get_normalized_audit_manager( self, client: AuditRedisClient, settings: Settings, - ) -> AsyncIterator[NormalizedAuditManager]: + ) -> NormalizedAuditManager: """Get raw audit manager.""" - yield NormalizedAuditManager( + return NormalizedAuditManager( client, settings.NORMALIZED_EVENT_STREAM_NAME, settings.EVENT_SENDER_GROUP, @@ -352,7 +366,7 @@ async def get_normalized_audit_manager( audit_destination_dao = provide(AuditDestinationDAO, scope=Scope.REQUEST) @provide(scope=Scope.REQUEST) - async def get_dhcp_manager_repository( + def get_dhcp_manager_repository( self, session: AsyncSession, ) -> DHCPManagerRepository: @@ -368,20 +382,20 @@ async def get_dhcp_manager_state( return await dhcp_manager_repository.ensure_state() @provide(scope=Scope.REQUEST) - async def get_dhcp_mngr_class( + def get_dhcp_mngr_class( self, dhcp_state: DHCPManagerState, ) -> type[AbstractDHCPManager]: """Get DHCP manager type.""" - return await get_dhcp_manager_class(dhcp_state) + return get_dhcp_manager_class(dhcp_state) @provide(scope=Scope.REQUEST) - async def get_dhcp_api_repository_class( + def get_dhcp_api_repository_class( self, dhcp_state: DHCPManagerState, ) -> type[DHCPAPIRepository]: """Get DHCP API repository type.""" - return await get_dhcp_api_repository_class(dhcp_state) + return get_dhcp_api_repository_class(dhcp_state) @provide(scope=Scope.APP) async def get_dhcp_http_client( @@ -395,7 +409,7 @@ async def get_dhcp_http_client( yield DHCPManagerHTTPClient(http_client) @provide(scope=Scope.REQUEST) - async def get_dhcp_api_repository( + def get_dhcp_api_repository( self, http_client: DHCPManagerHTTPClient, dhcp_api_repository_class: type[DHCPAPIRepository], @@ -404,7 +418,7 @@ async def get_dhcp_api_repository( return dhcp_api_repository_class(http_client) @provide(scope=Scope.REQUEST) - async def get_dhcp_mngr( + def get_dhcp_mngr( self, dhcp_manager_class: type[AbstractDHCPManager], dhcp_api_repository: DHCPAPIRepository, @@ -416,6 +430,10 @@ async def get_dhcp_mngr( kea_dhcp_repository=dhcp_api_repository, ) + attribute_value_validator = provide( + AttributeValueValidator, + scope=Scope.RUNTIME, + ) attribute_type_dao = provide(AttributeTypeDAO, scope=Scope.REQUEST) object_class_dao = provide(ObjectClassDAO, scope=Scope.REQUEST) entity_type_dao = provide(EntityTypeDAO, scope=Scope.REQUEST) @@ -425,6 +443,10 @@ async def get_dhcp_mngr( ) object_class_use_case = provide(ObjectClassUseCase, scope=Scope.REQUEST) + user_password_history_use_cases = provide( + UserPasswordHistoryUseCases, + scope=Scope.REQUEST, + ) password_policy_validator = provide( PasswordPolicyValidator, scope=Scope.REQUEST, @@ -445,7 +467,7 @@ async def get_dhcp_mngr( ) password_utils = provide(PasswordUtils, scope=Scope.RUNTIME) - access_manager = provide(AccessManager, scope=Scope.REQUEST) + access_manager = provide(AccessManager, scope=Scope.RUNTIME) role_dao = provide(RoleDAO, scope=Scope.REQUEST) ace_dao = provide(AccessControlEntryDAO, scope=Scope.REQUEST) role_use_case = provide(RoleUseCase, scope=Scope.REQUEST) @@ -490,12 +512,16 @@ class LDAPContextProvider(Provider): LDAPModifyDNRequestContext, scope=Scope.REQUEST, ) + unbind_request_context = provide( + LDAPUnbindRequestContext, + scope=Scope.REQUEST, + ) search_request_context = provide( LDAPSearchRequestContext, scope=Scope.REQUEST, ) - unbind_request_context = provide( - LDAPUnbindRequestContext, + abandon_request_context = provide( + LDAPAbandonRequestContext, scope=Scope.REQUEST, ) @@ -506,9 +532,23 @@ class HTTPProvider(LDAPContextProvider): scope = Scope.REQUEST request = from_context(provides=Request, scope=Scope.REQUEST) monitor_use_case = provide(AuditMonitorUseCase, scope=Scope.REQUEST) + network_policy_gateway = provide(NetworkPolicyGateway, scope=Scope.REQUEST) + network_policy_use_case = provide( + NetworkPolicyUseCase, + scope=Scope.REQUEST, + ) + network_policy_validator_gateway = provide( + NetworkPolicyValidatorGateway, + provides=NetworkPolicyValidatorProtocol, + scope=Scope.REQUEST, + ) + network_policy_validator_use_case = provide( + NetworkPolicyValidatorUseCase, + scope=Scope.REQUEST, + ) @provide() - async def get_audit_monitor( + def get_audit_monitor( self, session: AsyncSession, audit_use_case: "AuditUseCase", @@ -536,6 +576,10 @@ async def get_audit_monitor( scope=Scope.REQUEST, ) + user_password_history_reset_adapter = provide( + UserPasswordHistoryResetFastAPIAdapter, + scope=Scope.REQUEST, + ) password_policies_adapter = provide( PasswordPolicyFastAPIAdapter, scope=Scope.REQUEST, @@ -568,7 +612,7 @@ def get_permissions_provider( return auth_provider @provide() - async def get_identity_provider( + def get_identity_provider( self, request: Request, session_storage: SessionStorage, @@ -643,14 +687,41 @@ def get_krb_template_render( NetworkPolicyFastAPIAdapter, scope=Scope.REQUEST, ) - network_policy_use_case = provide( - NetworkPolicyUseCase, + + +class LDAPServerProvider(LDAPContextProvider): + """Provider with session scope.""" + + scope = Scope.SESSION + + network_policy_validator_gateway = provide( + NetworkPolicyValidatorGateway, scope=Scope.REQUEST, ) - network_policy_gateway = provide(NetworkPolicyGateway, scope=Scope.REQUEST) + network_policy_validator = provide( + NetworkPolicyValidatorGateway, + provides=NetworkPolicyValidatorProtocol, + scope=Scope.REQUEST, + ) + network_policy_validator_use_case = provide( + NetworkPolicyValidatorUseCase, + scope=Scope.REQUEST, + ) -class LDAPServerProvider(LDAPContextProvider): + @provide(scope=Scope.SESSION, provides=LDAPSession) + async def get_session( + self, + storage: SessionStorage, + ) -> AsyncIterator[LDAPSession]: + """Create ldap session.""" + session = LDAPSession(storage=storage) + await session.start() + yield session + await session.disconnect() + + +class GlobalLDAPServerProvider(Provider): """Provider with session scope.""" scope = Scope.SESSION @@ -666,6 +737,29 @@ async def get_session( yield session await session.disconnect() + bind_request_context = provide( + LDAPBindRequestContext, + scope=Scope.REQUEST, + ) + search_request_context = provide( + LDAPSearchRequestContext, + scope=Scope.REQUEST, + ) + unbind_request_context = provide( + LDAPUnbindRequestContext, + scope=Scope.REQUEST, + ) + + network_policy_validator = provide( + NetworkPolicyValidatorGateway, + provides=NetworkPolicyValidatorProtocol, + scope=Scope.REQUEST, + ) + network_policy_validator_use_case = provide( + NetworkPolicyValidatorUseCase, + scope=Scope.REQUEST, + ) + class MFACredsProvider(Provider): """Creds provider.""" @@ -739,7 +833,7 @@ async def get_client( yield MFAHTTPClient(client) @provide(provides=MultifactorAPI) - async def get_http_mfa( + def get_http_mfa( self, credentials: MFA_HTTP_Creds, client: MFAHTTPClient, @@ -761,7 +855,7 @@ async def get_http_mfa( ) @provide(provides=LDAPMultiFactorAPI) - async def get_ldap_mfa( + def get_ldap_mfa( self, credentials: MFA_LDAP_Creds, client: MFAHTTPClient, @@ -783,3 +877,26 @@ async def get_ldap_mfa( settings, ), ) + + +class MigrationProvider(Provider): + """Provider for migrations.""" + + scope = Scope.APP + + @provide(scope=Scope.APP) + def get_session_factory( + self, + connection: AsyncConnection, + ) -> AsyncSession: + """Create session factory.""" + return AsyncSession(connection) + + @provide(scope=Scope.APP) + async def get_conn_factory( + self, + engine: AsyncEngine, + ) -> AsyncIterator[AsyncConnection]: + """Create session factory.""" + async with engine.connect() as connection: + yield connection diff --git a/app/ldap_protocol/auth/auth_manager.py b/app/ldap_protocol/auth/auth_manager.py index 47dc4df3a..61c336f56 100644 --- a/app/ldap_protocol/auth/auth_manager.py +++ b/app/ldap_protocol/auth/auth_manager.py @@ -32,10 +32,7 @@ from ldap_protocol.multifactor import MultifactorAPI from ldap_protocol.objects import UserAccountControlFlag from ldap_protocol.policies.audit.monitor import AuditMonitorUseCase -from ldap_protocol.policies.network_policy import ( - check_mfa_group, - get_user_network_policy, -) +from ldap_protocol.policies.network import NetworkPolicyValidatorUseCase from ldap_protocol.policies.password import PasswordPolicyUseCases from ldap_protocol.session_storage import SessionStorage from ldap_protocol.session_storage.repository import SessionRepository @@ -61,6 +58,7 @@ def __init__( mfa_manager: MFAManager, setup_use_case: SetupUseCase, identity_provider: IdentityProvider, + network_policy_validator: NetworkPolicyValidatorUseCase, ) -> None: """Initialize dependencies of the manager (via DI). @@ -84,6 +82,7 @@ def __init__( self._mfa_manager = mfa_manager self._setup_use_case = setup_use_case self._identity_provider = identity_provider + self._network_policy_validator = network_policy_validator def __getattribute__(self, name: str) -> object: """Intercept attribute access.""" @@ -147,11 +146,11 @@ async def login( if user.is_expired(): raise LoginFailedError("User account is expired") - network_policy = await get_user_network_policy( - ip, - user, - self._session, - policy_type="is_http", + network_policy = ( + await self._network_policy_validator.get_user_http_policy( + ip, + user, + ) ) if network_policy is None: raise LoginFailedError("User not part of network policy") @@ -162,10 +161,11 @@ async def login( ): request_2fa = True if network_policy.mfa_status == MFAFlags.WHITELIST: - request_2fa = await check_mfa_group( - network_policy, - user, - self._session, + request_2fa = ( + await self._network_policy_validator.check_mfa_group( + network_policy, + user, + ) ) if request_2fa: ( diff --git a/app/ldap_protocol/auth/exceptions/__init__.py b/app/ldap_protocol/auth/exceptions/__init__.py index 2f47c8dcb..3fba3e212 100644 --- a/app/ldap_protocol/auth/exceptions/__init__.py +++ b/app/ldap_protocol/auth/exceptions/__init__.py @@ -5,6 +5,7 @@ """ from .mfa import ( + AuthenticationError, InvalidCredentialsError, MFARequiredError, MFATokenError, @@ -20,4 +21,5 @@ "InvalidCredentialsError", "NetworkPolicyError", "NotFoundError", + "AuthenticationError", ] diff --git a/app/ldap_protocol/auth/exceptions/mfa.py b/app/ldap_protocol/auth/exceptions/mfa.py index 53d1dfd58..144f752b3 100644 --- a/app/ldap_protocol/auth/exceptions/mfa.py +++ b/app/ldap_protocol/auth/exceptions/mfa.py @@ -4,46 +4,88 @@ License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE """ +from enum import IntEnum -class MFAIdentityError(Exception): +from errors import BaseDomainException + + +class ErrorCodes(IntEnum): + """Error codes.""" + + BASE_ERROR = 0 + FORBIDDEN_ERROR = 1 + MFA_REQUIRED_ERROR = 2 + MFA_TOKEN_ERROR = 3 + MFA_API_ERROR = 4 + MFA_CONNECT_ERROR = 5 + MISSING_MFA_CREDENTIALS_ERROR = 6 + INVALID_CREDENTIALS_ERROR = 7 + NETWORK_POLICY_ERROR = 8 + NOT_FOUND_ERROR = 9 + AUTHENTICATION_ERROR = 10 + + +class MFAError(BaseDomainException): """Base exception for MFA identity-related errors.""" + code: ErrorCodes = ErrorCodes.BASE_ERROR + -class ForbiddenError(MFAIdentityError): +class ForbiddenError(MFAError): """Raised when an action is forbidden.""" + code = ErrorCodes.FORBIDDEN_ERROR -class MFARequiredError(MFAIdentityError): + +class MFARequiredError(MFAError): """Raised when MFA is required for authentication.""" + code = ErrorCodes.MFA_REQUIRED_ERROR + -class MFATokenError(MFAIdentityError): +class MFATokenError(MFAError): """Raised when an MFA token is invalid or missing.""" + code = ErrorCodes.MFA_TOKEN_ERROR -class MFAAPIError(MFAIdentityError): + +class MFAAPIError(MFAError): """Raised when an MFA API error occurs.""" + code = ErrorCodes.MFA_API_ERROR + -class MFAConnectError(MFAIdentityError): +class MFAConnectError(MFAError): """Raised when an MFA connect error occurs.""" + code = ErrorCodes.MFA_CONNECT_ERROR -class MissingMFACredentialsError(MFAIdentityError): + +class MissingMFACredentialsError(MFAError): """Raised when MFA credentials are missing or not configured.""" + code = ErrorCodes.MISSING_MFA_CREDENTIALS_ERROR + -class InvalidCredentialsError(MFAIdentityError): +class InvalidCredentialsError(MFAError): """Raised when provided credentials are invalid.""" + code = ErrorCodes.INVALID_CREDENTIALS_ERROR -class NetworkPolicyError(MFAIdentityError): + +class NetworkPolicyError(MFAError): """Raised when a network policy violation occurs.""" + code = ErrorCodes.NETWORK_POLICY_ERROR + -class NotFoundError(MFAIdentityError): +class NotFoundError(MFAError): """Raised when a required resource is not found user, MFA config.""" + code = ErrorCodes.NOT_FOUND_ERROR -class AuthenticationError(MFAIdentityError): + +class AuthenticationError(MFAError): """Raised when an authentication attempt fails.""" + + code = ErrorCodes.AUTHENTICATION_ERROR diff --git a/app/ldap_protocol/auth/mfa_manager.py b/app/ldap_protocol/auth/mfa_manager.py index 50d7024d7..334a66a44 100644 --- a/app/ldap_protocol/auth/mfa_manager.py +++ b/app/ldap_protocol/auth/mfa_manager.py @@ -46,10 +46,7 @@ MultifactorAPI, ) from ldap_protocol.policies.audit.monitor import AuditMonitorUseCase -from ldap_protocol.policies.network_policy import ( - check_mfa_group, - get_user_network_policy, -) +from ldap_protocol.policies.network import NetworkPolicyValidatorUseCase from ldap_protocol.session_storage import SessionStorage from ldap_protocol.session_storage.repository import SessionRepository from password_utils import PasswordUtils @@ -72,6 +69,7 @@ def __init__( monitor: AuditMonitorUseCase, password_utils: PasswordUtils, identity_provider: IdentityProvider, + network_policy_validator: NetworkPolicyValidatorUseCase, ) -> None: """Initialize dependencies via DI. @@ -90,6 +88,7 @@ def __init__( self._monitor = monitor self._password_utils = password_utils self._identity_provider = identity_provider + self._network_policy_validator = network_policy_validator def __getattribute__(self, name: str) -> object: """Intercept attribute access.""" @@ -328,11 +327,11 @@ async def proxy_request(self, principal: str, ip: IPv4Address) -> None: f"User {principal} not found in the database.", ) - network_policy = await get_user_network_policy( - ip, - user, - self._session, - policy_type="is_kerberos", + network_policy = ( + await self._network_policy_validator.get_user_kerberos_policy( + ip, + user, + ) ) if network_policy is None or not network_policy.is_kerberos: @@ -351,10 +350,9 @@ async def proxy_request(self, principal: str, ip: IPv4Address) -> None: ): if ( network_policy.mfa_status == MFAFlags.WHITELIST - and not await check_mfa_group( + and not await self._network_policy_validator.check_mfa_group( network_policy, user, - self._session, ) ): return diff --git a/app/ldap_protocol/auth/setup_gateway.py b/app/ldap_protocol/auth/setup_gateway.py index d294eb0c3..b5bfe580a 100644 --- a/app/ldap_protocol/auth/setup_gateway.py +++ b/app/ldap_protocol/auth/setup_gateway.py @@ -12,6 +12,9 @@ from sqlalchemy.ext.asyncio import AsyncSession from entities import Attribute, Directory, Group, NetworkPolicy, User +from ldap_protocol.ldap_schema.attribute_value_validator import ( + AttributeValueValidator, +) from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO from ldap_protocol.utils.helpers import create_object_sid, generate_domain_sid from ldap_protocol.utils.queries import get_domain_object_class @@ -27,6 +30,7 @@ def __init__( session: AsyncSession, password_utils: PasswordUtils, entity_type_dao: EntityTypeDAO, + attribute_value_validator: AttributeValueValidator, ) -> None: """Initialize Setup use case. @@ -37,6 +41,7 @@ def __init__( self._session = session self._password_utils = password_utils self._entity_type_dao = entity_type_dao + self._attribute_value_validator = attribute_value_validator async def is_setup(self) -> bool: """Check if setup is performed. @@ -44,8 +49,9 @@ async def is_setup(self) -> bool: :return: bool (True if setup is performed, False otherwise) """ query = select( - exists(Directory).where(qa(Directory.parent_id).is_(None)), - ) + exists(Directory) + .where(qa(Directory.parent_id).is_(None)), + ) # fmt: skip retval = await self._session.scalars(query) return retval.one() @@ -53,6 +59,7 @@ async def setup_enviroment( self, *, data: list, + is_system: bool, dn: str = "multifactor.dev", ) -> None: """Create directories and users for enviroment.""" @@ -61,10 +68,8 @@ async def setup_enviroment( logger.warning("dev data already set up") return - domain = Directory( - name=dn, - object_class="domain", - ) + domain = Directory(name=dn, object_class="domain") + domain.is_system = True domain.object_sid = generate_domain_sid() domain.path = [f"dc={path}" for path in reversed(dn.split("."))] domain.depth = len(domain.path) @@ -94,14 +99,19 @@ async def setup_enviroment( directory=domain, is_system_entity_type=True, ) + if not self._attribute_value_validator.is_directory_valid(domain): + raise ValueError( + "Invalid directory attribute values during environment setup", # noqa: E501 + ) await self._session.flush() try: for unit in data: await self.create_dir( unit, - domain, - domain, + is_system=is_system, + domain=domain, + parent=domain, ) except Exception: @@ -113,11 +123,13 @@ async def setup_enviroment( async def create_dir( self, data: dict, + is_system: bool, domain: Directory, parent: Directory | None = None, ) -> None: """Create data recursively.""" dir_ = Directory( + is_system=is_system, object_class=data["object_class"], name=data["name"], parent=parent, @@ -199,21 +211,24 @@ async def create_dir( await self._session.refresh( instance=dir_, - attribute_names=["attributes"], + attribute_names=["attributes", "user"], with_for_update=None, ) await self._entity_type_dao.attach_entity_type_to_directory( directory=dir_, is_system_entity_type=True, ) + if not self._attribute_value_validator.is_directory_valid(dir_): + raise ValueError("Invalid directory attribute values") await self._session.flush() if "children" in data: for n_data in data["children"]: await self.create_dir( n_data, - domain, - dir_, + is_system=is_system, + domain=domain, + parent=dir_, ) async def _get_group(self, name: str) -> Group: diff --git a/app/ldap_protocol/auth/use_cases.py b/app/ldap_protocol/auth/use_cases.py index 81957b5da..b9a53414e 100644 --- a/app/ldap_protocol/auth/use_cases.py +++ b/app/ldap_protocol/auth/use_cases.py @@ -9,7 +9,11 @@ from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession -from constants import FIRST_SETUP_DATA +from constants import ( + DOMAIN_ADMIN_GROUP_NAME, + FIRST_SETUP_DATA, + USERS_CONTAINER_NAME, +) from ldap_protocol.auth.dto import SetupDTO from ldap_protocol.auth.setup_gateway import SetupGateway from ldap_protocol.identity.exceptions import ( @@ -79,7 +83,7 @@ def _create_user_data(self, dto: SetupDTO) -> dict: :return: dict with user data """ return { - "name": "users", + "name": USERS_CONTAINER_NAME, "object_class": "container", "attributes": {"objectClass": ["top"]}, "children": [ @@ -92,7 +96,7 @@ def _create_user_data(self, dto: SetupDTO) -> dict: "mail": dto.mail, "display_name": dto.display_name, "password": dto.password, - "groups": ["domain admins"], + "groups": [DOMAIN_ADMIN_GROUP_NAME], }, "attributes": { "objectClass": [ @@ -109,6 +113,7 @@ def _create_user_data(self, dto: SetupDTO) -> dict: "gidNumber": ["513"], "userAccountControl": ["512"], "primaryGroupID": ["512"], + "givenName": [dto.username], }, "objectSid": 500, }, @@ -126,6 +131,7 @@ async def _create(self, dto: SetupDTO, data: list) -> None: await self._setup_gateway.setup_enviroment( data=data, dn=dto.domain, + is_system=True, ) await self._password_use_cases.create_default_domain_policy() diff --git a/app/ldap_protocol/dhcp/__init__.py b/app/ldap_protocol/dhcp/__init__.py index 27df7d0c0..cf26f1903 100644 --- a/app/ldap_protocol/dhcp/__init__.py +++ b/app/ldap_protocol/dhcp/__init__.py @@ -26,7 +26,7 @@ from .stub import StubDHCPAPIRepository, StubDHCPManager -async def get_dhcp_manager_class( +def get_dhcp_manager_class( dhcp_state: DHCPManagerState, ) -> type[AbstractDHCPManager]: """Get an instance of the DHCP manager.""" @@ -35,7 +35,7 @@ async def get_dhcp_manager_class( return StubDHCPManager -async def get_dhcp_api_repository_class( +def get_dhcp_api_repository_class( dhcp_state: DHCPManagerState, ) -> type[DHCPAPIRepository]: """Get an instance of the DHCP API repository.""" diff --git a/app/ldap_protocol/dhcp/exceptions.py b/app/ldap_protocol/dhcp/exceptions.py index ce77e867c..4b29a4514 100644 --- a/app/ldap_protocol/dhcp/exceptions.py +++ b/app/ldap_protocol/dhcp/exceptions.py @@ -4,46 +4,88 @@ License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE """ +from enum import IntEnum -class DHCPError(Exception): +from errors import BaseDomainException + + +class ErrorCodes(IntEnum): + """Error codes.""" + + BASE_ERROR = 0 + DHCP_API_ERROR = 1 + DHCP_VALIDATION_ERROR = 2 + DHCP_CONNECTION_ERROR = 3 + DHCP_OPERATION_ERROR = 4 + DHCP_ENTRY_ADD_ERROR = 5 + DHCP_ENTRY_NOT_FOUND_ERROR = 6 + DHCP_ENTRY_DELETE_ERROR = 7 + DHCP_ENTRY_UPDATE_ERROR = 8 + DHCP_CONFLICT_ERROR = 9 + DHCP_UNSUPPORTED_ERROR = 10 + + +class DHCPError(BaseDomainException): """DHCP base exception.""" + code: ErrorCodes = ErrorCodes.BASE_ERROR + class DHCPAPIError(DHCPError): """DHCP API error.""" + code = ErrorCodes.DHCP_API_ERROR + class DHCPValidatonError(DHCPError): """DHCP validation error.""" + code = ErrorCodes.DHCP_VALIDATION_ERROR + class DHCPConnectionError(ConnectionError): """DHCP connection error.""" + code = ErrorCodes.DHCP_CONNECTION_ERROR + class DHCPOperationError(DHCPError): """DHCP operation error.""" + code = ErrorCodes.DHCP_OPERATION_ERROR + class DHCPEntryAddError(DHCPError): """DHCP entry addition error.""" + code = ErrorCodes.DHCP_ENTRY_ADD_ERROR + class DHCPEntryNotFoundError(DHCPError): """DHCP entry not found error.""" + code = ErrorCodes.DHCP_ENTRY_NOT_FOUND_ERROR + class DHCPEntryDeleteError(DHCPError): """DHCP entry deletion error.""" + code = ErrorCodes.DHCP_ENTRY_DELETE_ERROR + class DHCPEntryUpdateError(DHCPError): """DHCP entry update error.""" + code = ErrorCodes.DHCP_ENTRY_UPDATE_ERROR + class DHCPConflictError(DHCPError): """DHCP conflict error.""" + code = ErrorCodes.DHCP_CONFLICT_ERROR + class DHCPUnsupportedError(DHCPError): """DHCP unsupported error.""" + + code = ErrorCodes.DHCP_UNSUPPORTED_ERROR diff --git a/app/ldap_protocol/dialogue.py b/app/ldap_protocol/dialogue.py index c15a53051..a594b628c 100644 --- a/app/ldap_protocol/dialogue.py +++ b/app/ldap_protocol/dialogue.py @@ -16,10 +16,10 @@ from typing import TYPE_CHECKING, AsyncIterator import gssapi -from sqlalchemy.ext.asyncio import AsyncSession from entities import NetworkPolicy, User -from ldap_protocol.policies.network_policy import build_policy_query +from enums import ProtocolType +from ldap_protocol.policies.network import NetworkPolicyValidatorUseCase from .session_storage import SessionStorage @@ -142,21 +142,16 @@ async def lock(self) -> AsyncIterator[UserSchema | None]: async with self._lock: yield self._user - @staticmethod - async def _get_policy( - ip: IPv4Address, - session: AsyncSession, - ) -> NetworkPolicy | None: - query = build_policy_query(ip, "is_ldap") - return await session.scalar(query) - async def validate_conn( self, ip: IPv4Address | IPv6Address, - session: AsyncSession, + network_policy_use_case: NetworkPolicyValidatorUseCase, ) -> None: """Validate network policies.""" - policy = await self._get_policy(ip, session) # type: ignore + policy = await network_policy_use_case.get_by_protocol( + ip, + ProtocolType.LDAP, + ) if policy is not None: self.policy = policy await self.bind_session() diff --git a/app/ldap_protocol/dns/__init__.py b/app/ldap_protocol/dns/__init__.py index 235fc83f9..f9c97fba7 100644 --- a/app/ldap_protocol/dns/__init__.py +++ b/app/ldap_protocol/dns/__init__.py @@ -3,8 +3,6 @@ DNS_MANAGER_STATE_NAME, DNS_MANAGER_ZONE_NAME, AbstractDNSManager, - DNSConnectionError, - DNSError, DNSForwardServerStatus, DNSForwardZone, DNSManagerSettings, @@ -19,6 +17,7 @@ DNSZoneType, ) from .dns_gateway import DNSStateGateway +from .exceptions import DNSConnectionError, DNSError from .remote import RemoteDNSManager from .selfhosted import SelfHostedDNSManager from .stub import StubDNSManager diff --git a/app/ldap_protocol/dns/base.py b/app/ldap_protocol/dns/base.py index 275167b60..01fe71c8c 100644 --- a/app/ldap_protocol/dns/base.py +++ b/app/ldap_protocol/dns/base.py @@ -46,14 +46,6 @@ class DNSForwarderServerStatus(StrEnum): NOT_FOUND = "not found" -class DNSConnectionError(ConnectionError): - """API Error.""" - - -class DNSError(Exception): - """DNS Error.""" - - class DNSNotImplementedError(NotImplementedError): """API Not Implemented Error.""" diff --git a/app/ldap_protocol/dns/exceptions.py b/app/ldap_protocol/dns/exceptions.py index e7474e1a3..5b9da9f5e 100644 --- a/app/ldap_protocol/dns/exceptions.py +++ b/app/ldap_protocol/dns/exceptions.py @@ -4,38 +4,88 @@ License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE """ +from enum import IntEnum -class DNSError(Exception): +from errors import BaseDomainException + + +class ErrorCodes(IntEnum): + """Error codes.""" + + BASE_ERROR = 0 + DNS_SETUP_ERROR = 1 + DNS_RECORD_CREATE_ERROR = 2 + DNS_RECORD_UPDATE_ERROR = 3 + DNS_RECORD_DELETE_ERROR = 4 + DNS_ZONE_CREATE_ERROR = 5 + DNS_ZONE_UPDATE_ERROR = 6 + DNS_ZONE_DELETE_ERROR = 7 + DNS_UPDATE_SERVER_OPTIONS_ERROR = 8 + DNS_CONNECTION_ERROR = 9 + DNS_NOT_IMPLEMENTED_ERROR = 10 + + +class DNSError(BaseDomainException): """DNS Error.""" + code: ErrorCodes = ErrorCodes.BASE_ERROR + class DNSSetupError(DNSError): """DNS setup error.""" + code = ErrorCodes.DNS_SETUP_ERROR + class DNSRecordCreateError(DNSError): """DNS record create error.""" + code = ErrorCodes.DNS_RECORD_CREATE_ERROR + class DNSRecordUpdateError(DNSError): """DNS record update error.""" + code = ErrorCodes.DNS_RECORD_UPDATE_ERROR + class DNSRecordDeleteError(DNSError): """DNS record delete error.""" + code = ErrorCodes.DNS_RECORD_DELETE_ERROR + class DNSZoneCreateError(DNSError): """DNS zone create error.""" + code = ErrorCodes.DNS_ZONE_CREATE_ERROR + class DNSZoneUpdateError(DNSError): """DNS zone update error.""" + code = ErrorCodes.DNS_ZONE_UPDATE_ERROR + class DNSZoneDeleteError(DNSError): """DNS zone delete error.""" + code = ErrorCodes.DNS_ZONE_DELETE_ERROR + class DNSUpdateServerOptionsError(DNSError): """DNS update server options error.""" + + code = ErrorCodes.DNS_UPDATE_SERVER_OPTIONS_ERROR + + +class DNSConnectionError(DNSError): + """DNS connection error.""" + + code = ErrorCodes.DNS_CONNECTION_ERROR + + +class DNSNotImplementedError(DNSError): + """DNS not implemented error.""" + + code = ErrorCodes.DNS_NOT_IMPLEMENTED_ERROR diff --git a/app/ldap_protocol/dns/remote.py b/app/ldap_protocol/dns/remote.py index a44c69c0e..1c2cb25fd 100644 --- a/app/ldap_protocol/dns/remote.py +++ b/app/ldap_protocol/dns/remote.py @@ -15,7 +15,8 @@ from dns.update import Update from dns.zone import Zone -from .base import AbstractDNSManager, DNSConnectionError, DNSRecord, DNSRecords +from .base import AbstractDNSManager, DNSRecord, DNSRecords +from .exceptions import DNSConnectionError from .utils import logger_wraps diff --git a/app/ldap_protocol/dns/utils.py b/app/ldap_protocol/dns/utils.py index 005d9f5d4..9adc21fe9 100644 --- a/app/ldap_protocol/dns/utils.py +++ b/app/ldap_protocol/dns/utils.py @@ -9,7 +9,8 @@ from dns.asyncresolver import Resolver as AsyncResolver -from .base import DNSConnectionError, log +from .base import log +from .exceptions import DNSConnectionError def logger_wraps(is_stub: bool = False) -> Callable: diff --git a/app/ldap_protocol/filter_interpreter.py b/app/ldap_protocol/filter_interpreter.py index f7cc36f7f..c456fac00 100644 --- a/app/ldap_protocol/filter_interpreter.py +++ b/app/ldap_protocol/filter_interpreter.py @@ -31,6 +31,7 @@ User, ) from ldap_protocol.utils.helpers import ft_to_dt +from ldap_protocol.utils.queries import get_path_filter, get_search_path from repo.pg.tables import groups_table, queryable_attr as qa, users_table from .asn1parser import ASN1Row, TagNumbers @@ -398,6 +399,14 @@ def _cast_item(self, item: ASN1Row) -> UnaryExpression | ColumnElement: # noqa: is_substring = item.tag_id == TagNumbers.SUBSTRING + if attr == "distinguishedname" and not is_substring: + try: + dn_search_path = get_search_path(right.value) + except Exception: # noqa: S110 + pass + else: + return get_path_filter(dn_search_path) + if attr == "anr": if is_substring: expr = right.value[0] diff --git a/app/ldap_protocol/identity/exceptions.py b/app/ldap_protocol/identity/exceptions.py index ebd5274a9..568ec1f7f 100644 --- a/app/ldap_protocol/identity/exceptions.py +++ b/app/ldap_protocol/identity/exceptions.py @@ -4,34 +4,74 @@ License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE """ +from enum import IntEnum -class IdentityError(Exception): +from api.error_routing import BaseDomainException + + +class ErrorCodes(IntEnum): + """Identity error codes.""" + + BASE_ERROR = 0 + UNAUTHORIZED_ERROR = 1 + ALREADY_CONFIGURED_ERROR = 2 + FORBIDDEN_ERROR = 3 + LOGIN_FAILED_ERROR = 4 + PASSWORD_POLICY_ERROR = 5 + USER_NOT_FOUND_ERROR = 6 + AUTH_VALIDATION_ERROR = 7 + AUTHORIZATION_ERROR = 8 + + +class AuthError(BaseDomainException): """Base exception for authentication identity-related errors.""" + code: ErrorCodes = ErrorCodes.BASE_ERROR + -class UnauthorizedError(IdentityError): +class UnauthorizedError(AuthError): """Raised when authentication fails due to invalid credentials.""" + code = ErrorCodes.UNAUTHORIZED_ERROR -class AlreadyConfiguredError(IdentityError): + +class AlreadyConfiguredError(AuthError): """Raised when setup is attempted but already performed.""" + code = ErrorCodes.ALREADY_CONFIGURED_ERROR + -class ForbiddenError(IdentityError): +class ForbiddenError(AuthError): """Raised when access is forbidden due to policy or group membership.""" + code = ErrorCodes.FORBIDDEN_ERROR -class LoginFailedError(IdentityError): + +class LoginFailedError(AuthError): """Raised when login fails for reasons other than invalid credentials.""" + code = ErrorCodes.LOGIN_FAILED_ERROR + -class PasswordPolicyError(IdentityError): +class PasswordPolicyError(AuthError): """Raised when a password does not meet policy requirements.""" + code = ErrorCodes.PASSWORD_POLICY_ERROR -class UserNotFoundError(IdentityError): + +class UserNotFoundError(AuthError): """Raised when a user is not found in the system.""" + code = ErrorCodes.USER_NOT_FOUND_ERROR + -class AuthValidationError(IdentityError): +class AuthValidationError(AuthError): """Raised when there is a validation error during authentication.""" + + code = ErrorCodes.AUTH_VALIDATION_ERROR + + +class AuthorizationError(AuthError): + """Authorization error.""" + + code = ErrorCodes.AUTHORIZATION_ERROR diff --git a/app/ldap_protocol/identity/utils.py b/app/ldap_protocol/identity/utils.py new file mode 100644 index 000000000..844c0f4bf --- /dev/null +++ b/app/ldap_protocol/identity/utils.py @@ -0,0 +1,63 @@ +"""Identity utility functions for authentication and user management. + +Copyright (c) 2024 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from ipaddress import IPv4Address, IPv6Address, ip_address + +from fastapi import HTTPException, Request, status +from sqlalchemy.ext.asyncio import AsyncSession + +from entities import User +from ldap_protocol.utils.queries import get_user +from password_utils import PasswordUtils + + +async def authenticate_user( + session: AsyncSession, + username: str, + password: str, + password_utils: PasswordUtils, +) -> User | None: + """Get user and verify password. + + :param AsyncSession session: sa session + :param str username: any str + :param str password: any str + :return User | None: User model (pydantic). + """ + user = await get_user(session, username) + + if not user or not user.password or not password: + return None + if not password_utils.verify_password(password, user.password): + return None + return user + + +def get_ip_from_request(request: Request) -> IPv4Address | IPv6Address: + """Get IP address from request. + + :param Request request: The incoming request object. + :return IPv4Address | None: The IP address or None. + """ + forwarded_for = request.headers.get("X-Forwarded-For") + if forwarded_for: + client_ip = forwarded_for.split(",")[0] + else: + if request.client is None: + raise HTTPException(status.HTTP_400_BAD_REQUEST) + client_ip = request.client.host + + return ip_address(client_ip) + + +def get_user_agent_from_request(request: Request) -> str: + """Get user agent from request. + + :param Request request: The incoming request object. + :return str: The user agent header. + """ + user_agent_header = request.headers.get("User-Agent") + return user_agent_header if user_agent_header else "" diff --git a/app/ldap_protocol/kerberos/exceptions.py b/app/ldap_protocol/kerberos/exceptions.py index d98ca1bda..735149eff 100644 --- a/app/ldap_protocol/kerberos/exceptions.py +++ b/app/ldap_protocol/kerberos/exceptions.py @@ -4,86 +4,159 @@ License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE """ - -class KerberosError(Exception): +from enum import IntEnum + +from api.error_routing import BaseDomainException + + +class ErrorCodes(IntEnum): + """Error codes.""" + + BASE_ERROR = 0 + KERBEROS_BASE_DN_NOT_FOUND_ERROR = 1 + KERBEROS_CONFLICT_ERROR = 2 + KERBEROS_NOT_FOUND_ERROR = 3 + KERBEROS_DEPENDENCY_ERROR = 4 + KERBEROS_UNAVAILABLE_ERROR = 5 + KERBEROS_API_ERROR = 6 + KERBEROS_API_CONFLICT_ERROR = 7 + KERBEROS_API_NOT_FOUND_ERROR = 8 + KERBEROS_API_DEPENDENCY_ERROR = 9 + KERBEROS_API_UNAVAILABLE_ERROR = 10 + KERBEROS_API_SETUP_CONFIGS_ERROR = 11 + KERBEROS_API_SETUP_STASH_ERROR = 12 + KERBEROS_API_SETUP_TREE_ERROR = 13 + KERBEROS_API_PRINCIPAL_NOT_FOUND_ERROR = 14 + KERBEROS_API_ADD_PRINCIPAL_ERROR = 15 + KERBEROS_API_GET_PRINCIPAL_ERROR = 16 + KERBEROS_API_DELETE_PRINCIPAL_ERROR = 17 + KERBEROS_API_CHANGE_PASSWORD_ERROR = 18 + KERBEROS_API_RENAME_PRINCIPAL_ERROR = 19 + KERBEROS_API_LOCK_PRINCIPAL_ERROR = 20 + KERBEROS_API_FORCE_PASSWORD_CHANGE_ERROR = 21 + KERBEROS_API_STATUS_NOT_FOUND_ERROR = 22 + KERBEROS_API_CONNECTION_ERROR = 23 + + +class KerberosError(BaseDomainException): """Base exception for authentication kerberos-related errors.""" + code: ErrorCodes = ErrorCodes.BASE_ERROR + class KerberosConflictError(KerberosError): """Raised when a conflict occurs.""" + code = ErrorCodes.KERBEROS_CONFLICT_ERROR + class KerberosNotFoundError(KerberosError): """Raised when a resource is not found.""" + code = ErrorCodes.KERBEROS_NOT_FOUND_ERROR + class KerberosDependencyError(KerberosError): """Raised when a dependency fails.""" + code = ErrorCodes.KERBEROS_DEPENDENCY_ERROR + class KerberosUnavailableError(KerberosError): """Raised when the service is unavailable.""" + code = ErrorCodes.KERBEROS_UNAVAILABLE_ERROR + class KerberosBaseDnNotFoundError(KerberosError): """Raised when no base DN is found in the LDAP directory.""" + code = ErrorCodes.KERBEROS_BASE_DN_NOT_FOUND_ERROR -class KRBAPIError(Exception): + +class KRBAPIError(KerberosError): """API Error.""" class KRBAPIConflictError(KRBAPIError): """Conflict error.""" + code = ErrorCodes.KERBEROS_API_CONFLICT_ERROR + class KRBAPISetupConfigsError(KRBAPIError): """Setup configs error.""" + code = ErrorCodes.KERBEROS_API_SETUP_CONFIGS_ERROR + class KRBAPISetupStashError(KRBAPIError): """Setup stash error.""" + code = ErrorCodes.KERBEROS_API_SETUP_STASH_ERROR + class KRBAPISetupTreeError(KRBAPIError): """Setup tree error.""" + code = ErrorCodes.KERBEROS_API_SETUP_TREE_ERROR + class KRBAPIPrincipalNotFoundError(KRBAPIError): """Principal not found error.""" + code = ErrorCodes.KERBEROS_API_PRINCIPAL_NOT_FOUND_ERROR + class KRBAPIAddPrincipalError(KRBAPIError): """Add principal error.""" + code = ErrorCodes.KERBEROS_API_ADD_PRINCIPAL_ERROR + class KRBAPIGetPrincipalError(KRBAPIError): """Get principal error.""" + code = ErrorCodes.KERBEROS_API_GET_PRINCIPAL_ERROR + class KRBAPIDeletePrincipalError(KRBAPIError): """Delete principal error.""" + code = ErrorCodes.KERBEROS_API_DELETE_PRINCIPAL_ERROR + class KRBAPIChangePasswordError(KRBAPIError): """Change password error.""" + code = ErrorCodes.KERBEROS_API_CHANGE_PASSWORD_ERROR + class KRBAPIRenamePrincipalError(KRBAPIError): """Rename principal error.""" + code = ErrorCodes.KERBEROS_API_RENAME_PRINCIPAL_ERROR + class KRBAPILockPrincipalError(KRBAPIError): """Lock principal error.""" + code = ErrorCodes.KERBEROS_API_LOCK_PRINCIPAL_ERROR + class KRBAPIForcePasswordChangeError(KRBAPIError): """Force password change error.""" + code = ErrorCodes.KERBEROS_API_FORCE_PASSWORD_CHANGE_ERROR + class KRBAPIStatusNotFoundError(KRBAPIError): """Status not found error.""" + code = ErrorCodes.KERBEROS_API_STATUS_NOT_FOUND_ERROR + class KRBAPIConnectionError(KRBAPIError): """Connection error.""" + + code = ErrorCodes.KERBEROS_API_CONNECTION_ERROR diff --git a/app/ldap_protocol/kerberos/ldap_structure.py b/app/ldap_protocol/kerberos/ldap_structure.py index cba89fe39..45228a3c8 100644 --- a/app/ldap_protocol/kerberos/ldap_structure.py +++ b/app/ldap_protocol/kerberos/ldap_structure.py @@ -57,20 +57,20 @@ async def create_kerberos_structure( :return None. """ async with self._session.begin_nested(): - results = ( - await anext(services.handle(ctx)), - await anext(group.handle(ctx)), - await anext(krb_user.handle(ctx)), - ) - await self._session.flush() + service_result = await anext(services.handle(ctx)) + if service_result.result_code != 0: + raise KerberosConflictError("Service error") - if not all(result.result_code == 0 for result in results): - await self._session.rollback() - raise KerberosConflictError( - "Error creating Kerberos structure in directory", - ) + async with self._session.begin_nested(): + group_result = await anext(group.handle(ctx)) + if group_result.result_code != 0: + raise KerberosConflictError("Group error") + + async with self._session.begin_nested(): await self._role_use_case.create_kerberos_system_role() - await self._session.commit() + user_result = await anext(krb_user.handle(ctx)) + if user_result.result_code != 0: + raise KerberosConflictError("User error") async def rollback_kerberos_structure( self, diff --git a/app/ldap_protocol/kerberos/service.py b/app/ldap_protocol/kerberos/service.py index aea808d22..f6a0aae05 100644 --- a/app/ldap_protocol/kerberos/service.py +++ b/app/ldap_protocol/kerberos/service.py @@ -32,6 +32,7 @@ from .base import AbstractKadmin from .exceptions import ( KRBAPIAddPrincipalError, + KRBAPIConnectionError, KRBAPIDeletePrincipalError, KRBAPIPrincipalNotFoundError, KRBAPIRenamePrincipalError, @@ -43,7 +44,12 @@ from .ldap_structure import KRBLDAPStructureManager from .schemas import AddRequests, KDCContext, KerberosAdminDnGroup, TaskStruct from .template_render import KRBTemplateRenderer -from .utils import KerberosState, get_krb_server_state, set_state +from .utils import ( + KerberosState, + get_krb_server_state, + get_system_container_dn, + set_state, +) class KerberosService(AbstractService): @@ -140,7 +146,7 @@ def _build_kerberos_admin_dns(self, base_dn: str) -> KerberosAdminDnGroup: dataclass with DN for krbadmin, services_container, krbadmin_group. """ krbadmin = f"cn=krbadmin,cn=users,{base_dn}" - services_container = f"ou=services,{base_dn}" + services_container = get_system_container_dn(base_dn) krbgroup = f"cn=krbadmin,cn=groups,{base_dn}" return KerberosAdminDnGroup( krbadmin_dn=krbadmin, @@ -172,10 +178,12 @@ def _build_add_requests( "description": ["Kerberos administrator's group."], "gidNumber": ["800"], }, + is_system=True, ) services = AddRequest.from_dict( dns.services_container_dn, {"objectClass": ["organizationalUnit", "top", "container"]}, + is_system=True, ) krb_user = AddRequest.from_dict( dns.krbadmin_dn, @@ -209,6 +217,7 @@ def _build_add_requests( ), ], }, + is_system=True, ) return AddRequests( group=group, @@ -262,6 +271,7 @@ async def setup_kdc( KRBAPISetupStashError, KRBAPISetupTreeError, KerberosDependencyError, + KRBAPIConnectionError, ) as err: await self._ldap_manager.rollback_kerberos_structure( context.krbadmin, @@ -288,7 +298,7 @@ async def _get_kdc_context(self) -> KDCContext: base_dn, domain = await self._get_base_dn() krbadmin = f"cn=krbadmin,cn=users,{base_dn}" krbgroup = f"cn=krbadmin,cn=groups,{base_dn}" - services_container = f"ou=services,{base_dn}" + services_container = get_system_container_dn(base_dn) return KDCContext( base_dn=base_dn, domain=domain, diff --git a/app/ldap_protocol/kerberos/utils.py b/app/ldap_protocol/kerberos/utils.py index 1f43443a3..c6278ed95 100644 --- a/app/ldap_protocol/kerberos/utils.py +++ b/app/ldap_protocol/kerberos/utils.py @@ -9,7 +9,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from entities import Attribute, CatalogueSetting, Directory, EntityType -from enums import StrEnum +from enums import EntityTypeNames, StrEnum from repo.pg.tables import queryable_attr as qa from .exceptions import KRBAPIConnectionError, KRBAPIError @@ -122,7 +122,7 @@ async def unlock_principal(name: str, session: AsyncSession) -> None: .outerjoin(qa(Directory.entity_type)) .where( qa(Directory.name).ilike(name), - qa(EntityType.name) == "KRB Principal", + qa(EntityType.name) == EntityTypeNames.KRB_PRINCIPAL, ) .scalar_subquery() ) @@ -131,3 +131,8 @@ async def unlock_principal(name: str, session: AsyncSession) -> None: .filter_by(directory_id=subquery, name="krbprincipalexpiration") .execution_options(synchronize_session=False), ) + + +def get_system_container_dn(base_dn: str) -> str: + """Get System container DN for services.""" + return f"ou=System,{base_dn}" diff --git a/app/ldap_protocol/ldap_requests/abandon.py b/app/ldap_protocol/ldap_requests/abandon.py index 3facb0562..b9569ca0e 100644 --- a/app/ldap_protocol/ldap_requests/abandon.py +++ b/app/ldap_protocol/ldap_requests/abandon.py @@ -8,6 +8,7 @@ from typing import AsyncGenerator, ClassVar from ldap_protocol.asn1parser import ASN1Row +from ldap_protocol.ldap_requests.contexts import LDAPAbandonRequestContext from ldap_protocol.objects import ProtocolRequests from .base import BaseRequest @@ -16,6 +17,7 @@ class AbandonRequest(BaseRequest): """Abandon protocol.""" + CONTEXT_TYPE: ClassVar[type] = LDAPAbandonRequestContext PROTOCOL_OP: ClassVar[int] = ProtocolRequests.ABANDON message_id: int @@ -27,7 +29,7 @@ def from_data( """Create structure from ASN1Row dataclass list.""" return cls(message_id=1) - async def handle(self) -> AsyncGenerator: + async def handle(self, ctx: LDAPAbandonRequestContext) -> AsyncGenerator: # noqa: ARG002 """Handle message with current user.""" await asyncio.sleep(0) return diff --git a/app/ldap_protocol/ldap_requests/add.py b/app/ldap_protocol/ldap_requests/add.py index 4e15bc75d..75be3f6fc 100644 --- a/app/ldap_protocol/ldap_requests/add.py +++ b/app/ldap_protocol/ldap_requests/add.py @@ -4,18 +4,22 @@ License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE """ +import contextlib from typing import AsyncGenerator, ClassVar from pydantic import Field, SecretStr from sqlalchemy import select from sqlalchemy.exc import IntegrityError +from constants import DOMAIN_COMPUTERS_GROUP_NAME, DOMAIN_USERS_GROUP_NAME from entities import Attribute, Directory, Group, User -from enums import AceType +from enums import AceType, EntityTypeNames from ldap_protocol.asn1parser import ASN1Row from ldap_protocol.kerberos.exceptions import ( KRBAPIAddPrincipalError, KRBAPIConnectionError, + KRBAPIDeletePrincipalError, + KRBAPIPrincipalNotFoundError, ) from ldap_protocol.ldap_codes import LDAPCodes from ldap_protocol.ldap_responses import INVALID_ACCESS_RESPONSE, AddResponse @@ -24,10 +28,6 @@ ProtocolRequests, UserAccountControlFlag, ) -from ldap_protocol.utils.const import ( - DOMAIN_COMPUTERS_GROUP_NAME, - DOMAIN_USERS_GROUP_NAME, -) from ldap_protocol.utils.helpers import ( create_integer_hash, create_user_name, @@ -65,8 +65,14 @@ class AddRequest(BaseRequest): """ PROTOCOL_OP: ClassVar[int] = ProtocolRequests.ADD + CONTEXT_TYPE: ClassVar[type] = LDAPAddRequestContext entry: str = Field(..., description="Any `DistinguishedName`") + is_system: bool = Field( + False, + description="Mark as system directory (cannot be modified)", + ) + attributes: list[PartialAttribute] password: SecretStr | None = Field(None, examples=["password"]) @@ -158,10 +164,25 @@ async def handle( # noqa: C901 object_class_names=self.object_class_names, ) ) - if entity_type and entity_type.name == "Container": + if entity_type and entity_type.name == EntityTypeNames.CONTAINER: yield AddResponse(result_code=LDAPCodes.INSUFFICIENT_ACCESS_RIGHTS) return + if not ctx.attribute_value_validator.is_value_valid( + entity_type.name if entity_type else "", + "name", + name, + ) or not ctx.attribute_value_validator.is_value_valid( + entity_type.name if entity_type else "", + new_dn, + name, + ): + yield AddResponse( + result_code=LDAPCodes.UNDEFINED_ATTRIBUTE_TYPE, + errorMessage="Invalid attribute value(s)", + ) + return + can_add = ctx.access_manager.check_entity_level_access( aces=parent.access_control_entries, entity_type_id=entity_type.id if entity_type else None, @@ -189,6 +210,7 @@ async def handle( # noqa: C901 new_dir = Directory( object_class="", name=name, + is_system=self.is_system or bool(name == "kerberos"), parent=parent, ) @@ -399,10 +421,22 @@ async def handle( # noqa: C901 ), ) + if not ctx.attribute_value_validator.is_directory_attributes_valid( + entity_type.name if entity_type else "", + attributes, + ) or (user and not ctx.attribute_value_validator.is_user_valid(user)): + await ctx.session.rollback() + yield AddResponse( + result_code=LDAPCodes.UNDEFINED_ATTRIBUTE_TYPE, + errorMessage="Invalid attribute value(s)", + ) + return + try: items_to_add.extend(attributes) ctx.session.add_all(items_to_add) await ctx.session.flush() + await ctx.entity_type_dao.attach_entity_type_to_directory( directory=new_dir, is_system_entity_type=False, @@ -413,7 +447,7 @@ async def handle( # noqa: C901 parent_directory=parent, directory=new_dir, ) - await ctx.session.flush() + await ctx.session.commit() except IntegrityError: await ctx.session.rollback() yield AddResponse(result_code=LDAPCodes.ENTRY_ALREADY_EXISTS) @@ -422,13 +456,23 @@ async def handle( # noqa: C901 # in case server is not available: raise error and rollback # stub cannot raise error if user: + # NOTE: Try to delete existing principal if any + with contextlib.suppress( + KRBAPIDeletePrincipalError, + KRBAPIPrincipalNotFoundError, + ): + await ctx.kadmin.del_principal( + user.get_upn_prefix(), + ) + pw = ( self.password.get_secret_value() if self.password else None ) await ctx.kadmin.add_principal(user.get_upn_prefix(), pw) - if is_computer: + + elif is_computer: await ctx.kadmin.add_principal( f"{new_dir.host_principal}.{base_dn.name}", None, @@ -453,6 +497,7 @@ def from_dict( entry: str, attributes: dict[str, list[str]], password: str | None = None, + is_system: bool = False, ) -> "AddRequest": """Create AddRequest from dict. @@ -462,6 +507,7 @@ def from_dict( """ return AddRequest( entry=entry, + is_system=is_system, password=password, attributes=[ PartialAttribute(type=name, vals=vals) diff --git a/app/ldap_protocol/ldap_requests/base.py b/app/ldap_protocol/ldap_requests/base.py index 3123e6247..445ce3bae 100644 --- a/app/ldap_protocol/ldap_requests/base.py +++ b/app/ldap_protocol/ldap_requests/base.py @@ -24,6 +24,7 @@ from ldap_protocol.dependency import resolve_deps from ldap_protocol.dialogue import LDAPSession from ldap_protocol.ldap_responses import BaseResponse, LDAPResult +from ldap_protocol.objects import ProtocolRequests from ldap_protocol.policies.audit.audit_use_case import AuditUseCase from ldap_protocol.policies.audit.events.factory import ( RawAuditEventBuilderRedis, @@ -62,6 +63,7 @@ class _APIProtocol: ... class BaseRequest(ABC, _APIProtocol, BaseModel): """Base request builder.""" + CONTEXT_TYPE: ClassVar[type] handle: ClassVar[handler] from_data: ClassVar[serializer] __event_data: dict = {} @@ -113,38 +115,39 @@ async def handle_tcp( container: AsyncContainer, ) -> AsyncIterator[BaseResponse]: """Hanlde response with tcp.""" - kwargs = await resolve_deps(func=self.handle, container=container) - responses = [] + ctx = await container.get(self.CONTEXT_TYPE) # type: ignore - async for response in self.handle(**kwargs): + responses = [] + async for response in self.handle(ctx=ctx): responses.append(response) yield response - ldap_session = await container.get(LDAPSession) - settings = await container.get(Settings) - audit_use_case = await container.get(AuditUseCase) - - if await audit_use_case.check_event_processing_enabled( - self.PROTOCOL_OP, - ): - username = getattr( - ldap_session.user, - "user_principal_name", - "ANONYMOUS", - ) - event = RawAuditEventBuilderRedis.from_ldap_request( - self, - responses=responses, - username=username, - ip=ldap_session.ip, - protocol="TCP_LDAP", - settings=settings, - context=self.get_event_data(), - ) + if self.PROTOCOL_OP != ProtocolRequests.SEARCH: + ldap_session = await container.get(LDAPSession) + settings = await container.get(Settings) + audit_use_case = await container.get(AuditUseCase) + + if await audit_use_case.check_event_processing_enabled( + self.PROTOCOL_OP, + ): + username = getattr( + ldap_session.user, + "user_principal_name", + "ANONYMOUS", + ) + event = RawAuditEventBuilderRedis.from_ldap_request( + self, + responses=responses, + username=username, + ip=ldap_session.ip, + protocol="TCP_LDAP", + settings=settings, + context=self.get_event_data(), + ) - ldap_session.event_task_group.create_task( - audit_use_case.manager.send_event(event), - ) + ldap_session.event_task_group.create_task( + audit_use_case.manager.send_event(event), + ) async def _handle_api( self, @@ -156,7 +159,8 @@ async def _handle_api( :param AsyncSession session: db session :return list[BaseResponse]: list of handled responses """ - kwargs = await resolve_deps(func=self.handle, container=container) + ctx = await container.get(self.CONTEXT_TYPE) # type: ignore + ldap_session = await container.get(LDAPSession) settings = await container.get(Settings) audit_use_case = await container.get(AuditUseCase) @@ -168,7 +172,7 @@ async def _handle_api( else: log_api.info(f"{get_class_name(self)}[{un}]") - responses = [response async for response in self.handle(**kwargs)] + responses = [response async for response in self.handle(ctx=ctx)] if settings.DEBUG: for response in responses: diff --git a/app/ldap_protocol/ldap_requests/bind.py b/app/ldap_protocol/ldap_requests/bind.py index d64294e72..445b2f25c 100644 --- a/app/ldap_protocol/ldap_requests/bind.py +++ b/app/ldap_protocol/ldap_requests/bind.py @@ -8,12 +8,10 @@ from typing import AsyncGenerator, ClassVar from pydantic import Field -from sqlalchemy.ext.asyncio import AsyncSession -from entities import NetworkPolicy, User +from entities import NetworkPolicy from enums import MFAFlags from ldap_protocol.asn1parser import ASN1Row -from ldap_protocol.dialogue import LDAPSession from ldap_protocol.kerberos.exceptions import ( KRBAPIAddPrincipalError, KRBAPIConnectionError, @@ -34,10 +32,6 @@ from ldap_protocol.ldap_responses import BaseResponse, BindResponse from ldap_protocol.multifactor import MultifactorAPI from ldap_protocol.objects import ProtocolRequests, UserAccountControlFlag -from ldap_protocol.policies.network_policy import ( - check_mfa_group, - is_user_group_valid, -) from ldap_protocol.user_account_control import get_check_uac from ldap_protocol.utils.queries import set_user_logon_attrs @@ -49,6 +43,7 @@ class BindRequest(BaseRequest): """Bind request fields mapping.""" PROTOCOL_OP: ClassVar[int] = ProtocolRequests.BIND + CONTEXT_TYPE: ClassVar[type] = LDAPBindRequestContext version: int name: str @@ -93,15 +88,6 @@ def from_data(cls, data: list[ASN1Row]) -> "BindRequest": AuthenticationChoice=auth_choice, ) - @staticmethod - async def is_user_group_valid( - user: User, - ldap_session: LDAPSession, - session: AsyncSession, - ) -> bool: - """Test compability.""" - return await is_user_group_valid(user, ldap_session.policy, session) - @staticmethod async def check_mfa( api: MultifactorAPI | None, @@ -173,11 +159,12 @@ async def handle( if uac_check(UserAccountControlFlag.ACCOUNTDISABLE): yield get_bad_response(LDAPBindErrors.ACCOUNT_DISABLED) return - - if not await self.is_user_group_valid( + policy = getattr(ctx.ldap_session, "policy", None) + if ( + policy is not None + ) and not await ctx.network_policy_validator.is_user_group_valid( user, - ctx.ldap_session, - ctx.session, + policy, ): yield get_bad_response(LDAPBindErrors.LOGON_FAILURE) return @@ -192,14 +179,18 @@ async def handle( yield get_bad_response(LDAPBindErrors.PASSWORD_MUST_CHANGE) return - if ( - (policy := getattr(ctx.ldap_session, "policy", None)) - and policy.mfa_status in (MFAFlags.ENABLED, MFAFlags.WHITELIST) + if (policy is not None) and ( + policy.mfa_status in (MFAFlags.ENABLED, MFAFlags.WHITELIST) and ctx.mfa is not None ): request_2fa = True if policy.mfa_status == MFAFlags.WHITELIST: - request_2fa = await check_mfa_group(policy, user, ctx.session) + request_2fa = ( + await ctx.network_policy_validator.check_mfa_group( + policy, + user, + ) + ) if request_2fa: mfa_status = await self.check_mfa( @@ -240,6 +231,7 @@ class UnbindRequest(BaseRequest): """Remove user from ldap_session.""" PROTOCOL_OP: ClassVar[int] = ProtocolRequests.UNBIND + CONTEXT_TYPE: ClassVar[type] = LDAPUnbindRequestContext @classmethod def from_data( diff --git a/app/ldap_protocol/ldap_requests/contexts.py b/app/ldap_protocol/ldap_requests/contexts.py index df4918f8d..98f6e1a9b 100644 --- a/app/ldap_protocol/ldap_requests/contexts.py +++ b/app/ldap_protocol/ldap_requests/contexts.py @@ -11,8 +11,12 @@ from config import Settings from ldap_protocol.dialogue import LDAPSession from ldap_protocol.kerberos import AbstractKadmin +from ldap_protocol.ldap_schema.attribute_value_validator import ( + AttributeValueValidator, +) from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO from ldap_protocol.multifactor import LDAPMultiFactorAPI +from ldap_protocol.policies.network import NetworkPolicyValidatorUseCase from ldap_protocol.policies.password import PasswordPolicyUseCases from ldap_protocol.roles.access_manager import AccessManager from ldap_protocol.roles.role_use_case import RoleUseCase @@ -33,6 +37,7 @@ class LDAPAddRequestContext: password_utils: PasswordUtils access_manager: AccessManager role_use_case: RoleUseCase + attribute_value_validator: AttributeValueValidator @dataclass @@ -48,6 +53,7 @@ class LDAPModifyRequestContext: access_manager: AccessManager password_use_cases: PasswordPolicyUseCases password_utils: PasswordUtils + attribute_value_validator: AttributeValueValidator @dataclass @@ -61,6 +67,7 @@ class LDAPBindRequestContext: password_use_cases: PasswordPolicyUseCases password_utils: PasswordUtils mfa: LDAPMultiFactorAPI + network_policy_validator: NetworkPolicyValidatorUseCase @dataclass @@ -115,3 +122,8 @@ class LDAPModifyDNRequestContext: entity_type_dao: EntityTypeDAO access_manager: AccessManager role_use_case: RoleUseCase + attribute_value_validator: AttributeValueValidator + + +@dataclass +class LDAPAbandonRequestContext: ... diff --git a/app/ldap_protocol/ldap_requests/delete.py b/app/ldap_protocol/ldap_requests/delete.py index 5c731e9b7..e2b127331 100644 --- a/app/ldap_protocol/ldap_requests/delete.py +++ b/app/ldap_protocol/ldap_requests/delete.py @@ -43,6 +43,7 @@ class DeleteRequest(BaseRequest): """ PROTOCOL_OP: ClassVar[int] = ProtocolRequests.DELETE + CONTEXT_TYPE: ClassVar[type] = LDAPDeleteRequestContext entry: str @@ -97,6 +98,12 @@ async def handle( # noqa: C901 yield DeleteResponse(result_code=LDAPCodes.NO_SUCH_OBJECT) return + if directory.is_system: + yield DeleteResponse( + result_code=LDAPCodes.UNWILLING_TO_PERFORM, + ) + return + self.set_event_data( {"before_attrs": self.get_directory_attrs(directory)}, ) diff --git a/app/ldap_protocol/ldap_requests/extended.py b/app/ldap_protocol/ldap_requests/extended.py index 85ca1f31b..c3967889e 100644 --- a/app/ldap_protocol/ldap_requests/extended.py +++ b/app/ldap_protocol/ldap_requests/extended.py @@ -308,6 +308,7 @@ class ExtendedRequest(BaseRequest): """ PROTOCOL_OP: ClassVar[int] = ProtocolRequests.EXTENDED + CONTEXT_TYPE: ClassVar[type] = LDAPExtendedRequestContext request_name: LDAPOID request_value: SerializeAsAny[BaseExtendedValue] diff --git a/app/ldap_protocol/ldap_requests/modify.py b/app/ldap_protocol/ldap_requests/modify.py index 5161ae754..676550e3e 100644 --- a/app/ldap_protocol/ldap_requests/modify.py +++ b/app/ldap_protocol/ldap_requests/modify.py @@ -14,9 +14,9 @@ from sqlalchemy.orm import joinedload, selectinload from config import Settings -from constants import PRIMARY_ENTITY_TYPE_NAMES +from constants import DOMAIN_ADMIN_GROUP_NAME from entities import Attribute, Directory, Group, User -from enums import AceType +from enums import AceType, EntityTypeNames from ldap_protocol.asn1parser import ASN1Row from ldap_protocol.dialogue import UserSchema from ldap_protocol.kerberos import AbstractKadmin, unlock_principal @@ -37,17 +37,11 @@ from ldap_protocol.policies.password import PasswordPolicyUseCases from ldap_protocol.session_storage import SessionStorage from ldap_protocol.utils.cte import check_root_group_membership_intersection -from ldap_protocol.utils.helpers import ( - create_user_name, - ft_to_dt, - is_dn_in_base_directory, - validate_entry, -) +from ldap_protocol.utils.helpers import ft_to_dt, validate_entry from ldap_protocol.utils.queries import ( add_lock_and_expire_attributes, clear_group_membership, extend_group_membership, - get_base_directories, get_directories, get_directory_by_rid, get_filter_from_path, @@ -81,8 +75,6 @@ class ModifyForbiddenError(Exception): KRBAPIForcePasswordChangeError, ) -_DOMAIN_ADMIN_NAME = "domain admins" - class ModifyRequest(BaseRequest): """Modify request. @@ -103,6 +95,7 @@ class ModifyRequest(BaseRequest): """ PROTOCOL_OP: ClassVar[int] = ProtocolRequests.MODIFY + CONTEXT_TYPE: ClassVar[type] = LDAPModifyRequestContext object: str changes: list[Changes] @@ -193,9 +186,7 @@ async def handle( names = {change.get_name() for change in self.changes} - password_change_requested = self._check_password_change_requested( - names, - ) + password_change_requested = self._is_password_change_requested(names) self_modify = directory.id == ctx.ldap_session.user.directory_id if ( @@ -210,7 +201,7 @@ async def handle( return before_attrs = self.get_directory_attrs(directory) - + entity_type = directory.entity_type try: if not can_modify and not ( password_change_requested and self_modify @@ -224,6 +215,17 @@ async def handle( if change.modification.type.lower() in Directory.ro_fields: continue + if not ctx.attribute_value_validator.is_partial_attribute_valid( # noqa: E501 + entity_type.name if entity_type else "", + change.modification, + ): + await ctx.session.rollback() + yield ModifyResponse( + result_code=LDAPCodes.UNDEFINED_ATTRIBUTE_TYPE, + message="Invalid attribute value(s)", + ) + return + await self._update_password_expiration( change, directory.user, @@ -267,9 +269,7 @@ async def handle( await self._add(*add_args) await ctx.session.flush() - await ctx.session.execute( - update(Directory).filter_by(id=directory.id), - ) + except MODIFY_EXCEPTION_STACK as err: await ctx.session.rollback() result_code, message = self._match_bad_response(err) @@ -289,8 +289,10 @@ async def handle( directory=directory, is_system_entity_type=False, ) + await ctx.session.commit() yield ModifyResponse(result_code=LDAPCodes.SUCCESS) + finally: query = self._get_dir_query() directory = await ctx.session.scalar(query) @@ -351,7 +353,7 @@ def _get_dir_query(self) -> Select[tuple[Directory]]: .filter(get_filter_from_path(self.object)) ) - def _check_password_change_requested( + def _is_password_change_requested( self, names: set[str], ) -> bool: @@ -437,7 +439,7 @@ async def _can_delete_group_from_directory( if operation == Operation.REPLACE: for group in directory.groups: if ( - group.directory.name == _DOMAIN_ADMIN_NAME + group.directory.name == DOMAIN_ADMIN_GROUP_NAME and directory.path_dn == user.dn and group not in groups ): @@ -448,7 +450,7 @@ async def _can_delete_group_from_directory( elif operation == Operation.DELETE: for group in groups: if ( - group.directory.name == _DOMAIN_ADMIN_NAME + group.directory.name == DOMAIN_ADMIN_GROUP_NAME and directory.path_dn == user.dn ): raise ModifyForbiddenError( @@ -482,7 +484,7 @@ async def _can_delete_member_from_directory( operation == Operation.DELETE and user.dn in modified_members_dns ) - if directory.name == _DOMAIN_ADMIN_NAME and ( + if directory.name == DOMAIN_ADMIN_GROUP_NAME and ( is_user_in_deleted or is_user_not_in_replaced ): raise ModifyForbiddenError("Can't delete yourself from group.") @@ -591,7 +593,7 @@ async def _validate_object_class_modification( ) -> None: if not ( directory.entity_type - and directory.entity_type.name in PRIMARY_ENTITY_TYPE_NAMES + and directory.entity_type.name in EntityTypeNames ): return @@ -847,17 +849,12 @@ async def _add( # noqa: C901 await session.execute( delete(Attribute) - .filter_by( - name="nsAccountLock", - directory=directory, - ), - ) # fmt: skip - - await session.execute( - delete(Attribute) - .filter_by( - name="shadowExpire", - directory=directory, + .where( + or_( + qa(Attribute.name) == "nsAccountLock", + qa(Attribute.name) == "shadowExpire", + ), + qa(Attribute.directory) == directory, ), ) # fmt: skip @@ -881,30 +878,6 @@ async def _add( # noqa: C901 ) elif name in User.search_fields: - if not directory.user: - path_dn = directory.path_dn - for base_directory in await get_base_directories(session): - if is_dn_in_base_directory(base_directory, path_dn): - base_dn = base_directory - break - - sam_account_name = create_user_name(directory.id) - user_principal_name = f"{sam_account_name}@{base_dn.name}" - user = User( - sam_account_name=sam_account_name, - user_principal_name=user_principal_name, - directory_id=directory.id, - ) - uac_attr = Attribute( - name="userAccountControl", - value=str(UserAccountControlFlag.NORMAL_ACCOUNT), - directory_id=directory.id, - ) - - session.add_all([user, uac_attr]) - await session.flush() - await session.refresh(directory) - if name == "accountexpires": new_value = ft_to_dt(int(value)) if value != "0" else None else: @@ -915,14 +888,6 @@ async def _add( # noqa: C901 .filter_by(directory=directory) .values({name: new_value}), ) - - elif name in Group.search_fields and directory.group: - await session.execute( - update(Group) - .filter_by(directory=directory) - .values({name: value}), - ) - elif name in ("userpassword", "unicodepwd") and directory.user: if not settings.USE_CORE_TLS: raise PermissionError("TLS required") diff --git a/app/ldap_protocol/ldap_requests/modify_dn.py b/app/ldap_protocol/ldap_requests/modify_dn.py index d17120540..7c315eadd 100644 --- a/app/ldap_protocol/ldap_requests/modify_dn.py +++ b/app/ldap_protocol/ldap_requests/modify_dn.py @@ -8,7 +8,7 @@ from sqlalchemy import delete, func, select, text, update from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import selectinload +from sqlalchemy.orm import joinedload, selectinload from entities import AccessControlEntry, Attribute, Directory from enums import AceType @@ -68,6 +68,7 @@ class ModifyDNRequest(BaseRequest): """ PROTOCOL_OP: ClassVar[int] = ProtocolRequests.MODIFY_DN + CONTEXT_TYPE: ClassVar[type] = LDAPModifyDNRequestContext entry: str newrdn: str @@ -112,7 +113,8 @@ async def handle( query = ( select(Directory) .options( - selectinload(qa(Directory.parent)), + joinedload(qa(Directory.parent)), + joinedload(qa(Directory.entity_type)), ) .filter(get_filter_from_path(self.entry)) ) @@ -134,6 +136,12 @@ async def handle( yield ModifyDNResponse(result_code=LDAPCodes.UNWILLING_TO_PERFORM) return + if directory.is_system: + yield ModifyDNResponse( + result_code=LDAPCodes.UNWILLING_TO_PERFORM, + ) + return + old_name = directory.name new_dn, new_name = self.newrdn.split("=") directory.name = new_name @@ -143,6 +151,21 @@ async def handle( old_depth = directory.depth + if ( + directory.entity_type + and not ctx.attribute_value_validator.is_value_valid( + entity_type_name=directory.entity_type.name, + attr_name="name", + attr_value=new_name, + ) + ): + await ctx.session.rollback() + yield ModifyDNResponse( + result_code=LDAPCodes.UNDEFINED_ATTRIBUTE_TYPE, + message="Invalid attribute value(s)", + ) + return + if ( self.new_superior and directory.parent diff --git a/app/ldap_protocol/ldap_requests/search.py b/app/ldap_protocol/ldap_requests/search.py index 01ec77169..c6505322a 100644 --- a/app/ldap_protocol/ldap_requests/search.py +++ b/app/ldap_protocol/ldap_requests/search.py @@ -14,7 +14,12 @@ from pydantic import Field, PrivateAttr, field_serializer from sqlalchemy import func, or_, select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import joinedload, selectinload, with_loader_criteria +from sqlalchemy.orm import ( + contains_eager, + joinedload, + selectinload, + with_loader_criteria, +) from sqlalchemy.sql.elements import ColumnElement, UnaryExpression from sqlalchemy.sql.expression import Select @@ -100,6 +105,7 @@ class SearchRequest(BaseRequest): """ PROTOCOL_OP: ClassVar[int] = ProtocolRequests.SEARCH + CONTEXT_TYPE: ClassVar[type] = LDAPSearchRequestContext base_object: str = Field("", description="Any `DistinguishedName`") scope: Scope @@ -151,6 +157,10 @@ def is_sid_requested(self) -> bool: def is_guid_requested(self) -> bool: return self.all_attrs or "objectguid" in self.requested_attrs + @property + def is_objectclass_requested(self) -> bool: + return self.all_attrs or "objectclass" in self.requested_attrs + @cached_property def all_attrs(self) -> bool: return "*" in self.requested_attrs or not self.requested_attrs @@ -339,7 +349,7 @@ def _mutate_query_with_attributes_to_load( if self.entity_type_name: query = ( query.join(qa(Directory.entity_type)) - .options(selectinload(qa(Directory.entity_type))) + .options(contains_eager(qa(Directory.entity_type))) ) # fmt: skip if self.all_attrs: @@ -351,11 +361,16 @@ def _mutate_query_with_attributes_to_load( if attr not in _ATTRS_TO_CLEAN } + cond = or_( + func.lower(Attribute.name).in_(attrs), + func.lower(Attribute.name) == "objectclass", + ) + return query.options( selectinload(qa(Directory.attributes)), with_loader_criteria( Attribute, - func.lower(Attribute.name).in_(attrs), + cond, ), ) @@ -369,8 +384,8 @@ def _build_query( query = ( select(Directory) .join(qa(Directory.user), isouter=True) - .options(joinedload(qa(Directory.user))) - .options(selectinload(qa(Directory.group))) + .options(contains_eager(qa(Directory.user))) + .options(joinedload(qa(Directory.group))) ) query = self._mutate_query_with_attributes_to_load(query) @@ -423,7 +438,7 @@ def _build_query( if self.member: query = query.options( - selectinload(qa(Directory.group)).selectinload( + joinedload(qa(Directory.group)).selectinload( qa(Group.members), ), ) @@ -468,7 +483,7 @@ async def _fill_attrs( attrs: dict[str, list[str]], session: AsyncSession, ) -> None: - if "distinguishedname" not in self.requested_attrs or self.all_attrs: + if "distinguishedname" in self.requested_attrs or self.all_attrs: attrs["distinguishedName"].append(distinguished_name) if "whenCreated" in self.requested_attrs or self.all_attrs: @@ -501,15 +516,10 @@ async def _fill_attrs( ) if self.member_of: - logger.debug(f"Member of group: {directory.groups}") for group in directory.groups: attrs["memberOf"].append(group.directory.path_dn) if self.token_groups and "user" in obj_classes: - attrs["tokenGroups"].append( - str(string_to_sid(directory.object_sid)), - ) - group_directories = await get_all_parent_group_directories( directory.groups, session, @@ -518,7 +528,7 @@ async def _fill_attrs( if group_directories is not None: async for directory_ in group_directories: attrs["tokenGroups"].append( - str(string_to_sid(directory_.object_sid)), + string_to_sid(directory_.object_sid), # type: ignore ) if self.member and "group" in obj_classes and directory.group: @@ -541,9 +551,9 @@ async def tree_view( # noqa: C901 access_manager: AccessManager, ) -> AsyncGenerator[SearchResultEntry, None]: """Yield all resulted directories.""" - directories = await session.stream_scalars(query) + directories = await session.scalars(query) - async for directory in directories: + for directory in directories: attrs = defaultdict(list) obj_classes = [] @@ -572,6 +582,9 @@ async def tree_view( # noqa: C901 if attr.name.lower() == "objectclass": obj_classes.append(value) + if self.is_objectclass_requested: + attrs[attr.name].append(value) + continue attrs[attr.name].append(value) diff --git a/app/ldap_protocol/ldap_schema/attribute_value_validator.py b/app/ldap_protocol/ldap_schema/attribute_value_validator.py new file mode 100644 index 000000000..9c7c77f56 --- /dev/null +++ b/app/ldap_protocol/ldap_schema/attribute_value_validator.py @@ -0,0 +1,270 @@ +"""Attribute Value Validator. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +import re +from collections import defaultdict +from typing import Callable, cast as tcast + +from entities import Attribute, Directory, User +from enums import EntityTypeNames +from ldap_protocol.objects import PartialAttribute + +type _AttrNameType = str +type _ValueType = str +type _ValueValidatorType = Callable[[_ValueType], bool] +type _CompiledValidatorsType = dict[ + EntityTypeNames, + dict[_AttrNameType, _ValueValidatorType], +] + + +class AttributeValueValidatorError(Exception): ... + + +# NOTE: Not validate `distinguishedName`, `member` and `memberOf` attributes, +# because it doesn't exist. +_ENTITY_NAME_AND_ATTR_NAME_VALIDATION_MAP: dict[ + tuple[EntityTypeNames, _AttrNameType], + tuple[str, ...], +] = { + (EntityTypeNames.ORGANIZATIONAL_UNIT, "name"): ( + "not_start_with_space", + "not_start_with_hash", + "not_end_with_space", + "not_contains_symbols", + ), + (EntityTypeNames.GROUP, "name"): ( + "not_start_with_space", + "not_start_with_hash", + "not_end_with_space", + "not_contains_symbols", + ), + (EntityTypeNames.USER, "name"): ( + "not_start_with_space", + "not_start_with_hash", + "not_end_with_space", + "not_contains_symbols", + ), + (EntityTypeNames.USER, "sAMAccountName"): ( + "not_contains_symbols_ext", + "not_end_with_dot", + "not_contains_control_characters", + "not_contains_at", + ), + (EntityTypeNames.COMPUTER, "name"): ( + "not_start_with_space", + "not_start_with_hash", + "not_end_with_space", + "not_contains_symbols", + ), + (EntityTypeNames.COMPUTER, "sAMAccountName"): ( + "not_contains_symbols_ext", + "not_end_with_dot", + "not_contains_control_characters", + "not_contains_spaces_and_dots", + "not_only_numbers", + "not_start_with_number", + ), +} + + +class _ValValidators: + @staticmethod + def not_start_with_space(value: _ValueType) -> bool: + return not value.startswith(" ") + + @staticmethod + def not_only_numbers(value: _ValueType) -> bool: + return not value.isdigit() + + @staticmethod + def not_contains_at(value: _ValueType) -> bool: + return "@" not in value + + @staticmethod + def not_start_with_number(value: _ValueType) -> bool: + return bool(value and not value[0].isdigit()) + + @staticmethod + def not_start_with_hash(value: _ValueType) -> bool: + return not value.startswith("#") + + @staticmethod + def not_end_with_space(value: _ValueType) -> bool: + return not value.endswith(" ") + + @staticmethod + def not_contains_control_characters(value: _ValueType) -> bool: + return all(ord(char) >= 32 and ord(char) != 127 for char in value) + + @staticmethod + def not_contains_spaces_and_dots(value: _ValueType) -> bool: + return " " not in value and "." not in value + + @staticmethod + def not_contains_symbols(value: _ValueType) -> bool: + return not re.search(r'[,+"\\<>;=]', value) + + @staticmethod + def not_contains_symbols_ext(value: _ValueType) -> bool: + return not re.search(r'["/\\\[\]:;\|=,\+\*\?<>]', value) + + @staticmethod + def not_end_with_dot(value: _ValueType) -> bool: + return not value.endswith(".") + + +class AttributeValueValidator: + _compiled_validators: _CompiledValidatorsType + + def __init__(self) -> None: + self._compiled_validators: _CompiledValidatorsType = ( + self.__compile_validators() + ) + + def __compile_validators(self) -> _CompiledValidatorsType: + res: _CompiledValidatorsType = defaultdict(dict) + + for ( + key, + validator_names, + ) in _ENTITY_NAME_AND_ATTR_NAME_VALIDATION_MAP.items(): + validators = [getattr(_ValValidators, n) for n in validator_names] + res[key[0]][key[1]] = self.__create_combined_validator(validators) + + return res + + def __create_combined_validator( + self, + funcs: list[_ValueValidatorType], + ) -> _ValueValidatorType: + def combined(value: _ValueType) -> bool: + return all(func(value) for func in funcs) + + return combined + + def _get_subset_validators( + self, + entity_type_name: EntityTypeNames | str, + ) -> dict[_AttrNameType, _ValueValidatorType] | None: + if entity_type_name in self._compiled_validators: + entity_type_name = tcast("EntityTypeNames", entity_type_name) + else: + return None + return self._compiled_validators.get(entity_type_name) + + def _get_validator( + self, + entity_type_name: EntityTypeNames | str, + attr_name: str, + ) -> _ValueValidatorType | None: + subset_validators = self._get_subset_validators(entity_type_name) + return subset_validators.get(attr_name) if subset_validators else None + + def is_value_valid( + self, + entity_type_name: EntityTypeNames | str, + attr_name: _AttrNameType, + attr_value: _ValueType, + ) -> bool: + validator = self._get_validator(entity_type_name, attr_name) + + if not validator: + return True + + return validator(attr_value) + + def is_partial_attribute_valid( + self, + entity_type_name: EntityTypeNames | str, + partial_attribute: PartialAttribute, + ) -> bool: + validator = self._get_validator( + entity_type_name, + partial_attribute.type, + ) + + if not validator: + return True + + for value in partial_attribute.vals: + if isinstance(value, str) and not validator(value): + return False + + return True + + def is_directory_attributes_valid( + self, + entity_type_name: EntityTypeNames | str, + attributes: list[Attribute], + ) -> bool: + subset_validators = self._get_subset_validators(entity_type_name) + if not subset_validators: + return True + + for attribute in attributes: + if not attribute.value: + continue + + validator = subset_validators.get(attribute.name) + if not validator: + continue + + if not validator(attribute.value): + return False + + return True + + def is_directory_valid(self, directory: Directory) -> bool: + if not directory.entity_type: + raise AttributeValueValidatorError( + "Directory must have an entity type", + ) + + entity_type_name = directory.entity_type.name + + if entity_type_name and not self.is_value_valid( + entity_type_name, + "name", + directory.name, + ): + return False + + if entity_type_name == EntityTypeNames.USER: + if not directory.user: + raise AttributeValueValidatorError( + "User directory must have associated User", + ) + + if not self.is_user_valid(directory.user): + return False + + if not self.is_directory_attributes_valid( # noqa: SIM103 + entity_type_name, + directory.attributes, + ): + return False + + return True + + def is_user_valid(self, user: User) -> bool: + user_entity_type_name = EntityTypeNames.USER + + if not self.is_value_valid( + user_entity_type_name, + "sAMAccountName", + user.sam_account_name, + ): + return False + + if not self.is_value_valid( # noqa: SIM103 + user_entity_type_name, + "userPrincipalName", + user.user_principal_name, + ): + return False + + return True diff --git a/app/ldap_protocol/ldap_schema/dto.py b/app/ldap_protocol/ldap_schema/dto.py index 0430d9a2e..118a6e1e8 100644 --- a/app/ldap_protocol/ldap_schema/dto.py +++ b/app/ldap_protocol/ldap_schema/dto.py @@ -7,7 +7,7 @@ from dataclasses import dataclass, field from typing import Generic, TypeVar -from enums import KindType +from enums import EntityTypeNames, KindType _IdT = TypeVar("_IdT", int, None) @@ -49,7 +49,7 @@ class ObjectClassDTO(Generic[_IdT, _LinkT]): class EntityTypeDTO(Generic[_IdT]): """Entity Type DTO.""" - name: str + name: EntityTypeNames | str is_system: bool object_class_names: list[str] id: _IdT = None # type: ignore diff --git a/app/ldap_protocol/ldap_schema/entity_type_dao.py b/app/ldap_protocol/ldap_schema/entity_type_dao.py index 6e30a0989..abfdc49d1 100644 --- a/app/ldap_protocol/ldap_schema/entity_type_dao.py +++ b/app/ldap_protocol/ldap_schema/entity_type_dao.py @@ -16,6 +16,10 @@ from abstract_dao import AbstractDAO from entities import Attribute, Directory, EntityType, ObjectClass +from ldap_protocol.ldap_schema.attribute_value_validator import ( + AttributeValueValidator, + AttributeValueValidatorError, +) from ldap_protocol.ldap_schema.dto import EntityTypeDTO from ldap_protocol.ldap_schema.exceptions import ( EntityTypeAlreadyExistsError, @@ -44,15 +48,18 @@ class EntityTypeDAO(AbstractDAO[EntityTypeDTO, str]): __session: AsyncSession __object_class_dao: ObjectClassDAO + __attribute_value_validator: AttributeValueValidator def __init__( self, session: AsyncSession, object_class_dao: ObjectClassDAO, + attribute_value_validator: AttributeValueValidator, ) -> None: """Initialize Entity Type DAO with a database session.""" self.__session = session self.__object_class_dao = object_class_dao + self.__attribute_value_validator = attribute_value_validator async def get_all(self) -> list[EntityTypeDTO[int]]: """Get all Entity Types.""" @@ -120,11 +127,20 @@ async def update(self, _id: str, dto: EntityTypeDTO[int]) -> None: for directory in result.scalars(): for object_class_name in entity_type.object_class_names: + if not self.__attribute_value_validator.is_value_valid( + entity_type.name, + "objectClass", + object_class_name, + ): + raise AttributeValueValidatorError( + f"Invalid objectClass value '{object_class_name}' for entity type '{entity_type.name}'.", # noqa: E501 + ) + self.__session.add( Attribute( directory_id=directory.id, - value=object_class_name, name="objectClass", + value=object_class_name, ), ) diff --git a/app/ldap_protocol/ldap_schema/entity_type_use_case.py b/app/ldap_protocol/ldap_schema/entity_type_use_case.py index bb6fb2729..5958e6a99 100644 --- a/app/ldap_protocol/ldap_schema/entity_type_use_case.py +++ b/app/ldap_protocol/ldap_schema/entity_type_use_case.py @@ -7,8 +7,8 @@ from typing import ClassVar from abstract_service import AbstractService -from constants import ENTITY_TYPE_DATAS, PRIMARY_ENTITY_TYPE_NAMES -from enums import AuthorizationRules +from constants import ENTITY_TYPE_DATAS +from enums import AuthorizationRules, EntityTypeNames from ldap_protocol.ldap_schema.dto import EntityTypeDTO from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO from ldap_protocol.ldap_schema.exceptions import ( @@ -65,7 +65,7 @@ async def _validate_name( self, name: str, ) -> None: - if name in PRIMARY_ENTITY_TYPE_NAMES: + if name in EntityTypeNames: raise EntityTypeCantModifyError( f"Can't change entity type name {name}", ) @@ -93,7 +93,7 @@ async def create_for_first_setup(self) -> None: for entity_type_data in ENTITY_TYPE_DATAS: await self.create( EntityTypeDTO( - name=entity_type_data["name"], # type: ignore + name=entity_type_data["name"], object_class_names=list( entity_type_data["object_class_names"], ), diff --git a/app/ldap_protocol/ldap_schema/exceptions.py b/app/ldap_protocol/ldap_schema/exceptions.py index de6ab4544..02a9d43f3 100644 --- a/app/ldap_protocol/ldap_schema/exceptions.py +++ b/app/ldap_protocol/ldap_schema/exceptions.py @@ -4,50 +4,81 @@ License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE """ +from enum import IntEnum + +from errors import BaseDomainException + + +class ErrorCodes(IntEnum): + """Error codes for LDAP Schema operations.""" + + BASE_ERROR = 0 + ATTRIBUTE_TYPE_NOT_FOUND_ERROR = 1 + ATTRIBUTE_TYPE_CANT_MODIFY_ERROR = 2 + ATTRIBUTE_TYPE_ALREADY_EXISTS_ERROR = 3 + OBJECT_CLASS_NOT_FOUND_ERROR = 4 + OBJECT_CLASS_CANT_MODIFY_ERROR = 5 + OBJECT_CLASS_ALREADY_EXISTS_ERROR = 6 + ENTITY_TYPE_NOT_FOUND_ERROR = 7 + ENTITY_TYPE_CANT_MODIFY_ERROR = 8 + ENTITY_TYPE_ALREADY_EXISTS_ERROR = 9 -class AttributeTypeError(Exception): - """Raised when an attribute type is not found.""" +class LdapSchemaError(BaseDomainException): + """Raised when an LDAP Schema error occurs.""" -class AttributeTypeNotFoundError(AttributeTypeError): + code: ErrorCodes = ErrorCodes.BASE_ERROR + + +class AttributeTypeNotFoundError(LdapSchemaError): """Raised when an attribute type is not found.""" + code = ErrorCodes.ATTRIBUTE_TYPE_NOT_FOUND_ERROR -class AttributeTypeCantModifyError(AttributeTypeError): + +class AttributeTypeCantModifyError(LdapSchemaError): """Raised when an attribute type cannot be modified.""" + code = ErrorCodes.ATTRIBUTE_TYPE_CANT_MODIFY_ERROR -class AttributeTypeAlreadyExistsError(AttributeTypeError): - """Raised when an attribute type already exists.""" +class AttributeTypeAlreadyExistsError(LdapSchemaError): + """Raised when an attribute type already exists.""" -class ObjectClassTypeError(Exception): - """Raised when an object class type is not found.""" + code = ErrorCodes.ATTRIBUTE_TYPE_ALREADY_EXISTS_ERROR -class ObjectClassNotFoundError(ObjectClassTypeError): +class ObjectClassNotFoundError(LdapSchemaError): """Raised when an object class is not found.""" + code = ErrorCodes.OBJECT_CLASS_NOT_FOUND_ERROR + -class ObjectClassCantModifyError(ObjectClassTypeError): +class ObjectClassCantModifyError(LdapSchemaError): """Raised when an object class cannot be modified.""" + code = ErrorCodes.OBJECT_CLASS_CANT_MODIFY_ERROR -class ObjectClassAlreadyExistsError(ObjectClassTypeError): - """Raised when an object class already exists.""" +class ObjectClassAlreadyExistsError(LdapSchemaError): + """Raised when an object class already exists.""" -class EntityTypeTypeError(Exception): - """Raised when an entity type is not found.""" + code = ErrorCodes.OBJECT_CLASS_ALREADY_EXISTS_ERROR -class EntityTypeNotFoundError(EntityTypeTypeError): +class EntityTypeNotFoundError(LdapSchemaError): """Raised when an entity type is not found.""" + code = ErrorCodes.ENTITY_TYPE_NOT_FOUND_ERROR + -class EntityTypeCantModifyError(EntityTypeTypeError): +class EntityTypeCantModifyError(LdapSchemaError): """Raised when an entity type cannot be modified.""" + code = ErrorCodes.ENTITY_TYPE_CANT_MODIFY_ERROR -class EntityTypeAlreadyExistsError(EntityTypeTypeError): + +class EntityTypeAlreadyExistsError(LdapSchemaError): """Raised when an entity type already exists.""" + + code = ErrorCodes.ENTITY_TYPE_ALREADY_EXISTS_ERROR diff --git a/app/ldap_protocol/permissions_checker.py b/app/ldap_protocol/permissions_checker.py index be5a752c3..e41ae3cbf 100644 --- a/app/ldap_protocol/permissions_checker.py +++ b/app/ldap_protocol/permissions_checker.py @@ -5,15 +5,12 @@ from enums import AuthorizationRules from ldap_protocol.identity import IdentityProvider +from ldap_protocol.identity.exceptions import AuthorizationError _P = ParamSpec("_P") _R = TypeVar("_R") -class AuthorizationError(Exception): - """Authorization error.""" - - class AuthorizationProvider: """API permissions checker.""" diff --git a/app/ldap_protocol/policies/audit/events/service_senders/rfc5424_serializer.py b/app/ldap_protocol/policies/audit/events/service_senders/rfc5424_serializer.py new file mode 100644 index 000000000..8ff8272ce --- /dev/null +++ b/app/ldap_protocol/policies/audit/events/service_senders/rfc5424_serializer.py @@ -0,0 +1,170 @@ +"""RFC 5424 Syslog message serializer. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +import socket +from datetime import datetime, timezone +from typing import Any + +from ldap_protocol.policies.audit.events.dataclasses import ( + NormalizedAuditEvent, +) + + +class RFC5424Serializer: + """Serializer for RFC 5424 compliant syslog messages.""" + + NILVALUE: str = "-" + UTF8_BOM: str = "\ufeff" + + # SD-ID suffix for STRUCTURED-DATA: audit@32473 + # Change to your registered Private Enterprise Number (PEN) + STRUCTURED_DATA_ID_SUFFIX: str = "32473" + + SYSLOG_FACILITIES: dict[str, int] = { + "kernel": 0, + "user": 1, + "mail": 2, + "system": 3, + "security": 4, + "syslog": 5, + "printer": 6, + "network": 7, + "uucp": 8, + "clock": 9, + "authpriv": 10, + "ftp": 11, + "ntp": 12, + "audit": 13, + "alert": 14, + "cron": 15, + "local0": 16, + "local1": 17, + "local2": 18, + "local3": 19, + "local4": 20, + "local5": 21, + "local6": 22, + "local7": 23, + } + + def __init__( + self, + app_name: str, + facility: str, + ) -> None: + """Initialize RFC 5424 serializer.""" + self.app_name = app_name + self.facility = facility + + def serialize( + self, + event: NormalizedAuditEvent, + structured_data: dict[str, Any], + syslog_version: int, + ) -> str: + """Serialize audit event to RFC 5424 format.""" + severity = self._format_severity(event.severity) + timestamp = self._format_timestamp(event.timestamp) + hostname = self._format_hostname(event.hostname) + app_name = self._format_field(self.app_name, 48) + proc_id = self._format_field(event.service_name, 128) + msg_id = self._format_field(event.event_type, 32) + sd_str = self._format_structured_data(structured_data) + msg = self._format_message(event.syslog_message) + + return ( + f"<{severity}>{syslog_version} " + f"{timestamp} {hostname} {app_name} {proc_id} {msg_id} " + f"{sd_str}{msg}" + ) + + def _format_severity(self, severity: int) -> int: + """Calculate PRIORITY value (RFC 5424 section 6.2.1).""" + if not 0 <= severity <= 7: + raise NotImplementedError(f"Severity must be 0-7, got {severity}") + + facility_code = self.SYSLOG_FACILITIES.get( + self.facility.lower(), + self.SYSLOG_FACILITIES["authpriv"], + ) + + return (facility_code << 3) | severity + + def _format_timestamp(self, timestamp: float) -> str: + """Format TIMESTAMP field (RFC 5424 section 6.2.3).""" + dt = datetime.fromtimestamp(timestamp, tz=timezone.utc) + return dt.isoformat(timespec="milliseconds").replace("+00:00", "Z") + + def _format_hostname(self, hostname: str | None) -> str: + """Format HOSTNAME field (RFC 5424 section 6.2.4).""" + if not hostname: + hostname = socket.gethostname() + + return self._format_field(hostname, 255) + + def _format_field( + self, + value: str | None, + max_length: int, + ) -> str: + """Format generic RFC 5424 field.""" + if not value: + return self.NILVALUE + + sanitized = "".join(c for c in value if 33 <= ord(c) <= 126)[ + :max_length + ] + + return sanitized or self.NILVALUE + + def _format_structured_data( + self, + structured_data: dict[str, Any], + ) -> str: + """Format STRUCTURED-DATA field (RFC 5424 section 6.3).""" + if not structured_data: + return self.NILVALUE + + params = [] + for key, value in structured_data.items(): + param_name = self._sanitize_param_name(str(key)) + if not param_name: + continue + + param_value = self._escape_param_value(str(value)) + params.append(f'{param_name}="{param_value}"') + + if not params: + return self.NILVALUE + + sd_id = f"audit@{self.STRUCTURED_DATA_ID_SUFFIX}" + return f"[{sd_id} {' '.join(params)}]" + + def _sanitize_param_name(self, name: str) -> str: + """Sanitize PARAM-NAME for STRUCTURED-DATA. + + RFC 5424 allows only printable ASCII (33-126) + except: =, space, ], " + Max length: 32 characters + """ + return "".join( + c + for c in name + if 33 <= ord(c) <= 126 and c not in ("=", " ", "]", '"') + )[:32] + + def _escape_param_value(self, value: str) -> str: + """Escape PARAM-VALUE for STRUCTURED-DATA.""" + return ( + value.replace("\\", "\\\\").replace('"', r"\"").replace("]", r"\]") + ) + + def _format_message(self, msg: str | None) -> str: + """Format MSG field (RFC 5424 section 6.4).""" + if not msg: + return "" + + return f" {self.UTF8_BOM}{msg}" diff --git a/app/ldap_protocol/policies/audit/events/service_senders/syslog.py b/app/ldap_protocol/policies/audit/events/service_senders/syslog.py index 1ba1bbac7..3f30f323b 100644 --- a/app/ldap_protocol/policies/audit/events/service_senders/syslog.py +++ b/app/ldap_protocol/policies/audit/events/service_senders/syslog.py @@ -5,10 +5,7 @@ """ import asyncio -import socket -import uuid from copy import deepcopy -from datetime import datetime, timezone from typing import Any from loguru import logger @@ -19,43 +16,29 @@ ) from .base import AuditDestinationSenderABC +from .rfc5424_serializer import RFC5424Serializer class SyslogSender(AuditDestinationSenderABC): - """Syslog sender.""" + """Syslog sender with RFC 5424 support. + + Sends audit events to syslog servers using RFC 5424 format. + Supports both TCP and UDP protocols. + """ service_name: AuditDestinationServiceType = ( AuditDestinationServiceType.SYSLOG ) - SYSLOG_VERSION: int = 1 DEFAULT_TIMEOUT: int = 10 - DEFAULT_FACILITY = "authpriv" - SYSLOG_FACILITIES: dict[str, int] = { - "kernel": 0, - "user": 1, - "mail": 2, - "system": 3, - "security": 4, - "syslog": 5, - "printer": 6, - "network": 7, - "uucp": 8, - "clock": 9, - "authpriv": 10, - "ftp": 11, - "ntp": 12, - "audit": 13, - "alert": 14, - "cron": 15, - "local0": 16, - "local1": 17, - "local2": 18, - "local3": 19, - "local4": 20, - "local5": 21, - "local6": 22, - "local7": 23, - } + SYSLOG_VERSION: int = 1 + + def __init__(self, *args: Any, **kwargs: Any) -> None: + """Initialize syslog sender with RFC 5424 serializer.""" + super().__init__(*args, **kwargs) + self.__rfc_serializer = RFC5424Serializer( + app_name=self.DEFAULT_APP_NAME, + facility="authpriv", + ) async def _send_udp(self, message: str) -> None: """Send UDP.""" @@ -84,107 +67,24 @@ async def _send_tcp(self, message: str) -> None: writer.close() await writer.wait_closed() - def generate_rfc5424_message( - self, - event: NormalizedAuditEvent, - structured_data: dict[str, Any], - ) -> str: - """Generate a syslog message according to RFC 5424.""" - severity_code = event.severity - facility = self.DEFAULT_FACILITY - app_name = self.DEFAULT_APP_NAME - msg_id = str(uuid.uuid4()) - message = event.syslog_message - hostname = event.hostname - proc_id = event.service_name - - if not 0 <= severity_code <= 7: - raise ValueError("Severity code must be between 0 and 7") - - facility_code = self.SYSLOG_FACILITIES.get( - facility.lower(), - self.SYSLOG_FACILITIES[self.DEFAULT_FACILITY], - ) - priority = (facility_code << 3) | severity_code - - # TIMESTAMP (RFC 5424 section 6.2.3) - dt = datetime.fromtimestamp(event.timestamp, tz=timezone.utc) - timestamp = dt.isoformat( - timespec="milliseconds", - ).replace("+00:00", "Z") - - # HOSTNAME (section 6.2.4) - hostname = (hostname or socket.gethostname() or "-")[:255] - - # APP-NAME (section 6.2.5) - app_name = app_name or "-" - if len(app_name) > 48: - app_name = app_name[:48] - - # PROCID (section 6.2.6) - proc_id = proc_id or "-" - - # MSGID (section 6.2.7) - msg_id = msg_id or "-" - - # STRUCTURED-DATA (section 6.3) - sd_str = self._format_structured_data(app_name, structured_data) or "-" - - # MSG (section 6.4) - message = self._escape_message(message) if message else "" - - return ( - f"<{priority}>{self.SYSLOG_VERSION} {timestamp} " - f"{hostname} {app_name} {proc_id} {msg_id} " - f"{sd_str} {message}" - ) - - def _escape_message(self, msg: str) -> str: - """Escape special chars in message (RFC 5424 section 6.4).""" - return " " + msg.replace("\n", " ").replace("\r", " ") - - def _format_structured_data( - self, - app_name: str, - structured_data: dict[str, Any], - ) -> str: - """Format structured data according to RFC 5424 section 6.3.""" - if not structured_data: - return "" - - def escape_param_value(value: str) -> str: - return ( - value.replace("\\", "\\\\") - .replace('"', '\\"') - .replace("]", "\\]") - ) - - sd_id = f"{app_name}@{uuid.uuid4()}" - params = [] - - for k, v in structured_data.items(): - if not k or "=" in k or " " in k or '"' in k: - continue - escaped_value = escape_param_value(str(v)) - params.append(f'{k}="{escaped_value}"') - - if not params: - return "" - - return f"[{sd_id} {' '.join(params)}]" - async def send(self, event: NormalizedAuditEvent) -> None: """Send event.""" structured_data = deepcopy(event.destination_dict) - syslog_message = self.generate_rfc5424_message( + syslog_message = self.__rfc_serializer.serialize( event=event, structured_data=structured_data, + syslog_version=self.SYSLOG_VERSION, ) + if self._destination.protocol == AuditDestinationProtocolType.UDP: callback = self._send_udp elif self._destination.protocol == AuditDestinationProtocolType.TCP: callback = self._send_tcp + else: + raise NotImplementedError( + f"Unsupported protocol: {self._destination.protocol}", + ) try: await callback(syslog_message) diff --git a/app/ldap_protocol/policies/audit/exception.py b/app/ldap_protocol/policies/audit/exception.py index c32647dd5..7f9ca4d3a 100644 --- a/app/ldap_protocol/policies/audit/exception.py +++ b/app/ldap_protocol/policies/audit/exception.py @@ -4,10 +4,32 @@ License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE """ +from enum import IntEnum -class AuditNotFoundError(Exception): +from errors import BaseDomainException + + +class ErrorCodes(IntEnum): + """Error codes.""" + + BASE_ERROR = 0 + AUDIT_NOT_FOUND_ERROR = 1 + AUDIT_ALREADY_EXISTS_ERROR = 2 + + +class AuditError(BaseDomainException): + """Audit error.""" + + code: ErrorCodes = ErrorCodes.BASE_ERROR + + +class AuditNotFoundError(AuditError): """Exception raised when an audit model is not found.""" + code = ErrorCodes.AUDIT_NOT_FOUND_ERROR -class AuditAlreadyExistsError(Exception): + +class AuditAlreadyExistsError(AuditError): """Exception raised when an audit model already exists.""" + + code = ErrorCodes.AUDIT_ALREADY_EXISTS_ERROR diff --git a/app/ldap_protocol/policies/audit/monitor.py b/app/ldap_protocol/policies/audit/monitor.py index 0719c4149..5ce08d0ab 100644 --- a/app/ldap_protocol/policies/audit/monitor.py +++ b/app/ldap_protocol/policies/audit/monitor.py @@ -24,6 +24,7 @@ ) from ldap_protocol.auth.schemas import OAuth2Form from ldap_protocol.identity.exceptions import ( + AuthorizationError, AuthValidationError, LoginFailedError, PasswordPolicyError, @@ -33,7 +34,6 @@ from ldap_protocol.kerberos.exceptions import KRBAPIChangePasswordError from ldap_protocol.multifactor import MFA_HTTP_Creds from ldap_protocol.objects import OperationEvent -from ldap_protocol.permissions_checker import AuthorizationError from ldap_protocol.policies.audit.audit_use_case import AuditUseCase from ldap_protocol.policies.audit.events.factory import ( RawAuditEventBuilderRedis, diff --git a/app/ldap_protocol/policies/network/__init__.py b/app/ldap_protocol/policies/network/__init__.py index 09538aa8d..c4ea76e70 100644 --- a/app/ldap_protocol/policies/network/__init__.py +++ b/app/ldap_protocol/policies/network/__init__.py @@ -1,13 +1,26 @@ """Network policies module.""" -from .dto import NetworkPolicyDTO -from .exceptions import NetworkPolicyAlreadyExistsError +from .dto import NetworkPolicyDTO, NetworkPolicyUpdateDTO, SwapPrioritiesDTO +from .exceptions import ( + LastActivePolicyError, + NetworkPolicyAlreadyExistsError, + NetworkPolicyNotFoundError, +) from .gateway import NetworkPolicyGateway -from .use_cases import NetworkPolicyUseCase +from .use_cases import NetworkPolicyUseCase, NetworkPolicyValidatorUseCase +from .validator_gateway import NetworkPolicyValidatorGateway +from .validator_protocol import NetworkPolicyValidatorProtocol __all__ = [ "NetworkPolicyDTO", + "NetworkPolicyUpdateDTO", + "SwapPrioritiesDTO", "NetworkPolicyAlreadyExistsError", + "LastActivePolicyError", + "NetworkPolicyNotFoundError", "NetworkPolicyGateway", "NetworkPolicyUseCase", + "NetworkPolicyValidatorUseCase", + "NetworkPolicyValidatorGateway", + "NetworkPolicyValidatorProtocol", ] diff --git a/app/ldap_protocol/policies/network/exceptions.py b/app/ldap_protocol/policies/network/exceptions.py index ea63f3c2a..81ff561e8 100644 --- a/app/ldap_protocol/policies/network/exceptions.py +++ b/app/ldap_protocol/policies/network/exceptions.py @@ -4,18 +4,39 @@ License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE """ +from enum import IntEnum -class NetworkPolicyError(Exception): +from api.error_routing import BaseDomainException + + +class ErrorCodes(IntEnum): + """Error codes.""" + + BASE_ERROR = 0 + NETWORK_POLICY_ALREADY_EXISTS_ERROR = 1 + NETWORK_POLICY_NOT_FOUND_ERROR = 2 + LAST_ACTIVE_POLICY_ERROR = 3 + + +class NetworkPolicyError(BaseDomainException): """Network policy error.""" + code: ErrorCodes = ErrorCodes.BASE_ERROR + class NetworkPolicyAlreadyExistsError(NetworkPolicyError): """Network policy already exists error.""" + code = ErrorCodes.NETWORK_POLICY_ALREADY_EXISTS_ERROR + class NetworkPolicyNotFoundError(NetworkPolicyError): """Network policy not found error.""" + code = ErrorCodes.NETWORK_POLICY_NOT_FOUND_ERROR + class LastActivePolicyError(NetworkPolicyError): """Last active policy error.""" + + code = ErrorCodes.LAST_ACTIVE_POLICY_ERROR diff --git a/app/ldap_protocol/policies/network/use_cases.py b/app/ldap_protocol/policies/network/use_cases.py index a069b02d8..cde4294d6 100644 --- a/app/ldap_protocol/policies/network/use_cases.py +++ b/app/ldap_protocol/policies/network/use_cases.py @@ -4,6 +4,7 @@ License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE """ +from ipaddress import IPv4Address, IPv6Address from typing import ClassVar from adaptix import P @@ -11,8 +12,8 @@ from sqlalchemy.ext.asyncio import AsyncSession from abstract_service import AbstractService -from entities import NetworkPolicy -from enums import AuthorizationRules +from entities import NetworkPolicy, User +from enums import AuthorizationRules, ProtocolType from ldap_protocol.policies.network.dto import ( NetworkPolicyDTO, NetworkPolicyUpdateDTO, @@ -22,8 +23,10 @@ LastActivePolicyError, NetworkPolicyAlreadyExistsError, ) - -from .gateway import NetworkPolicyGateway +from ldap_protocol.policies.network.gateway import NetworkPolicyGateway +from ldap_protocol.policies.network.validator_protocol import ( + NetworkPolicyValidatorProtocol, +) def _convert_groups(policy: NetworkPolicy) -> list[str]: @@ -197,3 +200,103 @@ async def swap_priorities(self, id1: int, id2: int) -> SwapPrioritiesDTO: update.__name__: AuthorizationRules.NETWORK_POLICY_UPDATE, swap_priorities.__name__: AuthorizationRules.NETWORK_POLICY_SWAP_PRIORITIES, # noqa: E501 } + + +class NetworkPolicyValidatorUseCase(AbstractService): + """Network policies validator use cases.""" + + def __init__( + self, + network_policy_validator_gateway: NetworkPolicyValidatorProtocol, + ): + """Initialize Network policies validator use cases.""" + self._gateway = network_policy_validator_gateway + + async def get_by_protocol( + self, + ip: IPv4Address | IPv6Address, + protocol_type: ProtocolType, + ) -> NetworkPolicy | None: + """Get network policy by protocol.""" + return await self._gateway.get_by_protocol( + ip, + protocol_type, + ) + + async def get_user_network_policy( + self, + ip: IPv4Address | IPv6Address, + user: User, + policy_type: ProtocolType, + ) -> NetworkPolicy | None: + """Get user network policy.""" + return await self._gateway.get_user_network_policy( + ip, + user, + policy_type, + ) + + async def get_user_http_policy( + self, + ip: IPv4Address | IPv6Address, + user: User, + ) -> NetworkPolicy | None: + """Get user HTTP policy.""" + return await self._gateway.get_user_http_policy( + ip, + user, + ) + + async def get_user_kerberos_policy( + self, + ip: IPv4Address | IPv6Address, + user: User, + ) -> NetworkPolicy | None: + """Get user Kerberos policy.""" + return await self._gateway.get_user_kerberos_policy( + ip, + user, + ) + + async def get_user_ldap_policy( + self, + ip: IPv4Address | IPv6Address, + user: User, + ) -> NetworkPolicy | None: + """Get user LDAP policy.""" + return await self._gateway.get_user_ldap_policy( + ip, + user, + ) + + async def is_user_group_valid( + self, + user: User | None, + policy: NetworkPolicy | None, + ) -> bool: + """Validate user groups, is it including to policy.""" + return await self._gateway.is_user_group_valid( + user, + policy, + ) + + async def check_mfa_group( + self, + policy: NetworkPolicy, + user: User, + ) -> bool: + """Check if user is in a group with MFA policy.""" + return await self._gateway.check_mfa_group( + policy, + user, + ) + + PERMISSIONS: ClassVar[dict[str, AuthorizationRules]] = { + get_by_protocol.__name__: AuthorizationRules.NETWORK_POLICY_VALIDATOR_GET_BY_PROTOCOL, # noqa: E501 + get_user_network_policy.__name__: AuthorizationRules.NETWORK_POLICY_VALIDATOR_GET_USER_NETWORK_POLICY, # noqa: E501 + get_user_http_policy.__name__: AuthorizationRules.NETWORK_POLICY_VALIDATOR_GET_USER_HTTP_POLICY, # noqa: E501 + get_user_kerberos_policy.__name__: AuthorizationRules.NETWORK_POLICY_VALIDATOR_GET_USER_KERBEROS_POLICY, # noqa: E501 + get_user_ldap_policy.__name__: AuthorizationRules.NETWORK_POLICY_VALIDATOR_GET_USER_LDAP_POLICY, # noqa: E501 + is_user_group_valid.__name__: AuthorizationRules.NETWORK_POLICY_VALIDATOR_IS_USER_GROUP_VALID, # noqa: E501 + check_mfa_group.__name__: AuthorizationRules.NETWORK_POLICY_VALIDATOR_CHECK_MFA_GROUP, # noqa: E501 + } diff --git a/app/ldap_protocol/policies/network/validator_gateway.py b/app/ldap_protocol/policies/network/validator_gateway.py new file mode 100644 index 000000000..e04cb0dbe --- /dev/null +++ b/app/ldap_protocol/policies/network/validator_gateway.py @@ -0,0 +1,178 @@ +"""Network policy validator gateway. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from ipaddress import IPv4Address, IPv6Address + +from sqlalchemy import exists, or_, select, text +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload +from sqlalchemy.sql.expression import Select, true + +from entities import Group, NetworkPolicy, User +from enums import ProtocolType +from repo.pg.tables import queryable_attr as qa + + +class NetworkPolicyValidatorGateway: + """Gateway for validating network policies.""" + + def __init__( + self, + session: AsyncSession, + ): + """Initialize validator gateway.""" + self._session = session + + def _build_base_query( + self, + ip: IPv4Address | IPv6Address, + protocol_type: ProtocolType, + ) -> Select: + """Build a base query for network policies. + + :param IPv4Address | IPv6Address ip: IP address to filter + :param ProtocolType protocol_type: Protocol to filter + :param list[int] | None user_group_ids: + List of user group IDs, optional + :return: Select query + """ + protocol_field = getattr(NetworkPolicy, protocol_type) + query = ( + select(NetworkPolicy) + .options( + selectinload(qa(NetworkPolicy.groups)), + selectinload(qa(NetworkPolicy.mfa_groups)), + ) + .filter( + qa(NetworkPolicy.enabled).is_(True), + text(':ip <<= ANY("Policies".netmasks)').bindparams(ip=ip), + protocol_field == true(), + ) + .order_by(qa(NetworkPolicy.priority).asc()) + .limit(1) + ) + + return query + + async def get_by_protocol( + self, + ip: IPv4Address | IPv6Address, + protocol_type: ProtocolType, + ) -> NetworkPolicy | None: + """Get network policy by protocol.""" + query = self._build_base_query(ip, protocol_type) + return await self._session.scalar(query) + + async def get_user_network_policy( + self, + ip: IPv4Address | IPv6Address, + user: User, + policy_type: ProtocolType, + ) -> NetworkPolicy | None: + """Get the highest priority network policy for user, ip and protocol. + + :param User user: user object + :return NetworkPolicy | None: a NetworkPolicy object + """ + user_group_ids = [group.id for group in user.groups] + + query = self._build_base_query(ip, policy_type) + + if user_group_ids is not None: + query = query.filter( + or_( + qa(NetworkPolicy.groups) == None, # noqa + qa(NetworkPolicy.groups).any( + qa(Group.id).in_(user_group_ids), + ), + ), + ) + + return await self._session.scalar(query) + + async def get_user_http_policy( + self, + ip: IPv4Address | IPv6Address, + user: User, + ) -> NetworkPolicy | None: + """Get user HTTP policy.""" + return await self.get_user_network_policy( + ip, + user, + ProtocolType.HTTP, + ) + + async def get_user_kerberos_policy( + self, + ip: IPv4Address | IPv6Address, + user: User, + ) -> NetworkPolicy | None: + """Get user Kerberos policy.""" + return await self.get_user_network_policy( + ip, + user, + ProtocolType.KERBEROS, + ) + + async def get_user_ldap_policy( + self, + ip: IPv4Address | IPv6Address, + user: User, + ) -> NetworkPolicy | None: + """Get user LDAP policy.""" + return await self.get_user_network_policy( + ip, + user, + ProtocolType.LDAP, + ) + + async def is_user_group_valid( + self, + user: User | None, + policy: NetworkPolicy | None, + ) -> bool: + """Validate user groups, is it including to policy. + + :param User user: db user + :param NetworkPolicy policy: db policy + :return bool: status + """ + if not (user and policy): + return False + + if not policy.groups: + return True + query = select( + select(Group) + .join(qa(Group.users)) + .join(qa(Group.policies), isouter=True) + .exists() + .where(qa(Group.users).contains(user)) + .where(qa(Group.policies).contains(policy)), + ) + group = await self._session.scalar(query) + + return bool(group) + + async def check_mfa_group( + self, + policy: NetworkPolicy, + user: User, + ) -> bool: + """Check if user is in a group with MFA policy. + + :param NetworkPolicy policy: policy object + :param User user: user object + :return bool: status + """ + return await self._session.scalar( + select( + exists().where( # type: ignore + qa(Group.mfa_policies).contains(policy), + qa(Group.users).contains(user), + ), + ), + ) diff --git a/app/ldap_protocol/policies/network/validator_protocol.py b/app/ldap_protocol/policies/network/validator_protocol.py new file mode 100644 index 000000000..5ead88a2f --- /dev/null +++ b/app/ldap_protocol/policies/network/validator_protocol.py @@ -0,0 +1,72 @@ +"""Network policy validator protocol. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from ipaddress import IPv4Address, IPv6Address +from typing import Protocol + +from entities import NetworkPolicy, User +from enums import ProtocolType + + +class NetworkPolicyValidatorProtocol(Protocol): + """Protocol for validating network policies.""" + + async def get_user_http_policy( + self, + ip: IPv4Address | IPv6Address, + user: User, + ) -> NetworkPolicy | None: + """Get user HTTP policy.""" + ... + + async def get_user_kerberos_policy( + self, + ip: IPv4Address | IPv6Address, + user: User, + ) -> NetworkPolicy | None: + """Get user Kerberos policy.""" + ... + + async def get_user_ldap_policy( + self, + ip: IPv4Address | IPv6Address, + user: User, + ) -> NetworkPolicy | None: + """Get user LDAP policy.""" + ... + + async def check_mfa_group( + self, + policy: NetworkPolicy, + user: User, + ) -> bool: + """Check if user is in a group with MFA policy.""" + ... + + async def is_user_group_valid( + self, + user: User | None, + policy: NetworkPolicy | None, + ) -> bool: + """Validate user groups, is it including to policy.""" + ... + + async def get_user_network_policy( + self, + ip: IPv4Address | IPv6Address, + user: User, + policy_type: ProtocolType, + ) -> NetworkPolicy | None: + """Get the highest priority network policy.""" + ... + + async def get_by_protocol( + self, + ip: IPv4Address | IPv6Address, + protocol_type: ProtocolType, + ) -> NetworkPolicy | None: + """Get network policy by protocol.""" + ... diff --git a/app/ldap_protocol/policies/network_policy.py b/app/ldap_protocol/policies/network_policy.py deleted file mode 100644 index 264616ab7..000000000 --- a/app/ldap_protocol/policies/network_policy.py +++ /dev/null @@ -1,130 +0,0 @@ -"""Network policy manager. - -Copyright (c) 2024 MultiFactor -License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE -""" - -from ipaddress import IPv4Address, IPv6Address -from typing import Literal - -from sqlalchemy import exists, or_, select, text -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import selectinload -from sqlalchemy.sql.expression import Select, true - -from entities import Group, NetworkPolicy, User -from repo.pg.tables import queryable_attr as qa - - -def build_policy_query( - ip: IPv4Address | IPv6Address, - protocol_field_name: Literal["is_http", "is_ldap", "is_kerberos"], - user_group_ids: list[int] | None = None, -) -> Select: - """Build a base query for network policies with optional group filtering. - - :param IPv4Address ip: IP address to filter - :param Literal["is_http", "is_ldap", "is_kerberos"] protocol_field_name - protocol: Protocol to filter - :param list[int] | None user_group_ids: List of user group IDs, optional - :return: Select query - """ - protocol_field = getattr(NetworkPolicy, protocol_field_name) - query = ( - select(NetworkPolicy) - .filter_by(enabled=True) - .options( - selectinload(qa(NetworkPolicy.groups)), - selectinload(qa(NetworkPolicy.mfa_groups)), - ) - .filter( - text(':ip <<= ANY("Policies".netmasks)').bindparams(ip=ip), - protocol_field == true(), - ) - .order_by(qa(NetworkPolicy.priority).asc()) - .limit(1) - ) - - if user_group_ids is not None: - return query.filter( - or_( - qa(NetworkPolicy.groups) == None, # noqa - qa(NetworkPolicy.groups).any( - qa(Group.id).in_(user_group_ids), - ), - ), - ) - - return query - - -async def check_mfa_group( - policy: NetworkPolicy, - user: User, - session: AsyncSession, -) -> bool: - """Check if user is in a group with MFA policy. - - :param NetworkPolicy policy: policy object - :param User user: user object - :param AsyncSession session: db session - :return bool: status - """ - return await session.scalar( - select( - exists().where( # type: ignore - qa(Group.mfa_policies).contains(policy), - qa(Group.users).contains(user), - ), - ), - ) - - -async def get_user_network_policy( - ip: IPv4Address | IPv6Address, - user: User, - session: AsyncSession, - policy_type: Literal["is_http", "is_ldap", "is_kerberos"], -) -> NetworkPolicy | None: - """Get the highest priority network policy for user, ip and protocol. - - :param User user: user object - :param AsyncSession session: db session - :return NetworkPolicy | None: a NetworkPolicy object - """ - user_group_ids = [group.id for group in user.groups] - - query = build_policy_query(ip, policy_type, user_group_ids) - - return await session.scalar(query) - - -async def is_user_group_valid( - user: User | None, - policy: NetworkPolicy | None, - session: AsyncSession, -) -> bool: - """Validate user groups, is it including to policy. - - :param User user: db user - :param NetworkPolicy policy: db policy - :param AsyncSession session: db - :return bool: status - """ - if user is None or policy is None: - return False - - if not policy.groups: - return True - - query = ( - select(Group) - .join(qa(Group.users)) - .join(qa(Group.policies), isouter=True) - .where(qa(Group.users).contains(user)) - .where(qa(Group.policies).contains(policy)) - .limit(1) - ) - - group = await session.scalar(query) - return bool(group) diff --git a/app/ldap_protocol/policies/password/dao.py b/app/ldap_protocol/policies/password/dao.py index 95a759bf7..5c818ca0a 100644 --- a/app/ldap_protocol/policies/password/dao.py +++ b/app/ldap_protocol/policies/password/dao.py @@ -10,10 +10,15 @@ from adaptix.conversion import get_converter, link_function from sqlalchemy import Integer, String, cast, exists, func, select, update from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import selectinload +from sqlalchemy.orm import attributes, selectinload from abstract_dao import AbstractDAO from entities import Attribute, Group, PasswordPolicy, User +from enums import EntityTypeNames +from ldap_protocol.ldap_schema.attribute_value_validator import ( + AttributeValueValidator, + AttributeValueValidatorError, +) from ldap_protocol.objects import UserAccountControlFlag as UacFlag from ldap_protocol.policies.password.exceptions import ( PasswordPolicyAlreadyExistsError, @@ -63,13 +68,16 @@ class PasswordPolicyDAO(AbstractDAO[PasswordPolicyDTO, int]): """Password Policy DAO.""" _session: AsyncSession + __attribute_value_validator: AttributeValueValidator def __init__( self, session: AsyncSession, + attribute_value_validator: AttributeValueValidator, ) -> None: """Initialize Password Policy DAO with a database session.""" self._session = session + self.__attribute_value_validator = attribute_value_validator async def _get_total_count(self) -> int: """Count all Password Policies.""" @@ -392,6 +400,13 @@ async def get_or_create_pwd_last_set( ) # fmt: skip if not plset_attribute: + if not self.__attribute_value_validator.is_value_valid( + EntityTypeNames.USER, + "pwdLastSet", + ft_now(), + ): + raise AttributeValueValidatorError("Invalid pwdLastSet value") + plset_attribute = Attribute( directory_id=directory_id, name="pwdLastSet", @@ -425,6 +440,7 @@ async def post_save_password_actions(self, user: User) -> None: await self._session.execute(query) user.password_history.append(tcast("str", user.password)) + attributes.flag_modified(user, "password_history") await self._session.flush() async def is_password_change_restricted( diff --git a/app/ldap_protocol/policies/password/exceptions.py b/app/ldap_protocol/policies/password/exceptions.py index f3a5f0414..e8c707765 100644 --- a/app/ldap_protocol/policies/password/exceptions.py +++ b/app/ldap_protocol/policies/password/exceptions.py @@ -4,50 +4,96 @@ License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE """ +from enum import IntEnum -class PasswordPolicyBaseError(Exception): +from api.error_routing import BaseDomainException + + +class ErrorCodes(IntEnum): + """Error codes.""" + + BASE_ERROR = 0 + PASSWORD_POLICY_ALREADY_EXISTS_ERROR = 1 + PASSWORD_POLICY_NOT_FOUND_ERROR = 2 + PASSWORD_POLICY_DIR_IS_NOT_USER_ERROR = 3 + PASSWORD_POLICY_BASE_DN_NOT_FOUND_ERROR = 4 + PASSWORD_POLICY_CANT_CHANGE_DEFAULT_DOMAIN_ERROR = 5 + PASSWORD_POLICY_PRIORITY_ERROR = 6 + PASSWORD_POLICY_AGE_DAYS_ERROR = 7 + + PASSWORD_BAN_WORD_ERROR = 8 + PASSWORD_BAN_WORD_FILE_HAS_DUPLICATES_ERROR = 9 + PASSWORD_BAN_WORD_TOO_LONG_ERROR = 10 + PASSWORD_BAN_WORD_WRONG_FILE_EXTENSION_ERROR = 11 + + +class PasswordPolicyError(BaseDomainException): """Base exception class for Password Policy service errors.""" + code: ErrorCodes = ErrorCodes.BASE_ERROR + -class PasswordPolicyAlreadyExistsError(PasswordPolicyBaseError): +class PasswordPolicyAlreadyExistsError(PasswordPolicyError): """Exception raised when a Password Policy already exists.""" + code = ErrorCodes.PASSWORD_POLICY_ALREADY_EXISTS_ERROR -class PasswordPolicyNotFoundError(PasswordPolicyBaseError): + +class PasswordPolicyNotFoundError(PasswordPolicyError): """Exception raised when a Password Policy not found.""" + code = ErrorCodes.PASSWORD_POLICY_NOT_FOUND_ERROR + -class PasswordPolicyDirIsNotUserError(PasswordPolicyBaseError): +class PasswordPolicyDirIsNotUserError(PasswordPolicyError): """Exception raised when the directory is not a user.""" + code = ErrorCodes.PASSWORD_POLICY_DIR_IS_NOT_USER_ERROR -class PasswordPolicyBaseDnNotFoundError(PasswordPolicyBaseError): + +class PasswordPolicyBaseDnNotFoundError(PasswordPolicyError): """Exception raised when a Base DN not found.""" + code = ErrorCodes.PASSWORD_POLICY_BASE_DN_NOT_FOUND_ERROR + -class PasswordPolicyCantChangeDefaultDomainError(PasswordPolicyBaseError): +class PasswordPolicyCantChangeDefaultDomainError(PasswordPolicyError): """Cannot change the name of the default domain Password Policy.""" + code = ErrorCodes.PASSWORD_POLICY_CANT_CHANGE_DEFAULT_DOMAIN_ERROR + -class PasswordPolicyPriorityError(PasswordPolicyBaseError): +class PasswordPolicyPriorityError(PasswordPolicyError): """Exception raised when there is a priority error.""" + code = ErrorCodes.PASSWORD_POLICY_PRIORITY_ERROR -class PasswordPolicyAgeDaysError(PasswordPolicyBaseError): + +class PasswordPolicyAgeDaysError(PasswordPolicyError): """Exception raised when the age days are invalid.""" + code = ErrorCodes.PASSWORD_POLICY_AGE_DAYS_ERROR + -class PasswordBanWordError(Exception): +class PasswordBanWordError(PasswordPolicyError): """Base exception class for password policy service errors.""" + code = ErrorCodes.PASSWORD_BAN_WORD_ERROR + class PasswordBanWordFileHasDuplicatesError(PasswordBanWordError): """Exception raised when a ban word already exists.""" + code = ErrorCodes.PASSWORD_BAN_WORD_FILE_HAS_DUPLICATES_ERROR + class PasswordBanWordTooLongError(PasswordBanWordError): """Exception raised when a ban word too long.""" + code = ErrorCodes.PASSWORD_BAN_WORD_TOO_LONG_ERROR + class PasswordBanWordWrongFileExtensionError(PasswordBanWordError): """Exception raised when a ban words file has wrong extension.""" + + code = ErrorCodes.PASSWORD_BAN_WORD_WRONG_FILE_EXTENSION_ERROR diff --git a/app/ldap_protocol/policies/password/use_cases.py b/app/ldap_protocol/policies/password/use_cases.py index fe22e9b45..e0518e4a0 100644 --- a/app/ldap_protocol/policies/password/use_cases.py +++ b/app/ldap_protocol/policies/password/use_cases.py @@ -6,9 +6,12 @@ from typing import ClassVar, Iterable +from sqlalchemy.ext.asyncio import AsyncSession + from abstract_service import AbstractService from entities import User from enums import AuthorizationRules +from ldap_protocol.identity.exceptions import UserNotFoundError from ldap_protocol.policies.password.ban_word_repository import ( PasswordBanWordRepository, ) @@ -16,12 +19,37 @@ MAX_BANWORD_LENGTH, MIN_LENGTH_FOR_TRGM, ) +from ldap_protocol.utils.queries import get_user from .dao import PasswordPolicyDAO from .dataclasses import PasswordPolicyDTO, PriorityT from .validator import PasswordPolicyValidator +class UserPasswordHistoryUseCases(AbstractService): + """User Password History Use Cases.""" + + _session: AsyncSession + + def __init__(self, session: AsyncSession) -> None: + self._session = session + + async def clear(self, identity: str) -> None: + user = await get_user(self._session, identity) + + if not user: + raise UserNotFoundError( + f"User {identity} not found in the database.", + ) + + user.password_history = [] + await self._session.flush() + + PERMISSIONS: ClassVar[dict[str, AuthorizationRules]] = { + clear.__name__: AuthorizationRules.USER_CLEAR_PASSWORD_HISTORY, + } + + class PasswordPolicyUseCases(AbstractService): """Password Policy Use Cases.""" diff --git a/app/ldap_protocol/roles/role_use_case.py b/app/ldap_protocol/roles/role_use_case.py index 1e978a3f1..d9c2921e0 100644 --- a/app/ldap_protocol/roles/role_use_case.py +++ b/app/ldap_protocol/roles/role_use_case.py @@ -8,6 +8,7 @@ from entities import AccessControlEntry, AceType, Directory, Role from enums import AuthorizationRules, RoleConstants, RoleScope +from ldap_protocol.kerberos.utils import get_system_container_dn from ldap_protocol.utils.queries import get_base_directories from repo.pg.tables import ( access_control_entries_table, @@ -211,7 +212,7 @@ async def create_kerberos_system_role(self) -> None: aces = self._get_full_access_aces( role_id=self._role_dao.get_last_id(), - base_dn="ou=services," + base_dn_list[0].path_dn, + base_dn=get_system_container_dn(base_dn_list[0].path_dn), ) await self._access_control_entry_dao.create_bulk(aces) diff --git a/app/ldap_protocol/server.py b/app/ldap_protocol/server.py index 597b107a6..39b633a26 100644 --- a/app/ldap_protocol/server.py +++ b/app/ldap_protocol/server.py @@ -11,33 +11,26 @@ from io import BytesIO from ipaddress import IPv4Address, IPv6Address, ip_address from traceback import format_exc -from typing import Literal, cast, overload +from typing import Literal, NewType, cast, overload from dishka import AsyncContainer, Scope from loguru import logger from proxyprotocol import ProxyProtocolIncompleteError from proxyprotocol.v2 import ProxyProtocolV2 from pydantic import ValidationError -from sqlalchemy.ext.asyncio import AsyncSession from config import Settings from ldap_protocol import LDAPRequestMessage, LDAPSession from ldap_protocol.ldap_requests.bind_methods import GSSAPISL +from ldap_protocol.policies.network import NetworkPolicyValidatorUseCase from .data_logger import DataLogger -log = logger.bind(name="ldap") -log.add( - "logs/ldap_{time:DD-MM-YYYY}.log", - filter=lambda rec: rec["extra"].get("name") == "ldap", - retention="10 days", - rotation="1d", - colorize=False, -) - infinity = cast("int", math.inf) pp_v2 = ProxyProtocolV2() +ServerLogger = NewType("ServerLogger", type[logger]) # type: ignore + class PoolClientHandler: """Async client handler. @@ -52,14 +45,20 @@ class PoolClientHandler: ssl_context: ssl.SSLContext | None = None - def __init__(self, settings: Settings, container: AsyncContainer): + def __init__( + self, + settings: Settings, + container: AsyncContainer, + log: ServerLogger, + ): """Set workers number for single client concurrent handling.""" self.container = container self.settings = settings self.num_workers = self.settings.COROUTINES_NUM_PER_CLIENT self._size = self.settings.TCP_PACKET_SIZE - self.logger = DataLogger(log, is_full=self.settings.DEBUG) + self.log = log + self.logger = DataLogger(self.log, is_full=self.settings.DEBUG) self._load_ssl_context() @@ -78,15 +77,20 @@ async def __call__( ) ldap_session.ip = addr - logger.info(f"Connection {addr} opened") + self.log.info(f"Connection {addr} opened") try: async with session_scope(scope=Scope.REQUEST) as r: try: - session = await r.get(AsyncSession) - await ldap_session.validate_conn(addr, session) + network_policy_use_case = await r.get( + NetworkPolicyValidatorUseCase, + ) + await ldap_session.validate_conn( + addr, + network_policy_use_case, + ) except PermissionError: - log.warning(f"Whitelist violation from {addr}") + self.log.warning(f"Whitelist violation from {addr}") return async with asyncio.TaskGroup() as tg: @@ -111,7 +115,9 @@ async def __call__( ) except* RuntimeError as err: - log.error(f"Response handling error {err}: {format_exc()}") + self.log.error( + f"Response handling error {err}: {format_exc()}", + ) finally: await session_scope.close() @@ -120,18 +126,18 @@ async def __call__( writer.close() await writer.wait_closed() - logger.info(f"Connection {addr} closed") + self.log.info(f"Connection {addr} closed") def _load_ssl_context(self) -> None: """Load SSL context for LDAPS.""" if self.settings.USE_CORE_TLS and self.settings.LDAP_LOAD_SSL_CERT: if not self.settings.check_certs_exist(): - log.critical("Certs not found, exiting...") + self.log.critical("Certs not found, exiting...") raise SystemExit(1) cert_name = self.settings.SSL_CERT key_name = self.settings.SSL_KEY - log.success("Found existing cert and key, loading...") + self.log.success("Found existing cert and key, loading...") self.ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) self.ssl_context.load_cert_chain(cert_name, key_name) @@ -160,7 +166,7 @@ def _extract_proxy_protocol_address( header_length = int.from_bytes(data[14:16], "big") return addr, data[16 + header_length :] except (ValueError, ProxyProtocolIncompleteError) as err: - log.error(f"Proxy Protocol processing error: {err}") + self.log.error(f"Proxy Protocol processing error: {err}") return peer_addr, data @overload @@ -273,7 +279,7 @@ async def _handle_request( request = LDAPRequestMessage.from_bytes(data) except (ValidationError, IndexError, KeyError, ValueError) as err: - log.error(f"Invalid schema {format_exc()}") + self.log.error(f"Invalid schema {format_exc()}") writer.write(LDAPRequestMessage.from_err(data, err).encode()) await writer.drain() @@ -434,15 +440,14 @@ async def _run_server(server: asyncio.base_events.Server) -> None: async with server: await server.serve_forever() - @staticmethod - def log_addrs(server: asyncio.base_events.Server) -> None: + def log_addrs(self, server: asyncio.base_events.Server) -> None: addrs = ", ".join(str(sock.getsockname()) for sock in server.sockets) - log.info(f"Server on {addrs}") + self.log.info(f"Server on {addrs}") async def start(self) -> None: """Run and log tcp server.""" server = await self._get_server() - log.info( + self.log.info( f"started {'DEBUG' if self.settings.DEBUG else 'PROD'} " f"{'LDAPS' if self.settings.USE_CORE_TLS else 'LDAP'} server", ) diff --git a/app/ldap_protocol/session_storage/base.py b/app/ldap_protocol/session_storage/base.py index cc5724159..f4f9b3ac7 100644 --- a/app/ldap_protocol/session_storage/base.py +++ b/app/ldap_protocol/session_storage/base.py @@ -9,8 +9,6 @@ from secrets import token_hex from typing import Literal, Self -from loguru import logger - from config import Settings from .exceptions import ( @@ -175,9 +173,6 @@ async def get_user_id( raise SessionStorageInvalidKeyError("Invalid payload key") try: - logger.debug( - f"Retrieving session data for session_id: {session_id}", - ) data = await self.get(session_id) except KeyError: raise SessionStorageInvalidKeyError("Invalid session key") diff --git a/app/ldap_protocol/session_storage/exceptions.py b/app/ldap_protocol/session_storage/exceptions.py index fcbf820fe..76a453997 100644 --- a/app/ldap_protocol/session_storage/exceptions.py +++ b/app/ldap_protocol/session_storage/exceptions.py @@ -4,30 +4,67 @@ License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE """ +from enum import IntEnum -class SessionStorageError(Exception): +from errors import BaseDomainException + + +class ErrorCodes(IntEnum): + """Error codes.""" + + BASE_ERROR = 0 + INVALID_KEY_ERROR = 1 + MISSING_DATA_ERROR = 2 + INVALID_IP_ERROR = 3 + INVALID_USER_AGENT_ERROR = 4 + INVALID_SIGNATURE_ERROR = 5 + INVALID_DATA_ERROR = 6 + USER_NOT_FOUND_ERROR = 7 + + +class SessionStorageError(BaseDomainException): """Session storage error.""" + code: ErrorCodes = ErrorCodes.BASE_ERROR + class SessionStorageInvalidKeyError(SessionStorageError): """Session storage invalid key error.""" + code = ErrorCodes.INVALID_KEY_ERROR + class SessionStorageMissingDataError(SessionStorageError): """Session storage missing data error.""" + code = ErrorCodes.MISSING_DATA_ERROR + class SessionStorageInvalidIpError(SessionStorageError): """Session storage invalid ip error.""" + code = ErrorCodes.INVALID_IP_ERROR + class SessionStorageInvalidUserAgentError(SessionStorageError): """Session storage invalid user agent error.""" + code = ErrorCodes.INVALID_USER_AGENT_ERROR + class SessionStorageInvalidSignatureError(SessionStorageError): """Session storage invalid signature error.""" + code = ErrorCodes.INVALID_SIGNATURE_ERROR + class SessionStorageInvalidDataError(SessionStorageError): """Session storage invalid data error.""" + + code = ErrorCodes.INVALID_DATA_ERROR + + +class SessionUserNotFoundError(SessionStorageError): + """Session storage user not found error.""" + + code = ErrorCodes.USER_NOT_FOUND_ERROR diff --git a/app/ldap_protocol/session_storage/repository.py b/app/ldap_protocol/session_storage/repository.py index ccb842980..84366faee 100644 --- a/app/ldap_protocol/session_storage/repository.py +++ b/app/ldap_protocol/session_storage/repository.py @@ -12,6 +12,7 @@ from enums import AuthorizationRules from ldap_protocol.utils.queries import get_user, set_user_logon_attrs +from .exceptions import SessionUserNotFoundError from .redis import SessionStorage @@ -102,7 +103,7 @@ async def get_user_sessions( user = await get_user(self.session, upn) if not user: - raise LookupError("User not found.") + raise SessionUserNotFoundError("User not found.") sessions = await self.storage.get_user_sessions(user.id) @@ -121,7 +122,7 @@ async def clear_user_sessions(self, identity: str | User) -> None: ) if not user: - raise LookupError("User not found.") + raise SessionUserNotFoundError("User not found.") await self.storage.clear_user_sessions(user.id) diff --git a/app/ldap_protocol/udp_server.py b/app/ldap_protocol/udp_server.py index b889d0264..84c69b16e 100644 --- a/app/ldap_protocol/udp_server.py +++ b/app/ldap_protocol/udp_server.py @@ -10,10 +10,10 @@ from dishka import AsyncContainer, Scope from loguru import logger from pydantic import ValidationError -from sqlalchemy.ext.asyncio import AsyncSession from config import Settings from ldap_protocol import LDAPRequestMessage, LDAPSession +from ldap_protocol.policies.network import NetworkPolicyValidatorUseCase from .data_logger import DataLogger from .utils.udp import create_udp_socket @@ -50,8 +50,13 @@ async def _handle( ldap_session.ip = ip_address(addr[0]) try: - session = await container.get(AsyncSession) - await ldap_session.validate_conn(ldap_session.ip, session) + network_policy_use_case = await container.get( + NetworkPolicyValidatorUseCase, + ) + await ldap_session.validate_conn( + ldap_session.ip, + network_policy_use_case, + ) except PermissionError: log.warning(f"Whitelist violation from UDP {addr_str}") raise ConnectionAbortedError diff --git a/app/ldap_protocol/utils/const.py b/app/ldap_protocol/utils/const.py index e2afe48f7..c8a44a03f 100644 --- a/app/ldap_protocol/utils/const.py +++ b/app/ldap_protocol/utils/const.py @@ -31,5 +31,3 @@ def _type_validate_email(email: str) -> str: GRANT_DN_STRING = Annotated[str, AfterValidator(_type_validate_entry)] EmailStr = Annotated[str, AfterValidator(_type_validate_email)] -DOMAIN_USERS_GROUP_NAME = "domain users" -DOMAIN_COMPUTERS_GROUP_NAME = "domain computers" diff --git a/app/ldap_protocol/utils/pagination.py b/app/ldap_protocol/utils/pagination.py index f82bb7c96..5e4ef6e4b 100644 --- a/app/ldap_protocol/utils/pagination.py +++ b/app/ldap_protocol/utils/pagination.py @@ -104,7 +104,7 @@ async def get( session: AsyncSession, ) -> Self: """Get paginator.""" - if query._order_by_clause is None or len(query._order_by_clause) == 0: # noqa SLF001 + if query._order_by_clause is None or len(query._order_by_clause) == 0: # noqa: SLF001 raise ValueError("Select query must have an order_by clause.") metadata = PaginationMetadata( diff --git a/app/ldap_protocol/utils/queries.py b/app/ldap_protocol/utils/queries.py index 9db9b889e..a1f9243de 100644 --- a/app/ldap_protocol/utils/queries.py +++ b/app/ldap_protocol/utils/queries.py @@ -15,6 +15,10 @@ from sqlalchemy.sql.expression import ColumnElement from entities import Attribute, Directory, Group, User +from ldap_protocol.ldap_schema.attribute_value_validator import ( + AttributeValueValidator, + AttributeValueValidatorError, +) from repo.pg.tables import ( directory_memberships_table, directory_table, @@ -336,6 +340,7 @@ def get_domain_object_class(domain: Directory) -> Iterator[Attribute]: async def create_group( name: str, sid: int | None, + attribute_value_validator: AttributeValueValidator, session: AsyncSession, ) -> tuple[Directory, Group]: """Create group in default groups path. @@ -388,9 +393,18 @@ async def create_group( for name, attr in attributes.items(): for val in attr: session.add(Attribute(name=name, value=val, directory_id=dir_.id)) - await session.flush() - await session.refresh(dir_) + + await session.refresh( + instance=dir_, + attribute_names=["attributes", "user"], + with_for_update=None, + ) + if not attribute_value_validator.is_directory_valid(dir_): + raise AttributeValueValidatorError( + "Invalid directory attributes values", + ) + await session.refresh(group) return dir_, group @@ -496,22 +510,22 @@ async def set_or_update_primary_group( f"group '{group_dn}'.", ) - existing_attr = await session.scalar( - select(Attribute) - .filter_by( - name="primaryGroupID", - directory_id=directory.id, - ), - ) # fmt: skip + updated_attribute = await session.scalar( + update(Attribute) + .values(value=group.directory.relative_id) + .where( + qa(Attribute.name) == "primaryGroupID", + qa(Attribute.directory_id) == directory.id, + ) + .returning(qa(Attribute.directory_id)), + ) - if existing_attr: - existing_attr.value = group.directory.relative_id - else: + if not updated_attribute: session.add( Attribute( name="primaryGroupID", - value=group.directory.relative_id, directory_id=directory.id, + value=group.directory.relative_id, ), ) diff --git a/app/multidirectory.py b/app/multidirectory.py index f88537105..22a19259d 100644 --- a/app/multidirectory.py +++ b/app/multidirectory.py @@ -15,7 +15,6 @@ from alembic.config import Config, command from dishka import Scope, make_async_container from dishka.integrations.fastapi import setup_dishka -from dns.exception import DNSException from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from loguru import logger @@ -35,18 +34,15 @@ password_policy_router, session_router, shadow_router, + user_password_history_router, ) -from api.exception_handlers import ( - handle_db_connect_error, - handle_dns_api_error, - handle_dns_error, - handle_not_implemented_error, -) +from api.exception_handlers import handle_auth_error, handle_db_connect_error from api.middlewares import proc_time_header_middleware, set_key_middleware from config import Settings from extra.dump_acme_certs import dump_acme_cert from ioc import ( EventSenderProvider, + GlobalLDAPServerProvider, HTTPProvider, LDAPServerProvider, MainProvider, @@ -54,14 +50,10 @@ MFAProvider, ) from ldap_protocol.dependency import resolve_deps -from ldap_protocol.dns import ( - DNSConnectionError, - DNSError, - DNSNotImplementedError, -) +from ldap_protocol.identity.exceptions import UnauthorizedError from ldap_protocol.policies.audit.events.handler import AuditEventHandler from ldap_protocol.policies.audit.events.sender import AuditEventSenderManager -from ldap_protocol.server import PoolClientHandler +from ldap_protocol.server import PoolClientHandler, ServerLogger from ldap_protocol.udp_server import CLDAPUDPServer from schedule import scheduler_factory @@ -92,6 +84,7 @@ def _create_basic_app(settings: Settings) -> FastAPI: app.include_router(password_policy_router) app.include_router(krb5_router) app.include_router(dns_router) + app.include_router(user_password_history_router) app.include_router(session_router) app.include_router(ldap_schema_router) app.include_router(dhcp_router) @@ -108,14 +101,7 @@ def _create_basic_app(settings: Settings) -> FastAPI: app.middleware("http")(set_key_middleware) app.add_exception_handler(sa_exc.TimeoutError, handle_db_connect_error) app.add_exception_handler(sa_exc.InterfaceError, handle_db_connect_error) - app.add_exception_handler(DNSException, handle_dns_error) - app.add_exception_handler(DNSConnectionError, handle_dns_error) - app.add_exception_handler(DNSError, handle_dns_api_error) - app.add_exception_handler( - DNSNotImplementedError, - handle_not_implemented_error, - ) - + app.add_exception_handler(UnauthorizedError, handle_auth_error) return app @@ -213,7 +199,17 @@ async def ldap_factory(settings: Settings) -> None: ) settings = await container.get(Settings) - servers.append(PoolClientHandler(settings, container).start()) + log: ServerLogger = logger.bind(name="ldap") # type: ignore + log.add( + "logs/ldap_{time:DD-MM-YYYY}.log", + filter=lambda rec: rec["extra"].get("name") == "ldap", + retention="10 days", + rotation="1d", + colorize=False, + enqueue=True, + ) + + servers.append(PoolClientHandler(settings, container, log).start()) await asyncio.gather(*servers) @@ -231,6 +227,38 @@ async def cldap_factory(settings: Settings) -> None: await CLDAPUDPServer(settings, container).start() +async def global_ldap_server_factory(settings: Settings) -> None: + """Run global_ldap_server_factory.""" + servers = [] + + for setting in ( + settings.get_copy_4_global(), + settings.get_copy_4_global_tls(), + ): + container = make_async_container( + GlobalLDAPServerProvider(), + MainProvider(), + MFAProvider(), + MFACredsProvider(), + context={Settings: setting}, + ) + + settings = await container.get(Settings) + log: ServerLogger = logger.bind(name="global_catalog") # type: ignore + log.add( + "logs/global_catalog_{time:DD-MM-YYYY}.log", + filter=lambda rec: rec["extra"].get("name") == "global_catalog", + retention="10 days", + rotation="1d", + colorize=False, + enqueue=True, + ) + + servers.append(PoolClientHandler(settings, container, log).start()) + + await asyncio.gather(*servers) + + async def event_handler_factory(settings: Settings) -> None: """Run event handler.""" main_container = make_async_container( @@ -261,6 +289,10 @@ async def event_sender_factory(settings: Settings) -> None: ldap = partial(run_entrypoint, factory=ldap_factory) cldap = partial(run_entrypoint, factory=cldap_factory) +global_ldap_server = partial( + run_entrypoint, + factory=global_ldap_server_factory, +) scheduler = partial(run_entrypoint, factory=scheduler_factory) create_shadow_app = partial(create_prod_app, factory=_create_shadow_app) event_handler = partial(run_entrypoint, factory=event_handler_factory) @@ -274,6 +306,11 @@ async def event_sender_factory(settings: Settings) -> None: group = parser.add_mutually_exclusive_group(required=True) group.add_argument("--ldap", action="store_true", help="Run ldap") group.add_argument("--cldap", action="store_true", help="Run cldap") + group.add_argument( + "--global_ldap_server", + action="store_true", + help="Run global_ldap_server", + ) group.add_argument("--http", action="store_true", help="Run http") group.add_argument("--shadow", action="store_true", help="Run http") group.add_argument("--scheduler", action="store_true", help="Run tasks") @@ -303,9 +340,12 @@ async def event_sender_factory(settings: Settings) -> None: if args.ldap: ldap(settings=settings) - if args.cldap: + elif args.cldap: cldap(settings=settings) + elif args.global_ldap_server: + global_ldap_server(settings=settings) + elif args.event_sender: event_sender(settings=settings) diff --git a/app/repo/pg/tables.py b/app/repo/pg/tables.py index db63e6e30..5391c95d5 100644 --- a/app/repo/pg/tables.py +++ b/app/repo/pg/tables.py @@ -112,6 +112,7 @@ def _compile_create_uc( "Directory", metadata, Column("id", Integer, primary_key=True), + Column("is_system", Boolean, nullable=False, default=False), Column( "parentId", Integer, @@ -146,12 +147,6 @@ def _compile_create_uc( ), Column("depth", Integer, nullable=True), Column("objectSid", String, nullable=True, key="object_sid"), - Column( - "password_policy_id", - Integer, - ForeignKey("PasswordPolicies.id"), - nullable=True, - ), Column( "objectGUID", PG_UUID(as_uuid=True), @@ -668,6 +663,12 @@ def _compile_create_uc( "PasswordBanWords", metadata, Column("word", String(255), primary_key=True), + Index( + "idx_password_ban_words_word_gin_trgm", + "word", + postgresql_ops={"word": "gin_trgm_ops"}, + postgresql_using="gin", + ), ) dedicated_servers_table = Table( diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml index a562245b7..a28c8eae4 100644 --- a/docker-compose.dev.yml +++ b/docker-compose.dev.yml @@ -121,6 +121,59 @@ services: - traefik.udp.routers.cldap.service=cldap - traefik.udp.services.cldap.loadbalancer.server.port=389 + global_ldap_server: + build: + context: . + dockerfile: ./.docker/dev.Dockerfile + args: + DOCKER_BUILDKIT: 1 + target: runtime + image: multidirectory + restart: unless-stopped + deploy: + mode: replicated + replicas: 2 + endpoint_mode: dnsrr + resources: + reservations: + cpus: "0.25" + memory: 100M + environment: + - SERVICE_NAME=global_ldap_server + volumes: + - ./app:/app + - ./certs:/certs + - ldap_keytab:/LDAP_keytab/ + env_file: local.env + command: python -OO multidirectory.py --global_ldap_server + tty: true + depends_on: + migrations: + condition: service_completed_successfully + cert_local_check: + condition: service_completed_successfully + healthcheck: + test: ["CMD-SHELL", "nc -zv 127.0.0.1 3268 3269"] + interval: 30s + timeout: 10s + retries: 10 + start_period: 3s + labels: + - traefik.enable=true + + - traefik.tcp.routers.global_ldap.rule=HostSNI(`*`) + - traefik.tcp.routers.global_ldap.entrypoints=global_ldap + - traefik.tcp.routers.global_ldap.service=global_ldap + - traefik.tcp.services.global_ldap.loadbalancer.server.port=3268 + - traefik.tcp.services.global_ldap.loadbalancer.proxyprotocol.version=2 + + - traefik.tcp.routers.global_ldap_tls.rule=HostSNI(`*`) + - traefik.tcp.routers.global_ldap_tls.entrypoints=global_ldap_tls + - traefik.tcp.routers.global_ldap_tls.service=global_ldap_tls + - traefik.tcp.routers.global_ldap_tls.tls=true + - traefik.tcp.routers.global_ldap_tls.tls.certresolver=md-resolver + - traefik.tcp.services.global_ldap_tls.loadbalancer.server.port=3269 + - traefik.tcp.services.global_ldap_tls.loadbalancer.proxyprotocol.version=2 cert_local_check: image: multidirectory @@ -260,6 +313,7 @@ services: - ./certs:/certs - ./app:/app - ldap_keytab:/LDAP_keytab/ + - kdc:/etc/krb5kdc/ env_file: local.env command: python multidirectory.py --scheduler @@ -313,7 +367,7 @@ services: condition: service_healthy restart: true command: krb5kdc -n - + ports: - "88:88" - "88:88/udp" @@ -447,4 +501,4 @@ volumes: dns_server_file: dns_server_config: ldap_keytab: - dragonflydata: \ No newline at end of file + dragonflydata: diff --git a/docker-compose.remote.test.yml b/docker-compose.remote.test.yml index 96b9c8ce1..bd5659b02 100644 --- a/docker-compose.remote.test.yml +++ b/docker-compose.remote.test.yml @@ -9,7 +9,7 @@ services: POSTGRES_PASSWORD: password123 SECRET_KEY: 6a0452ae20cab4e21b6e9d18fa4b7bf397dd66ec3968b2d7407694278fd84cce POSTGRES_HOST: postgres - command: sh -c "python -m pytest -n auto -W ignore::DeprecationWarning -vv" + command: sh -c "python -m pytest -n auto -W ignore::DeprecationWarning -W ignore::coverage.exceptions.CoverageWarning -vv" postgres: image: postgres:16 diff --git a/docker-compose.test.yml b/docker-compose.test.yml index 221a4e4bf..96076b657 100644 --- a/docker-compose.test.yml +++ b/docker-compose.test.yml @@ -20,7 +20,7 @@ services: POSTGRES_HOST: postgres # PYTHONTRACEMALLOC: 1 PYTHONDONTWRITEBYTECODE: 1 - command: sh -c "python -B -m pytest -n auto -x -W ignore::DeprecationWarning -vv" + command: sh -c "python -B -m pytest -n auto -x -W ignore::DeprecationWarning -W ignore::coverage.exceptions.CoverageWarning -vv" tty: true postgres: @@ -34,12 +34,16 @@ services: - "5432" logging: driver: "none" + tmpfs: + - /var/lib/postgresql/data dragonfly: image: 'docker.dragonflydb.io/dragonflydb/dragonfly' container_name: dragonfly-test expose: - "6379" + tmpfs: + - /data deploy: resources: limits: diff --git a/docker-compose.yml b/docker-compose.yml index 891095345..543042d34 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -12,6 +12,8 @@ services: - "8080:8080" - "389:389" - "389:389/udp" + - "3268:3268" + - "3269:3269" - "636:636" - "749:749" - "464:464" @@ -113,6 +115,59 @@ services: - traefik.udp.routers.cldap.service=cldap - traefik.udp.services.cldap.loadbalancer.server.port=389 + global_ldap_server: + build: + context: . + dockerfile: ./.docker/dev.Dockerfile + args: + DOCKER_BUILDKIT: 1 + target: runtime + image: multidirectory + restart: unless-stopped + deploy: + mode: replicated + replicas: 2 + endpoint_mode: dnsrr + resources: + reservations: + cpus: "0.25" + memory: 100M + environment: + - SERVICE_NAME=global_ldap_server + volumes: + - ./app:/app + - ./certs:/certs + - ldap_keytab:/LDAP_keytab/ + env_file: local.env + command: python -OO multidirectory.py --global_ldap_server + tty: true + depends_on: + migrations: + condition: service_completed_successfully + cert_local_check: + condition: service_completed_successfully + healthcheck: + test: ["CMD-SHELL", "nc -zv 127.0.0.1 3268 3269"] + interval: 30s + timeout: 10s + retries: 10 + start_period: 3s + labels: + - traefik.enable=true + + - traefik.tcp.routers.global_ldap.rule=HostSNI(`*`) + - traefik.tcp.routers.global_ldap.entrypoints=global_ldap + - traefik.tcp.routers.global_ldap.service=global_ldap + - traefik.tcp.services.global_ldap.loadbalancer.server.port=3268 + - traefik.tcp.services.global_ldap.loadbalancer.proxyprotocol.version=2 + + - traefik.tcp.routers.global_ldap_tls.rule=HostSNI(`*`) + - traefik.tcp.routers.global_ldap_tls.entrypoints=global_ldap_tls + - traefik.tcp.routers.global_ldap_tls.service=global_ldap_tls + - traefik.tcp.routers.global_ldap_tls.tls=true + - traefik.tcp.services.global_ldap_tls.loadbalancer.server.port=3269 + - traefik.tcp.services.global_ldap_tls.loadbalancer.proxyprotocol.version=2 + api: image: multidirectory container_name: multidirectory_api @@ -159,7 +214,7 @@ services: postgres: condition: service_healthy restart: true - + cert_check: image: multidirectory container_name: multidirectory_certs_check @@ -229,6 +284,8 @@ services: restart: true cert_check: condition: service_completed_successfully + kdc: + condition: service_started ports: - 8000:8000 working_dir: /server @@ -369,6 +426,7 @@ services: - ./certs:/certs - ./app:/app - ldap_keytab:/LDAP_keytab/ + - kdc:/etc/krb5kdc/ env_file: local.env command: python multidirectory.py --scheduler tty: true diff --git a/interface b/interface index 21b31fed4..f31962020 160000 --- a/interface +++ b/interface @@ -1 +1 @@ -Subproject commit 21b31fed42a5082311a458da4d475c839f99a717 +Subproject commit f31962020a6689e6a4c61fb3349db5b5c7895f92 diff --git a/pyproject.toml b/pyproject.toml index d50c41bf7..f7adf0e26 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ dependencies = [ "dishka>=1.6.0", "dnspython>=2.7.0", "fastapi>=0.115.0", + "fastapi-error-map>=0.9.8", "gssapi>=1.9.0", "httpx>=0.28.1", "jinja2>=3.1.4", @@ -91,6 +92,9 @@ show_missing = true [tool.coverage.run] concurrency = ["thread", "gevent"] +omit = [ + "*/__dishka_factory_*", +] # RUFF # Ruff is a linter, not a type checker. @@ -217,6 +221,7 @@ known-first-party = [ "schedule", "extra", "enums", + "errors", ] known-third-party = [ "alembic", # https://github.com/astral-sh/ruff/issues/10519 diff --git a/syslog-ng.conf b/syslog-ng.conf index b8ffcadb0..152b4cd70 100644 --- a/syslog-ng.conf +++ b/syslog-ng.conf @@ -2,25 +2,30 @@ @include "scl.conf" source s_network { - tcp(ip("0.0.0.0") port(514) flags(no-parse)); - udp(ip("0.0.0.0") port(514) flags(no-parse)); + tcp( + ip("0.0.0.0") + port(514) + flags(syslog-protocol) + ); + + udp( + ip("0.0.0.0") + port(514) + flags(syslog-protocol) + ); }; -destination d_local { - file("/var/log/messages.log" - template("${MESSAGE}\n") +destination d_audit { + file("/var/log/audit/audit.log" + template("$ISODATE ${HOST} ${PROGRAM}[${PID}]: ${MSGID} ${SDATA} ${MSGONLY}\n") create_dirs(yes) perm(0644) - dir_perm(0755)); -}; - -destination d_files { - file("/var/log/${HOST}/${PROGRAM}.log"); + dir_perm(0755) + ); }; log { source(s_network); - destination(d_local); - destination(d_files); + destination(d_audit); }; diff --git a/tests/api_datasets.py b/tests/api_datasets.py index 975199bd6..a9f430ef4 100644 --- a/tests/api_datasets.py +++ b/tests/api_datasets.py @@ -20,7 +20,7 @@ "contains:colon", "contains+plus", "contains*asterisk", - "contains\"doublequotes", # noqa: Q003 + 'contains"doublequotes', "multiple#forbidden=chars<>here", "#starts_with_hash", "ends_with_semicolon;", @@ -39,8 +39,8 @@ ":", "+", "*", - "\"", # noqa: Q003 - "#=<>\\;:+*\"", # noqa: Q003 + '"', + '#=<>\\;:+*"', "", " ", " ", diff --git a/tests/conftest.py b/tests/conftest.py index 58937dbd3..c9ba0f8ff 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -31,6 +31,7 @@ ) from dishka.integrations.fastapi import setup_dishka from fastapi import FastAPI, Request, Response +from loguru import logger from multidirectory import _create_basic_app from sqlalchemy import schema, text from sqlalchemy.ext.asyncio import ( @@ -87,6 +88,7 @@ from ldap_protocol.kerberos.template_render import KRBTemplateRenderer from ldap_protocol.ldap_requests.bind import BindRequest from ldap_protocol.ldap_requests.contexts import ( + LDAPAbandonRequestContext, LDAPAddRequestContext, LDAPBindRequestContext, LDAPDeleteRequestContext, @@ -100,6 +102,9 @@ from ldap_protocol.ldap_schema.attribute_type_use_case import ( AttributeTypeUseCase, ) +from ldap_protocol.ldap_schema.attribute_value_validator import ( + AttributeValueValidator, +) from ldap_protocol.ldap_schema.dto import EntityTypeDTO from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO from ldap_protocol.ldap_schema.entity_type_use_case import EntityTypeUseCase @@ -119,8 +124,13 @@ ) from ldap_protocol.policies.audit.policies_dao import AuditPoliciesDAO from ldap_protocol.policies.audit.service import AuditService -from ldap_protocol.policies.network.gateway import NetworkPolicyGateway -from ldap_protocol.policies.network.use_cases import NetworkPolicyUseCase +from ldap_protocol.policies.network import ( + NetworkPolicyGateway, + NetworkPolicyUseCase, + NetworkPolicyValidatorGateway, + NetworkPolicyValidatorProtocol, + NetworkPolicyValidatorUseCase, +) from ldap_protocol.policies.password import ( PasswordPolicyDAO, PasswordPolicyUseCases, @@ -130,7 +140,10 @@ PasswordBanWordRepository, ) from ldap_protocol.policies.password.settings import PasswordValidatorSettings -from ldap_protocol.policies.password.use_cases import PasswordBanWordUseCases +from ldap_protocol.policies.password.use_cases import ( + PasswordBanWordUseCases, + UserPasswordHistoryUseCases, +) from ldap_protocol.roles.access_manager import AccessManager from ldap_protocol.roles.ace_dao import AccessControlEntryDAO from ldap_protocol.roles.dataclasses import RoleDTO @@ -153,7 +166,7 @@ class TestProvider(Provider): __test__ = False scope = Scope.RUNTIME - settings = from_context(provides=Settings, scope=Scope.RUNTIME) + settings = from_context(provides=Settings, scope=scope) _cached_session: AsyncSession | None = None _cached_kadmin: Mock | None = None _cached_audit_service: Mock | None = None @@ -302,6 +315,10 @@ def get_object_class_dao(self, session: AsyncSession) -> ObjectClassDAO: ) object_class_use_case = provide(ObjectClassUseCase, scope=Scope.REQUEST) + user_password_history_use_cases = provide( + UserPasswordHistoryUseCases, + scope=Scope.REQUEST, + ) password_ban_word_repository = provide( PasswordBanWordRepository, scope=Scope.REQUEST, @@ -328,13 +345,23 @@ def get_object_class_dao(self, session: AsyncSession) -> ObjectClassDAO: PasswordBanWordsFastAPIAdapter, scope=Scope.REQUEST, ) - password_utils = provide(PasswordUtils, scope=Scope.RUNTIME) + password_utils = provide(PasswordUtils, scope=scope) dns_fastapi_adapter = provide(DNSFastAPIAdapter, scope=Scope.REQUEST) dns_use_case = provide(DNSUseCase, scope=Scope.REQUEST) dns_state_gateway = provide(DNSStateGateway, scope=Scope.REQUEST) + network_policy_gateway = provide(NetworkPolicyGateway, scope=Scope.SESSION) + network_policy_validator_gateway = provide( + NetworkPolicyValidatorGateway, + provides=NetworkPolicyValidatorProtocol, + scope=Scope.SESSION, + ) + network_policy_validator = provide( + NetworkPolicyValidatorUseCase, + scope=Scope.SESSION, + ) - @provide(scope=Scope.RUNTIME, provides=AsyncEngine) + @provide(scope=scope, provides=AsyncEngine) def get_engine(self, settings: Settings) -> AsyncEngine: """Get async engine.""" return settings.engine @@ -518,6 +545,7 @@ def get_krb_template_render( audit_policy_dao = provide(AuditPoliciesDAO, scope=Scope.REQUEST) audit_use_case = provide(AuditUseCase, scope=Scope.REQUEST) audit_destination_dao = provide(AuditDestinationDAO, scope=Scope.REQUEST) + attribute_value_validator = provide(AttributeValueValidator, scope=scope) @provide(scope=Scope.REQUEST, provides=AuditService) async def get_audit_service(self) -> AsyncIterator[AsyncMock]: @@ -543,7 +571,7 @@ async def get_audit_service(self) -> AsyncIterator[AsyncMock]: audit_adapter = provide(AuditPoliciesAdapter, scope=Scope.REQUEST) - @provide(scope=Scope.RUNTIME) + @provide(scope=scope) async def get_audit_redis_client( self, settings: Settings, @@ -650,6 +678,11 @@ async def get_audit_monitor( LDAPSearchRequestContext, scope=Scope.REQUEST, ) + abandon_request_context = provide( + LDAPAbandonRequestContext, + scope=Scope.REQUEST, + ) + unbind_request_context = provide( LDAPUnbindRequestContext, scope=Scope.REQUEST, @@ -678,7 +711,6 @@ async def get_audit_monitor( NetworkPolicyUseCase, scope=Scope.REQUEST, ) - network_policy_gateway = provide(NetworkPolicyGateway, scope=Scope.REQUEST) @provide( provides=AuthorizationProviderProtocol, @@ -820,6 +852,20 @@ async def add_schema( ) +class TestMigrationProvider(Provider): + """Provider for migrations.""" + + async_conn = from_context(provides=AsyncConnection, scope=Scope.RUNTIME) + + @provide(scope=Scope.APP, cache=False) + def get_session_factory( + self, + async_conn: AsyncConnection, + ) -> AsyncSession: + """Create session factory.""" + return AsyncSession(async_conn) + + @pytest_asyncio.fixture(scope="session", autouse=True) async def _migrations( add_schema: None, # noqa: ARG001 @@ -840,13 +886,26 @@ def downgrade(conn: AsyncConnection) -> None: config.attributes["connection"] = conn command.downgrade(config, "base") + test_migration_provider = TestMigrationProvider() async with engine.begin() as conn: config.attributes["connection"] = conn + config.attributes["dishka_container"] = make_async_container( + TestProvider(), + test_migration_provider, + context={Settings: settings, AsyncConnection: conn}, + start_scope=Scope.RUNTIME, + ) await conn.run_sync(upgrade) # type: ignore yield async with engine.begin() as conn: + config.attributes["dishka_container"] = make_async_container( + TestProvider(), + test_migration_provider, + context={Settings: settings, AsyncConnection: conn}, + start_scope=Scope.RUNTIME, + ) await conn.run_sync(downgrade) # type: ignore @@ -879,13 +938,18 @@ async def setup_session( ) -> None: """Get session and acquire after completion.""" object_class_dao = ObjectClassDAO(session) - entity_type_dao = EntityTypeDAO(session, object_class_dao=object_class_dao) + attribute_value_validator = AttributeValueValidator() + entity_type_dao = EntityTypeDAO( + session, + object_class_dao=object_class_dao, + attribute_value_validator=attribute_value_validator, + ) for entity_type_data in ENTITY_TYPE_DATAS: await entity_type_dao.create( dto=EntityTypeDTO( id=None, - name=entity_type_data["name"], # type: ignore - object_class_names=entity_type_data["object_class_names"], # type: ignore + name=entity_type_data["name"], + object_class_names=entity_type_data["object_class_names"], is_system=True, ), ) @@ -899,7 +963,10 @@ async def setup_session( audit_destination_dao, raw_audit_manager, ) - password_policy_dao = PasswordPolicyDAO(session) + password_policy_dao = PasswordPolicyDAO( + session, + attribute_value_validator=attribute_value_validator, + ) password_policy_validator = PasswordPolicyValidator( PasswordValidatorSettings(), password_utils, @@ -910,9 +977,18 @@ async def setup_session( password_policy_validator, password_ban_word_repository, ) - setup_gateway = SetupGateway(session, password_utils, entity_type_dao) + setup_gateway = SetupGateway( + session, + password_utils, + entity_type_dao, + attribute_value_validator=attribute_value_validator, + ) await audit_use_case.create_policies() - await setup_gateway.setup_enviroment(dn="md.test", data=TEST_DATA) + await setup_gateway.setup_enviroment( + dn="md.test", + data=TEST_DATA, + is_system=False, + ) # NOTE: after setup environment we need base DN to be created await password_use_cases.create_default_domain_policy() @@ -979,6 +1055,24 @@ async def ldap_bound_session( return +@pytest_asyncio.fixture(scope="function") +async def network_policy_gateway( + container: AsyncContainer, +) -> AsyncIterator[NetworkPolicyGateway]: + """Get network policy gateway.""" + async with container(scope=Scope.SESSION) as container: + yield await container.get(NetworkPolicyGateway) + + +@pytest_asyncio.fixture(scope="function") +async def network_policy_validator( + container: AsyncContainer, +) -> AsyncIterator[NetworkPolicyValidatorUseCase]: + """Get network policy validator.""" + async with container(scope=Scope.SESSION) as container: + yield await container.get(NetworkPolicyValidatorUseCase) + + @pytest_asyncio.fixture(scope="session") async def handler( settings: Settings, @@ -986,8 +1080,9 @@ async def handler( ) -> AsyncIterator[PoolClientHandler]: """Create test handler.""" settings.set_test_port() + test_log = logger.bind(name="ldap_test") async with container(scope=Scope.APP) as app_scope: - yield PoolClientHandler(settings, app_scope) + yield PoolClientHandler(settings, app_scope, test_log) # type: ignore @pytest_asyncio.fixture(scope="function") @@ -998,7 +1093,14 @@ async def entity_type_dao( async with container(scope=Scope.APP) as container: session = await container.get(AsyncSession) object_class_dao = ObjectClassDAO(session) - yield EntityTypeDAO(session, object_class_dao) + attribute_value_validator = await container.get( + AttributeValueValidator, + ) + yield EntityTypeDAO( + session, + object_class_dao, + attribute_value_validator=attribute_value_validator, + ) @pytest_asyncio.fixture(scope="function") @@ -1008,7 +1110,13 @@ async def password_policy_dao( """Get session and acquire after completion.""" async with container(scope=Scope.APP) as container: session = await container.get(AsyncSession) - yield PasswordPolicyDAO(session) + attribute_value_validator = await container.get( + AttributeValueValidator, + ) + yield PasswordPolicyDAO( + session, + attribute_value_validator=attribute_value_validator, + ) @pytest_asyncio.fixture(scope="function") diff --git a/tests/constants.py b/tests/constants.py index 548e681ec..5542e0742 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -4,11 +4,18 @@ License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE """ +from constants import ( + DOMAIN_ADMIN_GROUP_NAME, + DOMAIN_COMPUTERS_GROUP_NAME, + DOMAIN_USERS_GROUP_NAME, + GROUPS_CONTAINER_NAME, + USERS_CONTAINER_NAME, +) from ldap_protocol.objects import UserAccountControlFlag TEST_DATA = [ { - "name": "groups", + "name": GROUPS_CONTAINER_NAME, "object_class": "container", "attributes": { "objectClass": ["top"], @@ -16,20 +23,20 @@ }, "children": [ { - "name": "domain admins", + "name": DOMAIN_ADMIN_GROUP_NAME, "object_class": "group", "attributes": { "objectClass": ["top", "posixGroup"], "groupType": ["-2147483646"], "instanceType": ["4"], - "sAMAccountName": ["domain admins"], + "sAMAccountName": [DOMAIN_ADMIN_GROUP_NAME], "sAMAccountType": ["268435456"], }, }, { "name": "developers", "object_class": "group", - "groups": ["domain admins"], + "groups": [DOMAIN_ADMIN_GROUP_NAME], "attributes": { "objectClass": ["top", "posixGroup"], "groupType": ["-2147483646"], @@ -50,31 +57,31 @@ }, }, { - "name": "domain users", + "name": DOMAIN_USERS_GROUP_NAME, "object_class": "group", "attributes": { "objectClass": ["top", "posixGroup"], "groupType": ["-2147483646"], "instanceType": ["4"], - "sAMAccountName": ["domain users"], + "sAMAccountName": [DOMAIN_USERS_GROUP_NAME], "sAMAccountType": ["268435456"], }, }, { - "name": "domain computers", + "name": DOMAIN_COMPUTERS_GROUP_NAME, "object_class": "group", "attributes": { "objectClass": ["top", "posixGroup"], "groupType": ["-2147483646"], "instanceType": ["4"], - "sAMAccountName": ["domain computers"], + "sAMAccountName": [DOMAIN_COMPUTERS_GROUP_NAME], "sAMAccountType": ["268435456"], }, }, ], }, { - "name": "users", + "name": USERS_CONTAINER_NAME, "object_class": "container", "attributes": {"objectClass": ["top"]}, "children": [ @@ -87,9 +94,7 @@ "mail": "user0@mail.com", "display_name": "user0", "password": "password", - "groups": [ - "domain admins", - ], + "groups": [DOMAIN_ADMIN_GROUP_NAME], }, "attributes": { "givenName": ["John"], @@ -119,9 +124,7 @@ "mail": "user_admin@mail.com", "display_name": "user_admin", "password": "password", - "groups": [ - "domain admins", - ], + "groups": [DOMAIN_ADMIN_GROUP_NAME], }, "attributes": { "objectClass": [ @@ -148,9 +151,7 @@ "mail": "user_admin_for_roles@mail.com", "display_name": "user_admin_for_roles", "password": "password", - "groups": [ - "admin login only", - ], + "groups": ["admin login only"], }, "attributes": { "objectClass": [ @@ -177,7 +178,7 @@ "mail": "user_non_admin@mail.com", "display_name": "user_non_admin", "password": "password", - "groups": ["domain users"], + "groups": [DOMAIN_USERS_GROUP_NAME], }, "attributes": { "objectClass": [ @@ -407,3 +408,29 @@ ], }, ] + +TEST_SYSTEM_ADMIN_DATA = { + "name": "System Administrator", + "object_class": "user", + "organizationalPerson": { + "sam_account_name": "system_admin", + "user_principal_name": "system_admin", + "mail": "system_admin@mail.com", + "display_name": "system_admin", + "password": "password", + "groups": [DOMAIN_ADMIN_GROUP_NAME], + }, + "attributes": { + "objectClass": [ + "top", + "person", + "organizationalPerson", + "posixAccount", + "inetOrgPerson", + "shadowAccount", + ], + "posixEmail": ["abctest@mail.com"], + "attr_with_bvalue": [b"any"], + "userAccountControl": [str(UserAccountControlFlag.NORMAL_ACCOUNT)], + }, +} diff --git a/tests/test_api/test_auth/test_identity_provider.py b/tests/test_api/test_auth/test_identity_provider.py index 4fff16484..788da08e5 100644 --- a/tests/test_api/test_auth/test_identity_provider.py +++ b/tests/test_api/test_auth/test_identity_provider.py @@ -13,7 +13,7 @@ make_async_container, provide, ) -from fastapi import HTTPException, status +from fastapi import status from httpx import AsyncClient from starlette.requests import Request @@ -21,7 +21,7 @@ from config import Settings from ldap_protocol.dialogue import UserSchema from ldap_protocol.identity import IdentityProvider -from ldap_protocol.identity.exceptions import UnauthorizedError +from ldap_protocol.identity.exceptions import ErrorCodes, UnauthorizedError from ldap_protocol.identity.provider_gateway import IdentityProviderGateway from ldap_protocol.session_storage.base import SessionStorage from ldap_protocol.session_storage.exceptions import ( @@ -113,10 +113,7 @@ async def invalid_user_provider( ) as cont: provider = await cont.get(IdentityProvider) provider.get_user_id = AsyncMock( # type: ignore - side_effect=HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Could not validate credentials", - ), + side_effect=UnauthorizedError(ErrorCodes.UNAUTHORIZED_ERROR), ) yield provider diff --git a/tests/test_api/test_auth/test_router.py b/tests/test_api/test_auth/test_router.py index 30c852578..c13c0a5a6 100644 --- a/tests/test_api/test_auth/test_router.py +++ b/tests/test_api/test_auth/test_router.py @@ -13,7 +13,6 @@ from fastapi import status from httpx import AsyncClient from jose import jwt -from password_utils import PasswordUtils from sqlalchemy import select, update from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload @@ -26,6 +25,7 @@ from ldap_protocol.ldap_requests.modify import Operation from ldap_protocol.session_storage import SessionStorage from ldap_protocol.utils.queries import get_search_path +from password_utils import PasswordUtils from repo.pg.tables import queryable_attr as qa from tests.conftest import TestCreds @@ -495,7 +495,7 @@ async def test_auth_disabled_user( }, ) - assert response.status_code == 403 + assert response.status_code == 400 @pytest.mark.asyncio diff --git a/tests/test_api/test_auth/test_sessions.py b/tests/test_api/test_auth/test_sessions.py index a52692197..59b11208c 100644 --- a/tests/test_api/test_auth/test_sessions.py +++ b/tests/test_api/test_auth/test_sessions.py @@ -155,7 +155,7 @@ async def test_session_api_get( assert storage_data[k]["sign"] == data["sign"] response = await http_client.get(f"sessions/{creds.un}123") - assert response.status_code == 404 + assert response.status_code == 400 assert response.json()["detail"] == "User not found." @@ -175,7 +175,7 @@ async def test_session_api_delete( assert len(storage_data) == 1 response = await http_client.delete(f"sessions/{creds.un}123") - assert response.status_code == 404 + assert response.status_code == 400 response = await http_client.delete(f"sessions/{creds.un}") assert response.status_code == 204 diff --git a/tests/test_api/test_dhcp/test_adapter.py b/tests/test_api/test_dhcp/test_adapter.py index d8816d887..5d2dd4b26 100644 --- a/tests/test_api/test_dhcp/test_adapter.py +++ b/tests/test_api/test_dhcp/test_adapter.py @@ -8,9 +8,9 @@ from unittest.mock import Mock import pytest -from authorization_provider_protocol import AuthorizationProviderProtocol from api.dhcp.adapter import DHCPAdapter +from authorization_provider_protocol import AuthorizationProviderProtocol from ldap_protocol.dhcp.dataclasses import ( DHCPLease, DHCPOptionData, diff --git a/tests/test_api/test_dhcp/test_router.py b/tests/test_api/test_dhcp/test_router.py index b912310c4..fbb739816 100644 --- a/tests/test_api/test_dhcp/test_router.py +++ b/tests/test_api/test_dhcp/test_router.py @@ -166,7 +166,7 @@ async def test_create_subnet_api_error( json=sample_subnet_data, ) - assert response.status_code == status.HTTP_409_CONFLICT + assert response.status_code == status.HTTP_400_BAD_REQUEST assert "Subnet already exists" in response.json()["detail"] @@ -235,7 +235,7 @@ async def test_update_subnet_not_found( json=sample_subnet_data, ) - assert response.status_code == status.HTTP_404_NOT_FOUND + assert response.status_code == status.HTTP_400_BAD_REQUEST @pytest.mark.asyncio @@ -262,7 +262,7 @@ async def test_delete_subnet_not_found( response = await http_client.delete("/dhcp/subnet/999") - assert response.status_code == status.HTTP_404_NOT_FOUND + assert response.status_code == status.HTTP_400_BAD_REQUEST @pytest.mark.asyncio @@ -316,7 +316,7 @@ async def test_create_lease_api_error( json=sample_lease_data, ) - assert response.status_code == status.HTTP_409_CONFLICT + assert response.status_code == status.HTTP_400_BAD_REQUEST assert "IP already in use" in response.json()["detail"] @@ -451,7 +451,7 @@ async def test_delete_lease_not_found( response = await http_client.delete("/dhcp/lease/192.168.1.128") - assert response.status_code == status.HTTP_404_NOT_FOUND + assert response.status_code == status.HTTP_400_BAD_REQUEST @pytest.mark.asyncio @@ -505,7 +505,7 @@ async def test_create_reservation_api_error( json=sample_reservation_data, ) - assert response.status_code == status.HTTP_409_CONFLICT + assert response.status_code == status.HTTP_400_BAD_REQUEST assert "IP already reserved" in response.json()["detail"] @@ -587,7 +587,7 @@ async def test_delete_reservation_not_found( }, ) - assert response.status_code == status.HTTP_404_NOT_FOUND + assert response.status_code == status.HTTP_400_BAD_REQUEST @pytest.mark.asyncio @@ -666,7 +666,7 @@ async def test_lease_to_reservation_not_found( json=[sample_reservation_data], ) - assert response.status_code == status.HTTP_404_NOT_FOUND + assert response.status_code == status.HTTP_400_BAD_REQUEST @pytest.mark.asyncio @@ -720,4 +720,4 @@ async def test_update_reservation_not_found( json=sample_reservation_data, ) - assert response.status_code == status.HTTP_404_NOT_FOUND + assert response.status_code == status.HTTP_400_BAD_REQUEST diff --git a/tests/test_api/test_ldap_schema/test_attribute_type_router.py b/tests/test_api/test_ldap_schema/test_attribute_type_router.py index 763c34514..bc9018948 100644 --- a/tests/test_api/test_ldap_schema/test_attribute_type_router.py +++ b/tests/test_api/test_ldap_schema/test_attribute_type_router.py @@ -73,7 +73,7 @@ async def test_create_attribute_type_conflict_when_already_exists( "/schema/attribute_type", json=schema.model_dump(), ) - assert response.status_code == status.HTTP_409_CONFLICT + assert response.status_code == status.HTTP_400_BAD_REQUEST @pytest.mark.asyncio @@ -110,7 +110,7 @@ async def test_modify_one_attribute_type_raise_404( "/schema/attribute_type/testAttributeType12345", json=schema.model_dump(), ) - assert response.status_code == status.HTTP_404_NOT_FOUND + assert response.status_code == status.HTTP_400_BAD_REQUEST @pytest.mark.parametrize( @@ -176,4 +176,4 @@ async def test_delete_bulk_attribute_types( response = await http_client.get( f"/schema/attribute_type/{attribute_type_name}", ) - assert response.status_code == status.HTTP_404_NOT_FOUND + assert response.status_code == status.HTTP_400_BAD_REQUEST diff --git a/tests/test_api/test_ldap_schema/test_attribute_type_router_datasets.py b/tests/test_api/test_ldap_schema/test_attribute_type_router_datasets.py index f4085e44e..bcfea7210 100644 --- a/tests/test_api/test_ldap_schema/test_attribute_type_router_datasets.py +++ b/tests/test_api/test_ldap_schema/test_attribute_type_router_datasets.py @@ -41,7 +41,7 @@ "no_user_modification": False, "is_included_anr": False, }, - "status_code": status.HTTP_404_NOT_FOUND, + "status_code": status.HTTP_400_BAD_REQUEST, }, { "attribute_type_name": "testAttributeType2", diff --git a/tests/test_api/test_ldap_schema/test_entity_type_router.py b/tests/test_api/test_ldap_schema/test_entity_type_router.py index c72d8c991..c130c2067 100644 --- a/tests/test_api/test_ldap_schema/test_entity_type_router.py +++ b/tests/test_api/test_ldap_schema/test_entity_type_router.py @@ -5,6 +5,7 @@ from httpx import AsyncClient from constants import ENTITY_TYPE_DATAS +from enums import EntityTypeNames from .test_entity_type_router_datasets import ( test_create_one_entity_type_dataset, @@ -60,7 +61,7 @@ async def test_create_one_entity_type_value_400( "is_system": False, }, ) - assert response.status_code == status.HTTP_404_NOT_FOUND + assert response.status_code == status.HTTP_400_BAD_REQUEST @pytest.mark.asyncio @@ -153,14 +154,14 @@ async def test_modify_entity_type_with_duplicate_data( f"/schema/entity_type/{update_entity}", json=update_data, ) - assert response.status_code == status.HTTP_403_FORBIDDEN + assert response.status_code == status.HTTP_400_BAD_REQUEST update_entity, update_data = new_statements["duplicate_name"] response = await http_client.patch( f"/schema/entity_type/{update_entity}", json=update_data, ) - assert response.status_code == status.HTTP_403_FORBIDDEN + assert response.status_code == status.HTTP_400_BAD_REQUEST @pytest.mark.parametrize( @@ -223,7 +224,7 @@ async def test_modify_primary_entity_type_name( "object_class_names": entity_type_data["object_class_names"], }, ) - assert response.status_code == status.HTTP_403_FORBIDDEN + assert response.status_code == status.HTTP_400_BAD_REQUEST response = await http_client.get( f"/schema/entity_type/{entity_type_data['name']}", @@ -267,14 +268,14 @@ async def test_delete_bulk_entries( response = await http_client.get( f"/schema/entity_type/{entity_type_name}", ) - assert response.status_code == status.HTTP_404_NOT_FOUND + assert response.status_code == status.HTTP_400_BAD_REQUEST @pytest.mark.asyncio @pytest.mark.usefixtures("session") async def test_delete_entry_with_directory(http_client: AsyncClient) -> None: """Test deleting entry with directory.""" - entity_type_name = "User" + entity_type_name = EntityTypeNames.USER response = await http_client.post( "/schema/entity_type/delete", json={"entity_type_names": [entity_type_name]}, diff --git a/tests/test_api/test_ldap_schema/test_object_class_router.py b/tests/test_api/test_ldap_schema/test_object_class_router.py index e371fea40..6e04cecdc 100644 --- a/tests/test_api/test_ldap_schema/test_object_class_router.py +++ b/tests/test_api/test_ldap_schema/test_object_class_router.py @@ -5,6 +5,7 @@ from httpx import AsyncClient from api.ldap_schema.schema import ObjectClassUpdateSchema +from enums import EntityTypeNames from .test_object_class_router_datasets import ( test_create_one_object_class_dataset, @@ -15,7 +16,7 @@ @pytest.mark.asyncio -async def test_get_one_extended_object_class( +async def test_get_extended_object_classes( http_client: AsyncClient, ) -> None: """Test getting a single extended object class.""" @@ -25,7 +26,10 @@ async def test_get_one_extended_object_class( assert response.status_code == status.HTTP_200_OK data = response.json() assert isinstance(data, dict) - assert data.get("entity_type_names") == ["User"] + assert set(data.get("entity_type_names")) == { # type: ignore + EntityTypeNames.CONTACT, + EntityTypeNames.USER, + } @pytest.mark.parametrize( @@ -87,7 +91,7 @@ async def test_create_object_class_type_conflict_when_already_exists( "/schema/object_class", json=dataset["object_class"], ) - assert response.status_code == status.HTTP_409_CONFLICT + assert response.status_code == status.HTTP_400_BAD_REQUEST @pytest.mark.asyncio @@ -108,7 +112,7 @@ async def test_modify_system_object_class(http_client: AsyncClient) -> None: f"/schema/object_class/{object_class_name}", json=request_data.model_dump(), ) - assert response.status_code == status.HTTP_403_FORBIDDEN + assert response.status_code == status.HTTP_400_BAD_REQUEST break else: pytest.fail("No system object class") @@ -203,7 +207,7 @@ async def test_delete_bulk_object_classes( response = await http_client.get( f"/schema/object_class/{object_class_name}", ) - assert response.status_code == status.HTTP_404_NOT_FOUND + assert response.status_code == status.HTTP_400_BAD_REQUEST @pytest.mark.parametrize( diff --git a/tests/test_api/test_main/conftest.py b/tests/test_api/test_main/conftest.py index bdf9e3f4e..8f1b58dea 100644 --- a/tests/test_api/test_main/conftest.py +++ b/tests/test_api/test_main/conftest.py @@ -47,9 +47,21 @@ async def adding_test_user( "type": "testing_attr", "vals": ["test"], }, + { + "type": "sAMAccountName", + "vals": ["test"], + }, { "type": "objectClass", - "vals": ["organization", "top", "user"], + "vals": [ + "top", + "user", + "person", + "organizationalPerson", + "posixAccount", + "shadowAccount", + "inetOrgPerson", + ], }, ], }, @@ -62,13 +74,6 @@ async def adding_test_user( json={ "object": test_user_dn, "changes": [ - { - "operation": Operation.ADD, - "modification": { - "type": "sAMAccountName", - "vals": ["Test"], - }, - }, { "operation": Operation.ADD, "modification": { diff --git a/tests/test_api/test_main/test_kadmin.py b/tests/test_api/test_main/test_kadmin.py index 0bb9c1879..0ffbd6ebe 100644 --- a/tests/test_api/test_main/test_kadmin.py +++ b/tests/test_api/test_main/test_kadmin.py @@ -60,7 +60,7 @@ def _create_test_user_data( async def test_tree_creation( http_client: AsyncClient, ctx_bind: LDAPBindRequestContext, - password_utils: PasswordUtils, + password_utils: PasswordUtils, # noqa: ARG001 ) -> None: """Test tree creation.""" krbadmin_pw = "Password123" @@ -77,7 +77,7 @@ async def test_tree_creation( response = await http_client.post( "entry/search", json={ - "base_object": "ou=services,dc=md,dc=test", + "base_object": "ou=System,dc=md,dc=test", "scope": 0, "deref_aliases": 0, "size_limit": 1000, @@ -90,7 +90,7 @@ async def test_tree_creation( ) assert ( response.json()["search_result"][0]["object_name"] - == "ou=services,dc=md,dc=test" + == "ou=System,dc=md,dc=test" ) bind = MutePolicyBindRequest( @@ -125,7 +125,7 @@ async def test_tree_collision(http_client: AsyncClient) -> None: }, ) - assert response.status_code == status.HTTP_409_CONFLICT + assert response.status_code == status.HTTP_400_BAD_REQUEST @pytest.mark.asyncio @@ -157,13 +157,13 @@ async def test_setup_call( kdc_doc = kadmin.setup.call_args.kwargs.pop("kdc_config").encode() # NOTE: Asserting documents integrity, tests template rendering - assert blake2b(krb_doc, digest_size=8).hexdigest() == "f433bbc7df5a236b" + assert blake2b(krb_doc, digest_size=8).hexdigest() == "0567ec28b8ccca51" assert blake2b(kdc_doc, digest_size=8).hexdigest() == "79e43649d34fe577" assert kadmin.setup.call_args.kwargs == { "domain": "md.test", "admin_dn": "cn=user0,cn=users,dc=md,dc=test", - "services_dn": "ou=services,dc=md,dc=test", + "services_dn": "ou=System,dc=md,dc=test", "krbadmin_dn": "cn=krbadmin,cn=users,dc=md,dc=test", "krbadmin_password": "Password123", "ldap_keytab_path": "/LDAP_keytab/ldap.keytab", @@ -228,21 +228,21 @@ async def test_ktadd( @pytest.mark.asyncio @pytest.mark.usefixtures("session") -async def test_ktadd_404( +async def test_ktadd_400( http_client: AsyncClient, kadmin: AbstractKadmin, ) -> None: """Test ktadd failure. - :param AsyncClient http_client: http cl - :param LDAPSession ldap_session: ldap + :param AsyncClient http_client: http client + :param AbstractKadmin kadmin: kadmin """ kadmin.ktadd.side_effect = KRBAPIPrincipalNotFoundError() # type: ignore names = ["test1", "test2"] response = await http_client.post("/kerberos/ktadd", json=names) - assert response.status_code == status.HTTP_404_NOT_FOUND + assert response.status_code == status.HTTP_400_BAD_REQUEST @pytest.mark.asyncio @@ -528,4 +528,4 @@ async def test_update_password( "old_password": "password", }, ) - assert response.status_code == status.HTTP_424_FAILED_DEPENDENCY + assert response.status_code == status.HTTP_400_BAD_REQUEST diff --git a/tests/test_api/test_main/test_router/conftest.py b/tests/test_api/test_main/test_router/conftest.py new file mode 100644 index 000000000..5ec37b884 --- /dev/null +++ b/tests/test_api/test_main/test_router/conftest.py @@ -0,0 +1,49 @@ +"""Test router config. + +Copyright (c) 2026 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +import pytest_asyncio +from sqlalchemy.ext.asyncio import AsyncSession + +from ldap_protocol.auth.setup_gateway import SetupGateway +from ldap_protocol.ldap_schema.attribute_value_validator import ( + AttributeValueValidator, +) +from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO +from ldap_protocol.ldap_schema.object_class_dao import ObjectClassDAO +from ldap_protocol.utils.queries import get_base_directories +from password_utils import PasswordUtils +from tests.constants import TEST_SYSTEM_ADMIN_DATA + + +@pytest_asyncio.fixture(scope="function") +async def add_system_administrator( + session: AsyncSession, + password_utils: PasswordUtils, + setup_session: None, # noqa: ARG001 +) -> None: + """Create system administrator user for tests that require it.""" + object_class_dao = ObjectClassDAO(session) + attribute_value_validator = AttributeValueValidator() + entity_type_dao = EntityTypeDAO( + session, + object_class_dao=object_class_dao, + attribute_value_validator=attribute_value_validator, + ) + + setup_gateway = SetupGateway( + session, + password_utils, + entity_type_dao, + attribute_value_validator=attribute_value_validator, + ) + + domain = (await get_base_directories(session))[0] + await setup_gateway.create_dir( + data=TEST_SYSTEM_ADMIN_DATA, + is_system=True, + domain=domain, + parent=domain, + ) diff --git a/tests/test_api/test_main/test_router/test_add.py b/tests/test_api/test_main/test_router/test_add.py index 8efd2d98c..3050bedec 100644 --- a/tests/test_api/test_main/test_router/test_add.py +++ b/tests/test_api/test_main/test_router/test_add.py @@ -23,24 +23,87 @@ async def test_api_correct_add(http_client: AsyncClient) -> None: "entry": "cn=test,dc=md,dc=test", "password": None, "attributes": [ + {"type": "name", "vals": ["test"]}, + {"type": "cn", "vals": ["test"]}, + {"type": "objectClass", "vals": ["organization", "top"]}, { - "type": "name", - "vals": ["test"], + "type": "memberOf", + "vals": ["cn=domain admins,cn=groups,dc=md,dc=test"], }, + ], + }, + ) + + data = response.json() + + assert isinstance(data, dict) + assert response.status_code == status.HTTP_200_OK + assert data.get("resultCode") == LDAPCodes.SUCCESS + assert data.get("errorMessage") == "" + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("session") +async def test_api_add_incorrect_computer_name( + http_client: AsyncClient, +) -> None: + """Test api incorrect (name) add.""" + response = await http_client.post( + "/entry/add", + json={ + "entry": "cn=test,dc=md,dc=test", + "password": None, + "attributes": [ + {"type": "name", "vals": [" test;incorrect"]}, + {"type": "cn", "vals": ["test"]}, + {"type": "objectClass", "vals": ["computer", "top"]}, { - "type": "cn", - "vals": ["test"], + "type": "memberOf", + "vals": ["cn=domain admins,cn=groups,dc=md,dc=test"], }, + ], + }, + ) + + data = response.json() + + assert isinstance(data, dict) + assert data.get("resultCode") == LDAPCodes.UNDEFINED_ATTRIBUTE_TYPE + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("session") +async def test_api_add_incorrect_user_samaccount_with_dot( + http_client: AsyncClient, +) -> None: + """Test api incorrect (sAMAccountName) add.""" + un = "test0" + + response = await http_client.post( + "/entry/add", + json={ + "entry": "cn=test0,dc=md,dc=test", + "password": "P@ssw0rd", + "attributes": [ + {"type": "name", "vals": [un]}, + {"type": "cn", "vals": [un]}, { "type": "objectClass", - "vals": ["organization", "top"], - }, - { - "type": "memberOf", "vals": [ - "cn=domain admins,cn=groups,dc=md,dc=test", + "top", + "user", + "person", + "organizationalPerson", + "posixAccount", + "shadowAccount", + "inetOrgPerson", ], }, + {"type": "sAMAccountName", "vals": ["test0."]}, + {"type": "userPrincipalName", "vals": [f"{un}@md.ru"]}, + {"type": "mail", "vals": [f"{un}@md.ru"]}, + {"type": "displayName", "vals": [un]}, + {"type": "userAccountControl", "vals": ["516"]}, ], }, ) @@ -48,9 +111,7 @@ async def test_api_correct_add(http_client: AsyncClient) -> None: data = response.json() assert isinstance(data, dict) - assert response.status_code == status.HTTP_200_OK - assert data.get("resultCode") == LDAPCodes.SUCCESS - assert data.get("errorMessage") == "" + assert data.get("resultCode") == LDAPCodes.UNDEFINED_ATTRIBUTE_TYPE @pytest.mark.asyncio diff --git a/tests/test_api/test_main/test_router/test_delete.py b/tests/test_api/test_main/test_router/test_delete.py index 59b714c95..0aa777cec 100644 --- a/tests/test_api/test_main/test_router/test_delete.py +++ b/tests/test_api/test_main/test_router/test_delete.py @@ -29,7 +29,26 @@ async def test_api_correct_delete(http_client: AsyncClient) -> None: @pytest.mark.asyncio -@pytest.mark.usefixtures("adding_test_user") +@pytest.mark.usefixtures("setup_session") +@pytest.mark.usefixtures("session") +@pytest.mark.usefixtures("add_system_administrator") +async def test_api_cant_delete_system_directory( + http_client: AsyncClient, +) -> None: + """Test API for delete system directory.""" + response = await http_client.request( + "delete", + "/entry/delete", + json={"entry": "cn=System Administrator,dc=md,dc=test"}, + ) + + data = response.json() + + assert isinstance(data, dict) + assert data.get("resultCode") == LDAPCodes.UNWILLING_TO_PERFORM + + +@pytest.mark.asyncio @pytest.mark.usefixtures("setup_session") @pytest.mark.usefixtures("session") async def test_api_delete_with_incorrect_dn(http_client: AsyncClient) -> None: @@ -49,7 +68,6 @@ async def test_api_delete_with_incorrect_dn(http_client: AsyncClient) -> None: @pytest.mark.asyncio -@pytest.mark.usefixtures("adding_test_user") @pytest.mark.usefixtures("setup_session") @pytest.mark.usefixtures("session") async def test_api_delete_non_exist_object(http_client: AsyncClient) -> None: diff --git a/tests/test_api/test_main/test_router/test_login.py b/tests/test_api/test_main/test_router/test_login.py index 216e8448b..e83f1cb64 100644 --- a/tests/test_api/test_main/test_router/test_login.py +++ b/tests/test_api/test_main/test_router/test_login.py @@ -53,7 +53,7 @@ async def test_api_auth_after_change_account_exp( }, ) - assert auth.status_code == status.HTTP_403_FORBIDDEN + assert auth.status_code == status.HTTP_400_BAD_REQUEST await http_client.patch( "/entry/update", diff --git a/tests/test_api/test_main/test_router/test_modify.py b/tests/test_api/test_main/test_router/test_modify.py index b5359f51a..3e46e879d 100644 --- a/tests/test_api/test_main/test_router/test_modify.py +++ b/tests/test_api/test_main/test_router/test_modify.py @@ -67,7 +67,6 @@ async def test_api_correct_modify(http_client: AsyncClient) -> None: @pytest.mark.asyncio -@pytest.mark.usefixtures("adding_test_user") @pytest.mark.usefixtures("setup_session") @pytest.mark.usefixtures("session") async def test_api_duplicate_with_spaces_modify( @@ -204,7 +203,6 @@ async def test_api_modify_many(http_client: AsyncClient) -> None: @pytest.mark.asyncio -@pytest.mark.usefixtures("adding_test_user") @pytest.mark.usefixtures("setup_session") @pytest.mark.usefixtures("session") async def test_api_modify_with_incorrect_dn(http_client: AsyncClient) -> None: @@ -258,7 +256,6 @@ async def test_api_modify_non_exist_object(http_client: AsyncClient) -> None: @pytest.mark.asyncio -@pytest.mark.usefixtures("adding_test_user") @pytest.mark.usefixtures("setup_session") @pytest.mark.usefixtures("session") async def test_api_correct_modify_replace_memberof( @@ -398,7 +395,6 @@ async def test_api_modify_replace_loop_detect_member( @pytest.mark.asyncio -@pytest.mark.usefixtures("adding_test_user") @pytest.mark.usefixtures("setup_session") @pytest.mark.usefixtures("session") async def test_api_modify_replace_loop_detect_memberof( @@ -429,7 +425,6 @@ async def test_api_modify_replace_loop_detect_memberof( @pytest.mark.asyncio -@pytest.mark.usefixtures("adding_test_user") @pytest.mark.usefixtures("session") async def test_api_modify_incorrect_uac(http_client: AsyncClient) -> None: """Test API for modify object attribute.""" @@ -454,7 +449,6 @@ async def test_api_modify_incorrect_uac(http_client: AsyncClient) -> None: @pytest.mark.asyncio -@pytest.mark.usefixtures("adding_test_user") @pytest.mark.usefixtures("setup_session") @pytest.mark.usefixtures("session") async def test_qpi_modify_primary_object_classes( diff --git a/tests/test_api/test_main/test_router/test_modify_dn.py b/tests/test_api/test_main/test_router/test_modify_dn.py index 6e5d71cfc..b27360dae 100644 --- a/tests/test_api/test_main/test_router/test_modify_dn.py +++ b/tests/test_api/test_main/test_router/test_modify_dn.py @@ -11,7 +11,6 @@ @pytest.mark.asyncio -@pytest.mark.usefixtures("adding_test_user") @pytest.mark.usefixtures("setup_session") @pytest.mark.usefixtures("session") async def test_api_modify_dn_without_level_change( @@ -80,7 +79,6 @@ async def test_api_modify_dn_without_level_change( @pytest.mark.asyncio -@pytest.mark.usefixtures("adding_test_user") @pytest.mark.usefixtures("setup_session") @pytest.mark.usefixtures("session") async def test_api_modify_dn_with_level_down( @@ -93,7 +91,7 @@ async def test_api_modify_dn_with_level_down( response = await http_client.post( "entry/search", json={ - "base_object": "cn=testGroup1,ou=testModifyDn2,ou=testModifyDn1,dc=md,dc=test", + "base_object": "cn=testGroup1,ou=testModifyDn2,ou=testModifyDn1,dc=md,dc=test", # noqa: E501 "scope": 0, "deref_aliases": 0, "size_limit": 1000, @@ -149,7 +147,6 @@ async def test_api_modify_dn_with_level_down( @pytest.mark.asyncio -@pytest.mark.usefixtures("adding_test_user") @pytest.mark.usefixtures("setup_session") @pytest.mark.usefixtures("session") async def test_api_modify_dn_with_level_up( @@ -218,7 +215,6 @@ async def test_api_modify_dn_with_level_up( @pytest.mark.asyncio -@pytest.mark.usefixtures("adding_test_user") @pytest.mark.usefixtures("setup_session") @pytest.mark.usefixtures("session") async def test_api_correct_update_dn(http_client: AsyncClient) -> None: @@ -338,7 +334,6 @@ async def test_api_correct_update_dn(http_client: AsyncClient) -> None: @pytest.mark.asyncio -@pytest.mark.usefixtures("adding_test_user") @pytest.mark.usefixtures("setup_session") @pytest.mark.usefixtures("session") async def test_api_update_dn_with_parent(http_client: AsyncClient) -> None: @@ -436,7 +431,6 @@ async def test_api_update_dn_non_auth_user(http_client: AsyncClient) -> None: @pytest.mark.asyncio -@pytest.mark.usefixtures("adding_test_user") @pytest.mark.usefixtures("setup_session") @pytest.mark.usefixtures("session") async def test_api_update_dn_non_exist_superior( @@ -460,7 +454,30 @@ async def test_api_update_dn_non_exist_superior( @pytest.mark.asyncio -@pytest.mark.usefixtures("adding_test_user") +@pytest.mark.usefixtures("setup_session") +@pytest.mark.usefixtures("session") +@pytest.mark.usefixtures("add_system_administrator") +async def test_api_cant_update_system_directory( + http_client: AsyncClient, +) -> None: + """Test API for update DN of system directory.""" + response = await http_client.put( + "/entry/update/dn", + json={ + "entry": "cn=System Administrator,dc=md,dc=test", + "newrdn": "cn=New System Administrator", + "deleteoldrdn": True, + "new_superior": "dc=non-exist,dc=test", + }, + ) + + data = response.json() + + assert isinstance(data, dict) + assert data.get("resultCode") == LDAPCodes.UNWILLING_TO_PERFORM + + +@pytest.mark.asyncio @pytest.mark.usefixtures("setup_session") @pytest.mark.usefixtures("session") async def test_api_update_dn_non_exist_entry(http_client: AsyncClient) -> None: @@ -482,7 +499,6 @@ async def test_api_update_dn_non_exist_entry(http_client: AsyncClient) -> None: @pytest.mark.asyncio -@pytest.mark.usefixtures("adding_test_user") @pytest.mark.usefixtures("setup_session") @pytest.mark.usefixtures("session") async def test_api_update_dn_invalid_entry(http_client: AsyncClient) -> None: @@ -504,7 +520,6 @@ async def test_api_update_dn_invalid_entry(http_client: AsyncClient) -> None: @pytest.mark.asyncio -@pytest.mark.usefixtures("adding_test_user") @pytest.mark.usefixtures("setup_session") @pytest.mark.usefixtures("session") async def test_api_update_dn_invalid_new_superior( diff --git a/tests/test_api/test_main/test_router/test_search.py b/tests/test_api/test_main/test_router/test_search.py index 59f60221a..34a9377aa 100644 --- a/tests/test_api/test_main/test_router/test_search.py +++ b/tests/test_api/test_main/test_router/test_search.py @@ -7,6 +7,7 @@ import pytest from httpx import AsyncClient +from enums import EntityTypeNames from ldap_protocol.ldap_codes import LDAPCodes from tests.search_request_datasets import ( test_search_by_rule_anr_dataset, @@ -94,15 +95,15 @@ async def test_api_search(http_client: AsyncClient) -> None: assert response["resultCode"] == LDAPCodes.SUCCESS - sub_dirs = [ + sub_dirs = { "cn=groups,dc=md,dc=test", "cn=users,dc=md,dc=test", "ou=testModifyDn1,dc=md,dc=test", "ou=testModifyDn3,dc=md,dc=test", "ou=test_bit_rules,dc=md,dc=test", - ] - assert all( - obj["object_name"] in sub_dirs for obj in response["search_result"] + } + assert sub_dirs == set( + obj["object_name"] for obj in response["search_result"] ) @@ -432,7 +433,7 @@ async def test_api_search_by_entity_type_name( http_client: AsyncClient, ) -> None: """Test api search by entity type name.""" - entity_type_name = "User" + entity_type_name = EntityTypeNames.USER raw_response = await http_client.post( "entry/search", @@ -471,7 +472,7 @@ async def test_api_empty_search( http_client: AsyncClient, ) -> None: """Test api empty search.""" - entity_type_name = "User" + entity_type_name = EntityTypeNames.USER raw_response = await http_client.post( "entry/search", json={ diff --git a/tests/test_api/test_network/test_router.py b/tests/test_api/test_network/test_router.py index 82471448f..9155ff65d 100644 --- a/tests/test_api/test_network/test_router.py +++ b/tests/test_api/test_network/test_router.py @@ -329,12 +329,12 @@ async def test_404(http_client: AsyncClient) -> None: response = await http_client.delete( f"/policy/{some_id}", ) - assert response.status_code == 404 + assert response.status_code == status.HTTP_400_BAD_REQUEST response = await http_client.patch( f"/policy/{some_id}", ) - assert response.status_code == 404 + assert response.status_code == status.HTTP_400_BAD_REQUEST response = await http_client.put( "/policy", @@ -343,7 +343,7 @@ async def test_404(http_client: AsyncClient) -> None: "name": "123", }, ) - assert response.status_code == 404 + assert response.status_code == status.HTTP_400_BAD_REQUEST @pytest.mark.asyncio diff --git a/tests/test_api/test_password_policy/conftest.py b/tests/test_api/test_password_policy/conftest.py index f5a7b3841..157a94725 100644 --- a/tests/test_api/test_password_policy/conftest.py +++ b/tests/test_api/test_password_policy/conftest.py @@ -16,10 +16,16 @@ provide, ) -from api.password_policy.adapter import PasswordPolicyFastAPIAdapter +from api.password_policy.adapter import ( + PasswordPolicyFastAPIAdapter, + UserPasswordHistoryResetFastAPIAdapter, +) from config import Settings from ldap_protocol.policies.password import PasswordPolicyUseCases from ldap_protocol.policies.password.dataclasses import PasswordPolicyDTO +from ldap_protocol.policies.password.use_cases import ( + UserPasswordHistoryUseCases, +) from tests.conftest import TestProvider @@ -35,11 +41,18 @@ class TestLocalProvider(Provider): """Test provider for local scope.""" _cached_policy_use_cases: PasswordPolicyUseCases | None = None + _cached_user_password_history_use_cases: ( + UserPasswordHistoryUseCases | None + ) = None password_policies_adapter = provide( PasswordPolicyFastAPIAdapter, scope=Scope.REQUEST, ) + user_password_history_reset_adapter = provide( + UserPasswordHistoryResetFastAPIAdapter, + scope=Scope.REQUEST, + ) @provide(scope=Scope.REQUEST, provides=PasswordPolicyUseCases) async def get_password_use_cases( @@ -120,6 +133,22 @@ async def get_password_use_cases( yield self._cached_policy_use_cases self._cached_policy_use_cases = None + @provide( + scope=Scope.REQUEST, + provides=UserPasswordHistoryUseCases, + ) + async def get_user_password_history_use_cases( + self, + ) -> AsyncIterator[UserPasswordHistoryUseCases]: + if self._cached_user_password_history_use_cases is None: + session = Mock() + use_cases = UserPasswordHistoryUseCases(session) + use_cases.clear = make_mock("clear") # type: ignore + self._cached_user_password_history_use_cases = use_cases + + yield self._cached_user_password_history_use_cases + self._cached_user_password_history_use_cases = None + @pytest_asyncio.fixture(scope="session") async def container(settings: Settings) -> AsyncIterator[AsyncContainer]: @@ -141,3 +170,12 @@ async def password_use_cases( """Get di password_use_cases.""" async with container(scope=Scope.REQUEST) as container: yield await container.get(PasswordPolicyUseCases) + + +@pytest_asyncio.fixture +async def user_password_history_use_cases( + container: AsyncContainer, +) -> AsyncIterator[UserPasswordHistoryUseCases]: + """Get di user_password_history_use_cases.""" + async with container(scope=Scope.REQUEST) as container: + yield await container.get(UserPasswordHistoryUseCases) diff --git a/tests/test_api/test_password_policy/test_password_policy_router.py b/tests/test_api/test_password_policy/test_password_policy_router.py index 1ea090026..0e3dbba8c 100644 --- a/tests/test_api/test_password_policy/test_password_policy_router.py +++ b/tests/test_api/test_password_policy/test_password_policy_router.py @@ -22,7 +22,7 @@ async def test_get_all_with_error( ) -> None: """Test get all Password Policy endpoint.""" response = await http_client_with_login_perm.get("/password-policy/all") - assert response.status_code == status.HTTP_403_FORBIDDEN + assert response.status_code == status.HTTP_401_UNAUTHORIZED # NOTE to password_use_cases.get_all returned Mock, not wrapper password_use_cases._perm_checker = None # noqa: SLF001 @@ -50,7 +50,7 @@ async def test_get_with_error( ) -> None: """Test get one Password Policy endpoint.""" response = await http_client_with_login_perm.get("/password-policy/1") - assert response.status_code == status.HTTP_403_FORBIDDEN + assert response.status_code == status.HTTP_401_UNAUTHORIZED # NOTE to password_use_cases.get_all returned Mock, not wrapper password_use_cases._perm_checker = None # noqa: SLF001 @@ -81,7 +81,7 @@ async def test_get_password_policy_by_dir_path_dn_with_error( response = await http_client_with_login_perm.get( f"/password-policy/by_dir_path_dn/{path}", ) - assert response.status_code == status.HTTP_403_FORBIDDEN + assert response.status_code == status.HTTP_401_UNAUTHORIZED # NOTE to password_use_cases.get_all returned Mock, not wrapper password_use_cases._perm_checker = None # noqa: SLF001 @@ -136,7 +136,7 @@ async def test_update_with_error( "/password-policy/1", json=schema.model_dump(), ) - assert response.status_code == status.HTTP_403_FORBIDDEN + assert response.status_code == status.HTTP_401_UNAUTHORIZED # NOTE to password_use_cases.get_all returned Mock, not wrapper password_use_cases._perm_checker = None # noqa: SLF001 @@ -152,7 +152,7 @@ async def test_reset_domain_policy_to_default_config_with_error( response = await http_client_with_login_perm.put( "/password-policy/reset/domain_policy", ) - assert response.status_code == status.HTTP_403_FORBIDDEN + assert response.status_code == status.HTTP_401_UNAUTHORIZED # NOTE to password_use_cases.get_all returned Mock, not wrapper password_use_cases._perm_checker = None # noqa: SLF001 diff --git a/tests/test_api/test_password_policy/test_user_password_history_router.py b/tests/test_api/test_password_policy/test_user_password_history_router.py new file mode 100644 index 000000000..e1c1d8d57 --- /dev/null +++ b/tests/test_api/test_password_policy/test_user_password_history_router.py @@ -0,0 +1,45 @@ +"""Test User Password History router. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from unittest.mock import Mock + +import pytest +from fastapi import status +from httpx import AsyncClient + + +@pytest.mark.asyncio +async def test_clear_success( + http_client: AsyncClient, + user_password_history_use_cases: Mock, +) -> None: + """Test clear user password history endpoint.""" + response = await http_client.post( + "/user/password_history/clear", + data={"identity": "testuser"}, + ) + + # NOTE to user_password_history_use_cases.reset returned Mock, not wrapper # noqa: E501 + user_password_history_use_cases._perm_checker = None # noqa: SLF001 + user_password_history_use_cases.clear.assert_called_once() + assert response.status_code == status.HTTP_200_OK + + +@pytest.mark.asyncio +async def test_clear_unauthorized( + http_client_with_login_perm: AsyncClient, + user_password_history_use_cases: Mock, +) -> None: + """Test clear user password history endpoint without permissions.""" + response = await http_client_with_login_perm.post( + "/user/password_history/clear", + data={"identity": "testuser"}, + ) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + # NOTE to user_password_history_use_cases.reset returned Mock, not wrapper # noqa: E501 + user_password_history_use_cases._perm_checker = None # noqa: SLF001 + user_password_history_use_cases.clear.assert_not_called() diff --git a/tests/test_api/test_shadow/test_router.py b/tests/test_api/test_shadow/test_router.py index 764404d42..815d19798 100644 --- a/tests/test_api/test_shadow/test_router.py +++ b/tests/test_api/test_shadow/test_router.py @@ -28,7 +28,7 @@ async def test_shadow_api_non_existent_user(http_client: AsyncClient) -> None: ).model_dump(), ) - assert response.status_code == status.HTTP_404_NOT_FOUND + assert response.status_code == status.HTTP_400_BAD_REQUEST @pytest.mark.asyncio @@ -46,7 +46,7 @@ async def test_shadow_api_without_network_policies( json=adding_mfa_user_and_group, ) - assert response.status_code == status.HTTP_403_FORBIDDEN + assert response.status_code == status.HTTP_400_BAD_REQUEST @pytest.mark.asyncio @@ -66,7 +66,7 @@ async def test_shadow_api_without_kerberos_protocol( json=adding_mfa_user_and_group, ) - assert response.status_code == status.HTTP_403_FORBIDDEN + assert response.status_code == status.HTTP_400_BAD_REQUEST @pytest.mark.asyncio diff --git a/tests/test_ldap/policies/test_audit/test_rfc5424_serializer.py b/tests/test_ldap/policies/test_audit/test_rfc5424_serializer.py new file mode 100644 index 000000000..34db31bd5 --- /dev/null +++ b/tests/test_ldap/policies/test_audit/test_rfc5424_serializer.py @@ -0,0 +1,210 @@ +"""Test RFC5424Serializer. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +import re +from datetime import datetime, timezone + +import pytest + +from ldap_protocol.policies.audit.events.service_senders.rfc5424_serializer import ( # noqa: E501 + RFC5424Serializer, +) + + +@pytest.fixture +def serializer() -> RFC5424Serializer: + """Create serializer instance.""" + return RFC5424Serializer( + app_name="TestApp", + facility="authpriv", + ) + + +@pytest.mark.parametrize( + ("facility", "severity", "expected_severity"), + [ + ("kernel", 5, 5), + ("user", 3, 11), + ("authpriv", 6, 86), + ("local0", 7, 135), + ("local7", 2, 186), + ], +) +def test_format_priority( + facility: str, + severity: int, + expected_severity: int, +) -> None: + """Test _format_priority with different facilities and severities.""" + serializer = RFC5424Serializer(app_name="Test", facility=facility) + severity = serializer._format_severity(severity) # noqa: SLF001 + assert severity == expected_severity + + +@pytest.mark.parametrize( + "invalid_severity", + [-1, 8, 10, 100], +) +def test_format_priority_invalid_severity( + serializer: RFC5424Serializer, + invalid_severity: int, +) -> None: + """Test _format_priority with invalid severity values.""" + with pytest.raises(NotImplementedError, match="Severity must be 0-7"): + serializer._format_severity(invalid_severity) # noqa: SLF001 + + +def test_format_timestamp(serializer: RFC5424Serializer) -> None: + """Test _format_timestamp formats timestamp correctly.""" + dt = datetime(2025, 12, 23, 10, 30, 45, 123000, tzinfo=timezone.utc) + timestamp = dt.timestamp() + + result = serializer._format_timestamp(timestamp) # noqa: SLF001 + + assert result == "2025-12-23T10:30:45.123Z" + assert re.match(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}Z", result) + + +@pytest.mark.parametrize( + ("hostname", "expected"), + [ + ("server01.example.com", "server01.example.com"), + ("a" * 300, "a" * 255), + ], +) +def test_format_hostname( + serializer: RFC5424Serializer, + hostname: str, + expected: str, +) -> None: + """Test _format_hostname with various inputs.""" + result = serializer._format_hostname(hostname) # noqa: SLF001 + assert result == expected + + +def test_format_hostname_with_none(serializer: RFC5424Serializer) -> None: + """Test _format_hostname with None uses system hostname.""" + result = serializer._format_hostname(None) # noqa: SLF001 + assert result != "-" + assert len(result) > 0 + + +@pytest.mark.parametrize( + ("value", "max_length", "expected"), + [ + ("test_value", 100, "test_value"), + (None, 100, "-"), + ("", 100, "-"), + ("abcdefghij", 5, "abcde"), + ("test\x00\x01value\n\r", 100, "testvalue"), + ("\x00\x01\x02\n\r", 100, "-"), + ], +) +def test_format_field( + serializer: RFC5424Serializer, + value: str | None, + max_length: int, + expected: str, +) -> None: + """Test _format_field with various inputs.""" + result = serializer._format_field(value, max_length) # noqa: SLF001 + assert result == expected + + +@pytest.mark.parametrize( + ("data", "expected_result"), + [ + ({}, "-"), + ({"username": "admin"}, '[audit@32473 username="admin"]'), + ], +) +def test_format_structured_data( + serializer: RFC5424Serializer, + data: dict, + expected_result: str, +) -> None: + """Test _format_structured_data with various inputs.""" + result = serializer._format_structured_data(data) # noqa: SLF001 + assert result == expected_result + + +def test_format_structured_data_multiple_params( + serializer: RFC5424Serializer, +) -> None: + """Test _format_structured_data with multiple parameters.""" + data = { + "username": "admin", + "ip": "192.168.1.100", + "action": "login", + } + result = serializer._format_structured_data(data) # noqa: SLF001 + + assert result.startswith("[audit@32473") + assert result.endswith("]") + assert 'username="admin"' in result + assert 'ip="192.168.1.100"' in result + assert 'action="login"' in result + + +@pytest.mark.parametrize( + ("input_name", "expected"), + [ + ("valid_name123", "valid_name123"), + ("user name", "username"), + ("user=name", "username"), + ('user"name', "username"), + ("user]name", "username"), + ], +) +def test_sanitize_param_name( + serializer: RFC5424Serializer, + input_name: str, + expected: str, +) -> None: + """Test _sanitize_param_name with various inputs.""" + result = serializer._sanitize_param_name(input_name) # noqa: SLF001 + assert result == expected + + +@pytest.mark.parametrize( + ("input_value", "expected"), + [ + ("simple text", "simple text"), + ("path\\to\\file", "path\\\\to\\\\file"), + ('say "hello"', r"say \"hello\""), + ("array[index]", r"array[index\]"), + ( + 'Test "quote" and \\backslash and ]bracket', + r"Test \"quote\" and \\backslash and \]bracket", + ), + ], +) +def test_escape_param_value( + serializer: RFC5424Serializer, + input_value: str, + expected: str, +) -> None: + """Test _escape_param_value with various special characters.""" + result = serializer._escape_param_value(input_value) # noqa: SLF001 + assert result == expected + + +@pytest.mark.parametrize( + ("input_msg", "expected"), + [ + ("User logged in", " \ufeffUser logged in"), + (None, ""), + ("", ""), + ], +) +def test_format_message( + serializer: RFC5424Serializer, + input_msg: str | None, + expected: str, +) -> None: + """Test _format_message with various inputs.""" + result = serializer._format_message(input_msg) # noqa: SLF001 + assert result == expected diff --git a/tests/test_ldap/policies/test_network/test_pool_client_handler.py b/tests/test_ldap/policies/test_network/test_pool_client_handler.py index 95bad8cc3..9f212986c 100644 --- a/tests/test_ldap/policies/test_network/test_pool_client_handler.py +++ b/tests/test_ldap/policies/test_network/test_pool_client_handler.py @@ -10,8 +10,8 @@ from sqlalchemy.ext.asyncio import AsyncSession from entities import NetworkPolicy -from ldap_protocol.dialogue import LDAPSession -from ldap_protocol.policies.network_policy import is_user_group_valid +from enums import ProtocolType +from ldap_protocol.policies.network import NetworkPolicyValidatorUseCase from ldap_protocol.utils.queries import get_group, get_user @@ -19,18 +19,20 @@ @pytest.mark.usefixtures("setup_session") @pytest.mark.usefixtures("session") async def test_check_policy( - ldap_session: LDAPSession, - session: AsyncSession, + network_policy_validator: NetworkPolicyValidatorUseCase, ) -> None: """Check policy.""" - policy = await ldap_session._get_policy(IPv4Address("127.0.0.1"), session) + policy = await network_policy_validator.get_by_protocol( + IPv4Address("127.0.0.1"), + ProtocolType.LDAP, + ) assert policy assert policy.netmasks == [IPv4Network("0.0.0.0/0")] @pytest.mark.asyncio async def test_specific_policy_ok( - ldap_session: LDAPSession, + network_policy_validator: NetworkPolicyValidatorUseCase, session: AsyncSession, ) -> None: """Test specific ip.""" @@ -44,15 +46,15 @@ async def test_specific_policy_ok( ), ) await session.commit() - policy = await ldap_session._get_policy( + policy = await network_policy_validator.get_by_protocol( ip=IPv4Address("127.100.10.5"), - session=session, + protocol_type=ProtocolType.LDAP, ) assert policy assert policy.netmasks == [IPv4Network("127.100.10.5/32")] - assert not await ldap_session._get_policy( + assert not await network_policy_validator.get_by_protocol( ip=IPv4Address("127.100.10.4"), - session=session, + protocol_type=ProtocolType.LDAP, ) @@ -60,17 +62,20 @@ async def test_specific_policy_ok( @pytest.mark.usefixtures("setup_session") @pytest.mark.usefixtures("settings") async def test_check_policy_group( - ldap_session: LDAPSession, + network_policy_validator: NetworkPolicyValidatorUseCase, session: AsyncSession, ) -> None: """Check policy.""" user = await get_user(session, "user0") assert user - policy = await ldap_session._get_policy(IPv4Address("127.0.0.1"), session) + policy = await network_policy_validator.get_by_protocol( + IPv4Address("127.0.0.1"), + ProtocolType.LDAP, + ) assert policy - assert await is_user_group_valid(user, policy, session) + assert await network_policy_validator.is_user_group_valid(user, policy) group = await get_group( dn="cn=domain admins,cn=groups,dc=md,dc=test", @@ -80,4 +85,4 @@ async def test_check_policy_group( policy.groups.append(group) await session.commit() - assert await is_user_group_valid(user, policy, session) + assert await network_policy_validator.is_user_group_valid(user, policy) diff --git a/tests/test_ldap/test_access_manager/test_search_access.py b/tests/test_ldap/test_access_manager/test_search_access.py index dd6f20b67..c586e2fc4 100644 --- a/tests/test_ldap/test_access_manager/test_search_access.py +++ b/tests/test_ldap/test_access_manager/test_search_access.py @@ -123,11 +123,11 @@ def test_check_search_access( expected_result: tuple[bool, set[str], set[str]], ) -> None: """Test the check_search_access method of AccessManager.""" - filtered_aces = AccessManager._filter_aces_by_entity_type( + filtered_aces = AccessManager._filter_aces_by_entity_type( # noqa: SLF001 aces, entity_type_id, ) - result = AccessManager._check_search_access(filtered_aces) + result = AccessManager._check_search_access(filtered_aces) # noqa: SLF001 assert result == expected_result diff --git a/tests/test_ldap/test_attribute_value_validator.py b/tests/test_ldap/test_attribute_value_validator.py new file mode 100644 index 000000000..084ca0a72 --- /dev/null +++ b/tests/test_ldap/test_attribute_value_validator.py @@ -0,0 +1,589 @@ +"""Tests for AttributeValueValidator. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +import pytest + +from enums import EntityTypeNames +from ldap_protocol.ldap_schema.attribute_value_validator import ( + AttributeValueValidator, +) + + +@pytest.fixture +def validator() -> AttributeValueValidator: + """Create validator instance.""" + return AttributeValueValidator() + + +class TestOrganizationalUnitName: + """Tests for Organizational Unit name validation.""" + + _entity_type_name = EntityTypeNames.ORGANIZATIONAL_UNIT + + def test_valid_names(self, validator: AttributeValueValidator) -> None: + """Test valid organizational unit names.""" + valid_names = [ + "IT Department", + "Sales", + "Marketing-Team", + "HR_Department", + "Department123", + ] + for name in valid_names: + assert validator.is_value_valid( + self._entity_type_name, + "name", + name, + ) + + def test_invalid_names_with_space_at_start( + self, + validator: AttributeValueValidator, + ) -> None: + """Test organizational unit names starting with space.""" + invalid_names = [" IT", " Sales", " Marketing"] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "name", + name, + ) + + def test_invalid_names_with_hash_at_start( + self, + validator: AttributeValueValidator, + ) -> None: + """Test organizational unit names starting with hash.""" + invalid_names = ["#IT", "#Sales", "#Marketing"] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "name", + name, + ) + + def test_invalid_names_with_space_at_end( + self, + validator: AttributeValueValidator, + ) -> None: + """Test organizational unit names ending with space.""" + invalid_names = ["IT ", "Sales ", "Marketing "] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "name", + name, + ) + + def test_invalid_names_with_forbidden_symbols( + self, + validator: AttributeValueValidator, + ) -> None: + """Test organizational unit names with forbidden symbols.""" + invalid_names = [ + 'IT"Dept', + "Sales,Team", + "Marketing+", + "HR\\Group", + "Dept<1>", + "Team;A", + "Group=B", + ] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "name", + name, + ) + + +class TestGroupName: + """Tests for Group name validation.""" + + _entity_type_name = EntityTypeNames.GROUP + + def test_valid_names(self, validator: AttributeValueValidator) -> None: + """Test valid group names.""" + valid_names = [ + "Administrators", + "Users", + "Power_Users", + "Group-123", + "TeamA", + ] + for name in valid_names: + assert validator.is_value_valid( + self._entity_type_name, + "name", + name, + ) + + def test_invalid_names_with_space_at_start( + self, + validator: AttributeValueValidator, + ) -> None: + """Test group names starting with space.""" + invalid_names = [" Admins", " Users", " PowerUsers"] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "name", + name, + ) + + def test_invalid_names_with_hash_at_start( + self, + validator: AttributeValueValidator, + ) -> None: + """Test group names starting with hash.""" + invalid_names = ["#Admins", "#Users", "#PowerUsers"] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "name", + name, + ) + + def test_invalid_names_with_space_at_end( + self, + validator: AttributeValueValidator, + ) -> None: + """Test group names ending with space.""" + invalid_names = ["Admins ", "Users ", "PowerUsers "] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "name", + name, + ) + + def test_invalid_names_with_forbidden_symbols( + self, + validator: AttributeValueValidator, + ) -> None: + """Test group names with forbidden symbols.""" + invalid_names = [ + 'Admins"Group', + "Users,Team", + "Power+Users", + "Group\\A", + "Team<1>", + "Users;B", + "Group=C", + ] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "name", + name, + ) + + +class TestUserName: + """Tests for User name validation.""" + + _entity_type_name = EntityTypeNames.USER + + def test_valid_names(self, validator: AttributeValueValidator) -> None: + """Test valid user names.""" + valid_names = [ + "John Doe", + "Jane_Smith", + "User-123", + "Administrator", + "User.Name", + ] + for name in valid_names: + assert validator.is_value_valid( + self._entity_type_name, + "name", + name, + ) + + def test_invalid_names_with_space_at_start( + self, + validator: AttributeValueValidator, + ) -> None: + """Test user names starting with space.""" + invalid_names = [" JohnDoe", " Jane", " User123"] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "name", + name, + ) + + def test_invalid_names_with_hash_at_start( + self, + validator: AttributeValueValidator, + ) -> None: + """Test user names starting with hash.""" + invalid_names = ["#JohnDoe", "#Jane", "#User123"] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "name", + name, + ) + + def test_invalid_names_with_space_at_end( + self, + validator: AttributeValueValidator, + ) -> None: + """Test user names ending with space.""" + invalid_names = ["JohnDoe ", "Jane ", "User123 "] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "name", + name, + ) + + def test_invalid_names_with_forbidden_symbols( + self, + validator: AttributeValueValidator, + ) -> None: + """Test user names with forbidden symbols.""" + invalid_names = [ + 'John"Doe', + "Jane,Smith", + "User+123", + "Name\\Test", + "User<1>", + "John;Doe", + "User=Name", + ] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "name", + name, + ) + + +class TestUserSAMAccountName: + """Tests for User sAMAccountName validation.""" + + _entity_type_name = EntityTypeNames.USER + + def test_valid_sam_account_names( + self, + validator: AttributeValueValidator, + ) -> None: + """Test valid sAMAccountName values.""" + valid_names = [ + "jdoe", + "john.doe", + "user123", + "admin_user", + "test-user", + ] + for name in valid_names: + assert validator.is_value_valid( + self._entity_type_name, + "sAMAccountName", + name, + ) + + def test_invalid_sam_account_names_with_forbidden_symbols( + self, + validator: AttributeValueValidator, + ) -> None: + """Test sAMAccountName with forbidden symbols.""" + invalid_names = [ + 'user"name', + "user/name", + "user\\name", + "user[name]", + "user:name", + "user;name", + "user|name", + "user=name", + "user,name", + "user+name", + "user*name", + "user?name", + "user", + ] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "sAMAccountName", + name, + ) + + def test_invalid_sam_account_names_ending_with_dot( + self, + validator: AttributeValueValidator, + ) -> None: + """Test sAMAccountName ending with dot.""" + invalid_names = ["user.", "john.doe.", "admin."] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "sAMAccountName", + name, + ) + + def test_invalid_sam_account_names_with_control_chars( + self, + validator: AttributeValueValidator, + ) -> None: + """Test sAMAccountName with control characters.""" + invalid_names = [ + "user\x00name", + "user\x01name", + "user\x1fname", + "user\x7fname", + ] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "sAMAccountName", + name, + ) + + def test_invalid_sam_account_names_with_at_symbol( + self, + validator: AttributeValueValidator, + ) -> None: + """Test sAMAccountName with @ symbol.""" + invalid_names = ["user@domain", "admin@test", "john@"] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "sAMAccountName", + name, + ) + + +class TestComputerName: + """Tests for Computer name validation.""" + + _entity_type_name = EntityTypeNames.COMPUTER + + def test_valid_names(self, validator: AttributeValueValidator) -> None: + """Test valid computer names.""" + valid_names = [ + "WORKSTATION01", + "Server-2024", + "PC_LAB_123", + "Desktop", + ] + for name in valid_names: + assert validator.is_value_valid( + self._entity_type_name, + "name", + name, + ) + + def test_invalid_names_with_space_at_start( + self, + validator: AttributeValueValidator, + ) -> None: + """Test computer names starting with space.""" + invalid_names = [" WORKSTATION", " Server", " PC123"] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "name", + name, + ) + + def test_invalid_names_with_hash_at_start( + self, + validator: AttributeValueValidator, + ) -> None: + """Test computer names starting with hash.""" + invalid_names = ["#WORKSTATION", "#Server", "#PC123"] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "name", + name, + ) + + def test_invalid_names_with_space_at_end( + self, + validator: AttributeValueValidator, + ) -> None: + """Test computer names ending with space.""" + invalid_names = ["WORKSTATION ", "Server ", "PC123 "] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "name", + name, + ) + + def test_invalid_names_with_forbidden_symbols( + self, + validator: AttributeValueValidator, + ) -> None: + """Test computer names with forbidden symbols.""" + invalid_names = [ + 'PC"01', + "Server,01", + "Work+Station", + "PC\\01", + "Server<1>", + "PC;01", + "Computer=01", + ] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "name", + name, + ) + + +class TestComputerSAMAccountName: + """Tests for Computer sAMAccountName validation.""" + + _entity_type_name = EntityTypeNames.COMPUTER + + def test_valid_sam_account_names( + self, + validator: AttributeValueValidator, + ) -> None: + """Test valid computer sAMAccountName values.""" + valid_names = [ + "WORKSTATION01$", + "SERVER-2024$", + "PC_LAB$", + ] + for name in valid_names: + assert validator.is_value_valid( + self._entity_type_name, + "sAMAccountName", + name, + ) + + def test_invalid_sam_account_names_with_forbidden_symbols( + self, + validator: AttributeValueValidator, + ) -> None: + """Test computer sAMAccountName with forbidden symbols.""" + invalid_names = [ + 'PC"01$', + "PC/01$", + "PC\\01$", + "PC[01]$", + "PC:01$", + "PC;01$", + "PC|01$", + "PC=01$", + "PC,01$", + "PC+01$", + "PC*01$", + "PC?01$", + "PC<01>$", + ] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "sAMAccountName", + name, + ) + + def test_invalid_sam_account_names_ending_with_dot( + self, + validator: AttributeValueValidator, + ) -> None: + """Test computer sAMAccountName ending with dot.""" + invalid_names = ["PC01.", "SERVER.", "WORKSTATION."] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "sAMAccountName", + name, + ) + + def test_invalid_sam_account_names_with_control_chars( + self, + validator: AttributeValueValidator, + ) -> None: + """Test computer sAMAccountName with control characters.""" + invalid_names = [ + "PC\x00NAME$", + "PC\x01NAME$", + "PC\x1fNAME$", + "PC\x7fNAME$", + ] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "sAMAccountName", + name, + ) + + def test_invalid_sam_account_names_with_spaces_and_dots( + self, + validator: AttributeValueValidator, + ) -> None: + """Test computer sAMAccountName with spaces and dots.""" + invalid_names = [ + "PC 01$", + "SERVER 2024$", + "WORK.STATION$", + "PC.01$", + ] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "sAMAccountName", + name, + ) + + def test_invalid_sam_account_names_only_numbers( + self, + validator: AttributeValueValidator, + ) -> None: + """Test computer sAMAccountName that are only numbers.""" + invalid_names = ["123", "456789", "0"] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "sAMAccountName", + name, + ) + + def test_invalid_sam_account_names_starting_with_number( + self, + validator: AttributeValueValidator, + ) -> None: + """Test computer sAMAccountName starting with number.""" + invalid_names = ["1PC$", "2SERVER$", "9WORKSTATION$"] + for name in invalid_names: + assert not validator.is_value_valid( + self._entity_type_name, + "sAMAccountName", + name, + ) + + +class TestNoValidationRules: + """Test validation for attributes without specific rules.""" + + def test_attributes_without_rules_always_valid( + self, + validator: AttributeValueValidator, + ) -> None: + """Test that attributes without validation rules always pass.""" + test_cases = [ + (EntityTypeNames.USER, "description", "Any value here!"), + (EntityTypeNames.GROUP, "description", " spaces and #symbols "), + (EntityTypeNames.COMPUTER, "location", "Building 1, Room 101"), + (EntityTypeNames.ORGANIZATIONAL_UNIT, "description", ""), + ] + + for entity_type, property_name, value in test_cases: + assert validator.is_value_valid( + entity_type, + property_name, + value, + ) diff --git a/tests/test_ldap/test_bind.py b/tests/test_ldap/test_bind.py index b2aa0a0e5..e31353880 100644 --- a/tests/test_ldap/test_bind.py +++ b/tests/test_ldap/test_bind.py @@ -79,12 +79,12 @@ async def mock_init_security_context( session: AsyncSession, # noqa: ARG001 settings: Settings, # noqa: ARG001 ) -> None: - auth_choice._ldap_session.gssapi_security_context = ( + auth_choice._ldap_session.gssapi_security_context = ( # noqa: SLF001 mock_security_context ) auth_choice = SaslGSSAPIAuthentication(ticket=b"ticket") - auth_choice._init_security_context = mock_init_security_context # type: ignore + auth_choice._init_security_context = mock_init_security_context # type: ignore # noqa: SLF001 bind = BindRequest( version=0, @@ -150,12 +150,12 @@ async def mock_init_security_context( session: AsyncSession, # noqa: ARG001 settings: Settings, # noqa: ARG001 ) -> None: - auth_choice._ldap_session.gssapi_security_context = ( + auth_choice._ldap_session.gssapi_security_context = ( # noqa: SLF001 mock_security_context ) auth_choice = SaslGSSAPIAuthentication(ticket=b"client_ticket") - auth_choice._init_security_context = mock_init_security_context # type: ignore + auth_choice._init_security_context = mock_init_security_context # type: ignore # noqa: SLF001 first_bind = BindRequest( version=0, @@ -218,12 +218,12 @@ async def mock_init_security_context( session: AsyncSession, # noqa: ARG001 settings: Settings, # noqa: ARG001 ) -> None: - auth_choice._ldap_session.gssapi_security_context = ( + auth_choice._ldap_session.gssapi_security_context = ( # noqa: SLF001 mock_security_context ) auth_choice = SaslSPNEGOAuthentication(ticket=b"client_ticket") - auth_choice._init_security_context = mock_init_security_context # type: ignore + auth_choice._init_security_context = mock_init_security_context # type: ignore # noqa: SLF001 first_bind = BindRequest( version=0, diff --git a/tests/test_ldap/test_roles/test_multiple_access.py b/tests/test_ldap/test_roles/test_multiple_access.py index 5c6f83c27..da8cc17bc 100644 --- a/tests/test_ldap/test_roles/test_multiple_access.py +++ b/tests/test_ldap/test_roles/test_multiple_access.py @@ -13,7 +13,7 @@ from config import Settings from entities import Directory -from enums import AceType, RoleScope +from enums import AceType, EntityTypeNames, RoleScope from ldap_protocol.ldap_schema.attribute_type_dao import AttributeTypeDAO from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO from ldap_protocol.roles.ace_dao import AccessControlEntryDAO @@ -37,7 +37,7 @@ async def test_multiple_access( custom_role: RoleDTO, ) -> None: """Test multiple access control entries in a role.""" - user_entity_type = await entity_type_dao.get("User") + user_entity_type = await entity_type_dao.get(EntityTypeNames.USER) assert user_entity_type posix_email_attr = await attribute_type_dao.get("posixEmail") diff --git a/tests/test_ldap/test_roles/test_search.py b/tests/test_ldap/test_roles/test_search.py index 35115bc1d..a20e8f0dd 100644 --- a/tests/test_ldap/test_roles/test_search.py +++ b/tests/test_ldap/test_roles/test_search.py @@ -7,7 +7,7 @@ import pytest from config import Settings -from enums import AceType, RoleScope +from enums import AceType, EntityTypeNames, RoleScope from ldap_protocol.ldap_schema.attribute_type_dao import AttributeTypeDAO from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO from ldap_protocol.roles.ace_dao import AccessControlEntryDAO @@ -169,7 +169,7 @@ async def test_role_search_5( User with a custom role should see all Users objects. """ - user_entity_type = await entity_type_dao.get("User") + user_entity_type = await entity_type_dao.get(EntityTypeNames.USER) assert user_entity_type ace = AccessControlEntryDTO( @@ -221,7 +221,7 @@ async def test_role_search_6( User with a custom role should see only the posixEmail attribute. """ - user_entity_type = await entity_type_dao.get("User") + user_entity_type = await entity_type_dao.get(EntityTypeNames.USER) assert user_entity_type posix_email_attr = await attribute_type_dao.get("posixEmail") @@ -270,7 +270,7 @@ async def test_role_search_7( User with a custom role should see all attributes except description. """ - user_entity_type = await entity_type_dao.get("User") + user_entity_type = await entity_type_dao.get(EntityTypeNames.USER) assert user_entity_type description_attr = await attribute_type_dao.get("description") @@ -330,7 +330,7 @@ async def test_role_search_8( User with a custom role should see only the description attribute. """ - user_entity_type = await entity_type_dao.get("User") + user_entity_type = await entity_type_dao.get(EntityTypeNames.USER) assert user_entity_type description_attr = await attribute_type_dao.get("description") @@ -390,7 +390,7 @@ async def test_role_search_9( User with a custom role should see only the posixEmail attribute. """ - user_entity_type = await entity_type_dao.get("User") + user_entity_type = await entity_type_dao.get(EntityTypeNames.USER) assert user_entity_type description_attr = await attribute_type_dao.get("description") diff --git a/tests/test_ldap/test_util/test_modify.py b/tests/test_ldap/test_util/test_modify.py index c141914a4..02b174b4b 100644 --- a/tests/test_ldap/test_util/test_modify.py +++ b/tests/test_ldap/test_util/test_modify.py @@ -777,8 +777,6 @@ async def try_modify() -> int: ] assert attributes["jpegPhoto"] == ["modme.jpeg"] - assert directory.user - assert directory.user.mail == "modme@student.of.life.edu" assert "posixEmail" not in attributes diff --git a/tests/test_ldap/test_util/test_search.py b/tests/test_ldap/test_util/test_search.py index 137508bf6..903fb2598 100644 --- a/tests/test_ldap/test_util/test_search.py +++ b/tests/test_ldap/test_util/test_search.py @@ -14,13 +14,13 @@ from config import Settings from entities import User -from enums import AceType, RoleScope +from enums import AceType, ProtocolType, RoleScope from ldap_protocol.asn1parser import ASN1Row, TagNumbers from ldap_protocol.dialogue import LDAPSession from ldap_protocol.ldap_requests import SearchRequest from ldap_protocol.ldap_requests.contexts import LDAPSearchRequestContext from ldap_protocol.ldap_responses import SearchResultEntry -from ldap_protocol.policies.network_policy import is_user_group_valid +from ldap_protocol.policies.network import NetworkPolicyValidatorUseCase from ldap_protocol.roles.ace_dao import AccessControlEntryDAO from ldap_protocol.roles.dataclasses import AccessControlEntryDTO, RoleDTO from ldap_protocol.roles.role_dao import RoleDAO @@ -307,10 +307,13 @@ async def test_bind_policy( session: AsyncSession, settings: Settings, creds: TestCreds, - ldap_session: LDAPSession, + network_policy_validator: NetworkPolicyValidatorUseCase, ) -> None: """Bind with policy.""" - policy = await ldap_session._get_policy(IPv4Address("127.0.0.1"), session) # noqa: SLF001 + policy = await network_policy_validator.get_by_protocol( + IPv4Address("127.0.0.1"), + ProtocolType.LDAP, + ) assert policy group = await get_group( @@ -345,12 +348,15 @@ async def test_bind_policy( @pytest.mark.usefixtures("setup_session") async def test_bind_policy_missing_group( session: AsyncSession, - ldap_session: LDAPSession, settings: Settings, creds: TestCreds, + network_policy_validator: NetworkPolicyValidatorUseCase, ) -> None: """Bind policy fail.""" - policy = await ldap_session._get_policy(IPv4Address("127.0.0.1"), session) # noqa: SLF001 + policy = await network_policy_validator.get_by_protocol( + IPv4Address("127.0.0.1"), + ProtocolType.LDAP, + ) assert policy @@ -368,7 +374,7 @@ async def test_bind_policy_missing_group( user.groups.clear() await session.commit() - assert not await is_user_group_valid(user, policy, session) + assert not await network_policy_validator.is_user_group_valid(user, policy) proc = await asyncio.create_subprocess_exec( "ldapsearch", diff --git a/tests/test_shedule.py b/tests/test_shedule.py index fde3d7a4a..dc5aaaf01 100644 --- a/tests/test_shedule.py +++ b/tests/test_shedule.py @@ -67,11 +67,9 @@ async def test_check_ldap_principal( async def test_update_krb5_config( session: AsyncSession, settings: Settings, - kadmin: AbstractKadmin, ) -> None: """Test update_krb5_config.""" await update_krb5_config( session=session, - kadmin=kadmin, settings=settings, ) diff --git a/traefik.yml b/traefik.yml index 7b2384086..f95bf72f3 100644 --- a/traefik.yml +++ b/traefik.yml @@ -22,6 +22,14 @@ entryPoints: address: ":636" proxyProtocol: insecure: true + global_ldap: + address: ":3268" + proxyProtocol: + insecure: true + global_ldap_tls: + address: ":3269" + proxyProtocol: + insecure: true kadmind: address: ":749" kpasswd: diff --git a/uv.lock b/uv.lock index c0207e4fa..85838a5a8 100644 --- a/uv.lock +++ b/uv.lock @@ -304,6 +304,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/45/7c/97d033faf771c9fe960c7b51eb78ab266bfa64cbc917601978963f0c3c7b/fastapi-0.118.2-py3-none-any.whl", hash = "sha256:d1f842612e6a305f95abe784b7f8d3215477742e7c67a16fccd20bd79db68150", size = 97954, upload-time = "2025-10-08T14:52:16.166Z" }, ] +[[package]] +name = "fastapi-error-map" +version = "0.9.8" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "fastapi" }, + { name = "orjson" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/84/9a/81aefff01594bfced5afdfb6c93de02a0f28fccc562f28c6bd721d7876a8/fastapi_error_map-0.9.8.tar.gz", hash = "sha256:894f6884598e4dd8b6c76cae59dee1522813ac3799ba0231b05465193c752f93", size = 386418, upload-time = "2025-11-02T01:58:06.546Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/37/07/850dc161f16d79ec86f61e90e1dda2bc95c70c462b2b3fb0e46455b1dc98/fastapi_error_map-0.9.8-py3-none-any.whl", hash = "sha256:1d54a5a40b4a7c8653266f0c3f1f3d6be4729e19fd6ec34c29addc85d3e27b58", size = 20462, upload-time = "2025-11-02T01:58:04.545Z" }, +] + [[package]] name = "fastapi-sqlalchemy-monitor" version = "1.1.3" @@ -532,6 +545,7 @@ dependencies = [ { name = "dishka" }, { name = "dnspython" }, { name = "fastapi" }, + { name = "fastapi-error-map" }, { name = "gssapi" }, { name = "httpx" }, { name = "jinja2" }, @@ -585,6 +599,7 @@ requires-dist = [ { name = "dishka", specifier = ">=1.6.0" }, { name = "dnspython", specifier = ">=2.7.0" }, { name = "fastapi", specifier = ">=0.115.0" }, + { name = "fastapi-error-map", specifier = ">=0.9.8" }, { name = "gssapi", specifier = ">=1.9.0" }, { name = "httpx", specifier = ">=0.28.1" }, { name = "jinja2", specifier = ">=3.1.4" }, @@ -654,6 +669,29 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl", hash = "sha256:1be4cccdb0f2482337c4743e60421de3a356cd97508abadd57d47403e94f5505", size = 4963, upload-time = "2025-04-22T14:54:22.983Z" }, ] +[[package]] +name = "orjson" +version = "3.11.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c6/fe/ed708782d6709cc60eb4c2d8a361a440661f74134675c72990f2c48c785f/orjson-3.11.4.tar.gz", hash = "sha256:39485f4ab4c9b30a3943cfe99e1a213c4776fb69e8abd68f66b83d5a0b0fdc6d", size = 5945188, upload-time = "2025-10-24T15:50:38.027Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/23/15/c52aa7112006b0f3d6180386c3a46ae057f932ab3425bc6f6ac50431cca1/orjson-3.11.4-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:2d6737d0e616a6e053c8b4acc9eccea6b6cce078533666f32d140e4f85002534", size = 243525, upload-time = "2025-10-24T15:49:29.737Z" }, + { url = "https://files.pythonhosted.org/packages/ec/38/05340734c33b933fd114f161f25a04e651b0c7c33ab95e9416ade5cb44b8/orjson-3.11.4-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:afb14052690aa328cc118a8e09f07c651d301a72e44920b887c519b313d892ff", size = 128871, upload-time = "2025-10-24T15:49:31.109Z" }, + { url = "https://files.pythonhosted.org/packages/55/b9/ae8d34899ff0c012039b5a7cb96a389b2476e917733294e498586b45472d/orjson-3.11.4-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:38aa9e65c591febb1b0aed8da4d469eba239d434c218562df179885c94e1a3ad", size = 130055, upload-time = "2025-10-24T15:49:33.382Z" }, + { url = "https://files.pythonhosted.org/packages/33/aa/6346dd5073730451bee3681d901e3c337e7ec17342fb79659ec9794fc023/orjson-3.11.4-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f2cf4dfaf9163b0728d061bebc1e08631875c51cd30bf47cb9e3293bfbd7dcd5", size = 129061, upload-time = "2025-10-24T15:49:34.935Z" }, + { url = "https://files.pythonhosted.org/packages/39/e4/8eea51598f66a6c853c380979912d17ec510e8e66b280d968602e680b942/orjson-3.11.4-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:89216ff3dfdde0e4070932e126320a1752c9d9a758d6a32ec54b3b9334991a6a", size = 136541, upload-time = "2025-10-24T15:49:36.923Z" }, + { url = "https://files.pythonhosted.org/packages/9a/47/cb8c654fa9adcc60e99580e17c32b9e633290e6239a99efa6b885aba9dbc/orjson-3.11.4-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9daa26ca8e97fae0ce8aa5d80606ef8f7914e9b129b6b5df9104266f764ce436", size = 137535, upload-time = "2025-10-24T15:49:38.307Z" }, + { url = "https://files.pythonhosted.org/packages/43/92/04b8cc5c2b729f3437ee013ce14a60ab3d3001465d95c184758f19362f23/orjson-3.11.4-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c8b2769dc31883c44a9cd126560327767f848eb95f99c36c9932f51090bfce9", size = 136703, upload-time = "2025-10-24T15:49:40.795Z" }, + { url = "https://files.pythonhosted.org/packages/aa/fd/d0733fcb9086b8be4ebcfcda2d0312865d17d0d9884378b7cffb29d0763f/orjson-3.11.4-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1469d254b9884f984026bd9b0fa5bbab477a4bfe558bba6848086f6d43eb5e73", size = 136293, upload-time = "2025-10-24T15:49:42.347Z" }, + { url = "https://files.pythonhosted.org/packages/c2/d7/3c5514e806837c210492d72ae30ccf050ce3f940f45bf085bab272699ef4/orjson-3.11.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:68e44722541983614e37117209a194e8c3ad07838ccb3127d96863c95ec7f1e0", size = 140131, upload-time = "2025-10-24T15:49:43.638Z" }, + { url = "https://files.pythonhosted.org/packages/9c/dd/ba9d32a53207babf65bd510ac4d0faaa818bd0df9a9c6f472fe7c254f2e3/orjson-3.11.4-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:8e7805fda9672c12be2f22ae124dcd7b03928d6c197544fe12174b86553f3196", size = 406164, upload-time = "2025-10-24T15:49:45.498Z" }, + { url = "https://files.pythonhosted.org/packages/8e/f9/f68ad68f4af7c7bde57cd514eaa2c785e500477a8bc8f834838eb696a685/orjson-3.11.4-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:04b69c14615fb4434ab867bf6f38b2d649f6f300af30a6705397e895f7aec67a", size = 149859, upload-time = "2025-10-24T15:49:46.981Z" }, + { url = "https://files.pythonhosted.org/packages/b6/d2/7f847761d0c26818395b3d6b21fb6bc2305d94612a35b0a30eae65a22728/orjson-3.11.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:639c3735b8ae7f970066930e58cf0ed39a852d417c24acd4a25fc0b3da3c39a6", size = 139926, upload-time = "2025-10-24T15:49:48.321Z" }, + { url = "https://files.pythonhosted.org/packages/9f/37/acd14b12dc62db9a0e1d12386271b8661faae270b22492580d5258808975/orjson-3.11.4-cp313-cp313-win32.whl", hash = "sha256:6c13879c0d2964335491463302a6ca5ad98105fc5db3565499dcb80b1b4bd839", size = 136007, upload-time = "2025-10-24T15:49:49.938Z" }, + { url = "https://files.pythonhosted.org/packages/c0/a9/967be009ddf0a1fffd7a67de9c36656b28c763659ef91352acc02cbe364c/orjson-3.11.4-cp313-cp313-win_amd64.whl", hash = "sha256:09bf242a4af98732db9f9a1ec57ca2604848e16f132e3f72edfd3c5c96de009a", size = 131314, upload-time = "2025-10-24T15:49:51.248Z" }, + { url = "https://files.pythonhosted.org/packages/cb/db/399abd6950fbd94ce125cb8cd1a968def95174792e127b0642781e040ed4/orjson-3.11.4-cp313-cp313-win_arm64.whl", hash = "sha256:a85f0adf63319d6c1ba06fb0dbf997fced64a01179cf17939a6caca662bf92de", size = 126152, upload-time = "2025-10-24T15:49:52.922Z" }, +] + [[package]] name = "packaging" version = "25.0"