Skip to content
Merged
5 changes: 3 additions & 2 deletions .package/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,6 @@ services:
- traefik.udp.services.cldap.loadbalancer.server.port=389

global_ldap_server:
image: ghcr.io/multidirectorylab/multidirectory:${VERSION:-latest}
restart: unless-stopped
deploy:
mode: replicated
replicas: 2
Expand All @@ -167,10 +165,13 @@ services:
reservations:
cpus: "0.25"
memory: 100M
image: ghcr.io/multidirectorylab/multidirectory:${VERSION:-latest}
restart: unless-stopped
environment:
- SERVICE_NAME=global_ldap_server
volumes:
- ./certs:/certs
- ./logs:/app/logs
- ldap_keytab:/LDAP_keytab/
env_file:
- .env
Expand Down
49 changes: 24 additions & 25 deletions app/ldap_protocol/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from io import BytesIO
from ipaddress import IPv4Address, IPv6Address, ip_address
from traceback import format_exc
from typing import Literal, cast, overload
from typing import Literal, NewType, cast, overload

from dishka import AsyncContainer, Scope
from loguru import logger
Expand All @@ -26,19 +26,11 @@

from .data_logger import DataLogger

log = logger.bind(name="ldap")
log.add(
"logs/ldap_{time:DD-MM-YYYY}.log",
filter=lambda rec: rec["extra"].get("name") == "ldap",
retention="10 days",
rotation="1d",
colorize=False,
enqueue=True,
)

infinity = cast("int", math.inf)
pp_v2 = ProxyProtocolV2()

ServerLogger = NewType("ServerLogger", type[logger]) # type: ignore


