diff --git a/app/alembic/versions/16a9fa2c1f1e_rename_readonly_group.py b/app/alembic/versions/16a9fa2c1f1e_rename_readonly_group.py index b331dddd5..3cfbca4a2 100644 --- a/app/alembic/versions/16a9fa2c1f1e_rename_readonly_group.py +++ b/app/alembic/versions/16a9fa2c1f1e_rename_readonly_group.py @@ -43,7 +43,6 @@ def upgrade(container: AsyncContainer) -> None: # noqa: ARG001 return ro_dir.name = READ_ONLY_GROUP_NAME - ro_dir.create_path(ro_dir.parent, ro_dir.get_dn_prefix()) session.execute( @@ -91,7 +90,6 @@ def downgrade(container: AsyncContainer) -> None: # noqa: ARG001 return ro_dir.name = "readonly domain controllers" - ro_dir.create_path(ro_dir.parent, ro_dir.get_dn_prefix()) session.execute( diff --git a/app/alembic/versions/71e642808369_add_directory_is_system.py b/app/alembic/versions/71e642808369_add_directory_is_system.py index 2526190e4..48ece1bc4 100644 --- a/app/alembic/versions/71e642808369_add_directory_is_system.py +++ b/app/alembic/versions/71e642808369_add_directory_is_system.py @@ -56,8 +56,13 @@ async def _indicate_system_directories( if not base_dn_list: return - for base_dn in base_dn_list: - base_dn.is_system = True + await session.execute( + update(Directory) + .where( + qa(Directory.parent_id).is_(None), + ) + .values(is_system=True), + ) await session.flush() diff --git a/app/ldap_protocol/auth/setup_gateway.py b/app/ldap_protocol/auth/setup_gateway.py index b5bfe580a..4294066c5 100644 --- a/app/ldap_protocol/auth/setup_gateway.py +++ b/app/ldap_protocol/auth/setup_gateway.py @@ -16,6 +16,7 @@ AttributeValueValidator, ) from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO +from ldap_protocol.utils.async_cache import base_directories_cache from ldap_protocol.utils.helpers import create_object_sid, generate_domain_sid from ldap_protocol.utils.queries import get_domain_object_class from password_utils import PasswordUtils @@ -113,6 +114,7 @@ async def setup_enviroment( domain=domain, parent=domain, ) + base_directories_cache.clear() except Exception: import traceback @@ -132,13 +134,13 @@ async def create_dir( is_system=is_system, object_class=data["object_class"], name=data["name"], - parent=parent, ) dir_.groups = [] dir_.create_path(parent, dir_.get_dn_prefix()) self._session.add(dir_) await self._session.flush() + dir_.parent_id = parent.id if parent else None await self._session.refresh(dir_, ["id"]) self._session.add( diff --git a/app/ldap_protocol/utils/async_cache.py b/app/ldap_protocol/utils/async_cache.py new file mode 100644 index 000000000..f66f45cf3 --- /dev/null +++ b/app/ldap_protocol/utils/async_cache.py @@ -0,0 +1,43 @@ +"""Async cache implementation.""" + +import time +from functools import wraps +from typing import Callable, Generic, TypeVar + +from entities import Directory + +T = TypeVar("T") +DEFAULT_CACHE_TIME = 5 * 60 # 5 minutes + + +class AsyncTTLCache(Generic[T]): + def __init__(self, ttl: int | None = DEFAULT_CACHE_TIME) -> None: + self._ttl = ttl + self._value: T | None = None + self._expires_at: float | None = None + + def clear(self) -> None: + self._value = None + self._expires_at = None + + def __call__(self, func: Callable) -> Callable: + @wraps(func) + async def wrapper(*args: tuple, **kwargs: dict) -> T: + if self._value is not None: + if not self._expires_at or self._expires_at > time.monotonic(): + return self._value + self.clear() + + result = await func(*args, **kwargs) + + self._value = result + self._expires_at = ( + time.monotonic() + self._ttl if self._ttl else None + ) + + return result + + return wrapper + + +base_directories_cache = AsyncTTLCache[list[Directory]]() diff --git a/app/ldap_protocol/utils/queries.py b/app/ldap_protocol/utils/queries.py index a1f9243de..0038ea9e2 100644 --- a/app/ldap_protocol/utils/queries.py +++ b/app/ldap_protocol/utils/queries.py @@ -25,6 +25,7 @@ queryable_attr as qa, ) +from .async_cache import base_directories_cache from .const import EMAIL_RE, GRANT_DN_STRING from .helpers import ( create_integer_hash, @@ -35,13 +36,25 @@ ) +@base_directories_cache async def get_base_directories(session: AsyncSession) -> list[Directory]: """Get base domain directories.""" result = await session.execute( select(Directory) .filter(qa(Directory.parent_id).is_(None)), ) # fmt: skip - return list(result.scalars().all()) + res = [] + for dir_ in result.scalars(): + new_dir = Directory( + **{ + k: v + for k, v in dir_.__dict__.items() + if not k.startswith("_") and k != "id" + }, + ) + new_dir.id = dir_.id + res.append(new_dir) + return res async def get_user(session: AsyncSession, name: str) -> User | None: @@ -362,7 +375,7 @@ async def create_group( dir_ = Directory( object_class="", name=name, - parent=parent, + parent_id=parent.id, ) session.add(dir_) await session.flush() diff --git a/interface b/interface index f31962020..95ed5e191 160000 --- a/interface +++ b/interface @@ -1 +1 @@ -Subproject commit f31962020a6689e6a4c61fb3349db5b5c7895f92 +Subproject commit 95ed5e191cdafa07b1dfac96a1659926679ead97