class PoolClientHandler:
"""Async client handler.
Expand All @@ -53,14 +45,20 @@ class PoolClientHandler:

ssl_context: ssl.SSLContext | None = None

def __init__(self, settings: Settings, container: AsyncContainer):
def __init__(
self,
settings: Settings,
container: AsyncContainer,
log: ServerLogger,
):
"""Set workers number for single client concurrent handling."""
self.container = container
self.settings = settings
self.num_workers = self.settings.COROUTINES_NUM_PER_CLIENT
self._size = self.settings.TCP_PACKET_SIZE

self.logger = DataLogger(log, is_full=self.settings.DEBUG)
self.log = log
self.logger = DataLogger(self.log, is_full=self.settings.DEBUG)

self._load_ssl_context()

Expand All @@ -79,7 +77,7 @@ async def __call__(
)
ldap_session.ip = addr

logger.info(f"Connection {addr} opened")
self.log.info(f"Connection {addr} opened")

try:
async with session_scope(scope=Scope.REQUEST) as r:
Expand All @@ -92,7 +90,7 @@ async def __call__(
network_policy_use_case,
)
except PermissionError:
log.warning(f"Whitelist violation from {addr}")
self.log.warning(f"Whitelist violation from {addr}")
return

async with asyncio.TaskGroup() as tg:
Expand All @@ -117,7 +115,9 @@ async def __call__(
)

except* RuntimeError as err:
log.error(f"Response handling error {err}: {format_exc()}")
self.log.error(
f"Response handling error {err}: {format_exc()}",
)

finally:
await session_scope.close()
Expand All @@ -126,18 +126,18 @@ async def __call__(
writer.close()
await writer.wait_closed()

logger.info(f"Connection {addr} closed")
self.log.info(f"Connection {addr} closed")

def _load_ssl_context(self) -> None:
"""Load SSL context for LDAPS."""
if self.settings.USE_CORE_TLS and self.settings.LDAP_LOAD_SSL_CERT:
if not self.settings.check_certs_exist():
log.critical("Certs not found, exiting...")
self.log.critical("Certs not found, exiting...")
raise SystemExit(1)

cert_name = self.settings.SSL_CERT
key_name = self.settings.SSL_KEY
log.success("Found existing cert and key, loading...")
self.log.success("Found existing cert and key, loading...")
self.ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
self.ssl_context.load_cert_chain(cert_name, key_name)

Expand Down Expand Up @@ -166,7 +166,7 @@ def _extract_proxy_protocol_address(
header_length = int.from_bytes(data[14:16], "big")
return addr, data[16 + header_length :]
except (ValueError, ProxyProtocolIncompleteError) as err:
log.error(f"Proxy Protocol processing error: {err}")
self.log.error(f"Proxy Protocol processing error: {err}")
return peer_addr, data

@overload
Expand Down Expand Up @@ -279,7 +279,7 @@ async def _handle_request(
request = LDAPRequestMessage.from_bytes(data)

except (ValidationError, IndexError, KeyError, ValueError) as err:
log.error(f"Invalid schema {format_exc()}")
self.log.error(f"Invalid schema {format_exc()}")

writer.write(LDAPRequestMessage.from_err(data, err).encode())
await writer.drain()
Expand Down Expand Up @@ -440,15 +440,14 @@ async def _run_server(server: asyncio.base_events.Server) -> None:
async with server:
await server.serve_forever()

@staticmethod
def log_addrs(server: asyncio.base_events.Server) -> None:
def log_addrs(self, server: asyncio.base_events.Server) -> None:
addrs = ", ".join(str(sock.getsockname()) for sock in server.sockets)
log.info(f"Server on {addrs}")
self.log.info(f"Server on {addrs}")

async def start(self) -> None:
"""Run and log tcp server."""
server = await self._get_server()
log.info(
self.log.info(
f"started {'DEBUG' if self.settings.DEBUG else 'PROD'} "
f"{'LDAPS' if self.settings.USE_CORE_TLS else 'LDAP'} server",
)
Expand Down
26 changes: 23 additions & 3 deletions app/multidirectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
from ldap_protocol.identity.exceptions import UnauthorizedError
from ldap_protocol.policies.audit.events.handler import AuditEventHandler
from ldap_protocol.policies.audit.events.sender import AuditEventSenderManager
from ldap_protocol.server import PoolClientHandler
from ldap_protocol.server import PoolClientHandler, ServerLogger
from ldap_protocol.udp_server import CLDAPUDPServer
from schedule import scheduler_factory

Expand Down Expand Up @@ -199,7 +199,17 @@ async def ldap_factory(settings: Settings) -> None:
)

settings = await container.get(Settings)
servers.append(PoolClientHandler(settings, container).start())
log: ServerLogger = logger.bind(name="ldap") # type: ignore
log.add(
"logs/ldap_{time:DD-MM-YYYY}.log",
filter=lambda rec: rec["extra"].get("name") == "ldap",
retention="10 days",
rotation="1d",
colorize=False,
enqueue=True,
)

servers.append(PoolClientHandler(settings, container, log).start())

await asyncio.gather(*servers)

Expand Down Expand Up @@ -234,7 +244,17 @@ async def global_ldap_server_factory(settings: Settings) -> None:
)

settings = await container.get(Settings)
servers.append(PoolClientHandler(settings, container).start())
log: ServerLogger = logger.bind(name="global_catalog") # type: ignore
log.add(
"logs/global_catalog_{time:DD-MM-YYYY}.log",
filter=lambda rec: rec["extra"].get("name") == "global_catalog",
retention="10 days",
rotation="1d",
colorize=False,
enqueue=True,
)

servers.append(PoolClientHandler(settings, container, log).start())

await asyncio.gather(*servers)

Expand Down
4 changes: 2 additions & 2 deletions docker-compose.dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,14 @@ services:
- traefik.udp.services.cldap.loadbalancer.server.port=389

global_ldap_server:
image: multidirectory
restart: unless-stopped
build:
context: .
dockerfile: ./.docker/dev.Dockerfile
args:
DOCKER_BUILDKIT: 1
target: runtime
image: multidirectory
restart: unless-stopped
deploy:
mode: replicated
replicas: 2
Expand Down
4 changes: 2 additions & 2 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,14 @@ services:
- traefik.udp.services.cldap.loadbalancer.server.port=389

global_ldap_server:
image: multidirectory
restart: unless-stopped
build:
context: .
dockerfile: ./.docker/dev.Dockerfile
args:
DOCKER_BUILDKIT: 1
target: runtime
image: multidirectory
restart: unless-stopped
deploy:
mode: replicated
replicas: 2
Expand Down
2 changes: 1 addition & 1 deletion interface
4 changes: 3 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
)
from dishka.integrations.fastapi import setup_dishka
from fastapi import FastAPI, Request, Response
from loguru import logger
from multidirectory import _create_basic_app
from sqlalchemy import schema, text
from sqlalchemy.ext.asyncio import (
Expand Down Expand Up @@ -1079,8 +1080,9 @@ async def handler(
) -> AsyncIterator[PoolClientHandler]:
"""Create test handler."""
settings.set_test_port()
test_log = logger.bind(name="ldap_test")
async with container(scope=Scope.APP) as app_scope:
yield PoolClientHandler(settings, app_scope)
yield PoolClientHandler(settings, app_scope, test_log) # type: ignore


@pytest_asyncio.fixture(scope="function")
Expand Down