from __future__ import annotations
import asyncio
import inspect
import threading
from collections.abc import AsyncGenerator, Awaitable, Callable, Generator
from contextlib import (
_AsyncGeneratorContextManager,
_GeneratorContextManager,
asynccontextmanager,
contextmanager,
)
from dataclasses import dataclass
from typing import Any, AsyncContextManager, ContextManager, Literal, TypeVar, overload
from picodi._internal import (
async_injection_context,
build_depend_tree,
sync_injection_context,
)
from picodi._scopes import AutoScope, ManualScope, NullScope, ScopeType
from picodi._types import (
DependencyCallable,
DependNode,
Depends,
InitDependencies,
LifespanScopeClass,
)
from picodi.support import (
ExitStack,
NullAwaitable,
call_cm_async,
call_cm_sync,
is_async_function,
)
T = TypeVar("T")
TC = TypeVar("TC", bound=Callable)
class Storage:
def __init__(self) -> None:
self.deps: dict[DependencyCallable, Provider] = {}
self.overrides: dict[DependencyCallable, DependencyCallable] = {}
self.touched_dependencies: set[DependencyCallable] = set()
self.scopes: dict[type[ScopeType], ScopeType] = {}
def add(
self,
dependency: DependencyCallable,
scope_class: type[ScopeType] = NullScope,
override: bool = False,
) -> None:
with lock:
if scope_class not in self.scopes:
self.scopes[scope_class] = scope_class()
if dependency not in self.deps or override:
self.deps[dependency] = Provider.from_dependency(
dependency=dependency,
scope=self.scopes[scope_class],
)
def get(self, dependency: DependencyCallable) -> Provider:
dependency = self.get_dep_or_override(dependency)
self.touched_dependencies.add(dependency)
if dependency not in self.deps:
self.add(dependency)
return self.deps[dependency]
def get_dep_or_override(self, dependency: DependencyCallable) -> DependencyCallable:
return self.overrides.get(dependency, dependency)
def get_override(self, dependency: DependencyCallable) -> DependencyCallable | None:
return self.overrides.get(dependency)
def get_original(self, override: DependencyCallable) -> DependencyCallable | None:
for original, overriden in self.overrides.items():
if overriden == override:
return original
return None
def has_overrides(self) -> bool:
return bool(self.overrides)
@overload
def resolve(
self,
dependencies: list[DependencyCallable],
registry: Registry,
is_async: Literal[False],
) -> ContextManager:
"""
Return sync context manager that will return tuple of results
"""
@overload
def resolve(
self,
dependencies: list[DependencyCallable],
registry: Registry,
is_async: Literal[True],
) -> AsyncContextManager:
"""
Return async context manager that will return tuple of results
"""
def resolve(
self, dependencies: list[DependencyCallable], registry: Registry, is_async: bool
) -> AsyncContextManager | ContextManager:
signature = inspect.Signature(
parameters=[
inspect.Parameter(
f"dep{i}",
inspect.Parameter.POSITIONAL_ONLY,
default=Depends(dep),
)
for i, dep in enumerate(dependencies, start=1)
],
)
dependant = DependNode(
value=lambda *args: args if len(args) > 1 else args[0],
name=None,
dependencies=[
build_depend_tree(dep, name=f"dep{i}", storage=self)
for i, dep in enumerate(dependencies, start=1)
],
)
resolver = async_injection_context if is_async else sync_injection_context
return resolver(
dependant,
signature,
registry,
args=(),
kwargs={},
)
[docs]
class Registry:
"""
Manages dependencies and overrides.
"""
def __init__(self, for_init: InitDependencies | None = None) -> None:
self._storage = Storage()
self._for_init: list[InitDependencies] = [for_init] if for_init else []
[docs]
def add(
self,
dependency: DependencyCallable,
scope_class: type[ScopeType] = NullScope,
) -> None:
"""
Add a dependency to the registry and set scope_class for it.
"""
self._storage.add(dependency, scope_class, override=True)
[docs]
def add_for_init(self, dependencies: InitDependencies) -> None:
"""
Add a dependencies to the list of dependencies to initialize.
"""
if dependencies not in self._for_init:
self._for_init.append(dependencies)
[docs]
def set_scope(
self, scope_class: type[ScopeType], *, auto_init: bool = False
) -> Callable[[TC], TC]:
"""
Decorator to declare a dependency.
Should be placed last in the decorator chain (on top).
:param scope_class: specify the scope class to use it for the dependency.
:param auto_init: if set to ``True``, the dependency will be added to the list
of dependencies to initialize. This is useful for dependencies that
need to be initialized before the application starts.
"""
def decorator(fn: TC) -> TC:
self._storage.add(
fn,
scope_class=scope_class,
override=True,
)
if auto_init:
self.add_for_init([fn])
return fn
return decorator
[docs]
def init(self, dependencies: InitDependencies | None = None) -> Awaitable:
"""
Call this method to init dependencies. Usually, it should be called
when your application is starting up.
This method works both for synchronous and asynchronous dependencies.
If you call it without ``await``, it will initialize only sync dependencies.
If you call it ``await init(...)``,
it will initialize both sync and async dependencies.
:param dependencies: dependencies to initialize. If this argument
is passed - init dependencies specified in the registry will be ignored.
"""
if dependencies is None:
dependencies = self._for_init_list()
elif callable(dependencies):
dependencies = dependencies()
sync_deps: list[DependencyCallable] = []
async_deps: list[DependencyCallable] = []
for dep in dependencies:
provider = self._storage.get(dep)
if not isinstance(provider.scope, ManualScope):
raise ValueError(
f"Dependency {dep} is not in ManualScope, "
"you cannot initialize it manually."
)
if provider.is_async:
async_deps.append(dep)
else:
sync_deps.append(dep)
if sync_deps:
call_cm_sync(self._resolve_many(*sync_deps))
if async_deps:
return call_cm_async(self._aresolve_many(*async_deps))
return NullAwaitable()
def _for_init_list(self) -> list[DependencyCallable]:
dependencies: list[DependencyCallable] = []
for item in self._for_init:
if callable(item):
item = item()
dependencies.extend(item)
return dependencies
[docs]
def shutdown(self, scope_class: LifespanScopeClass = ManualScope) -> Awaitable:
"""
Call this method to close dependencies. Usually, it should be called
when your application is shut down.
This method works both for synchronous and asynchronous dependencies.
If you call it without ``await``, it will shutdown only sync dependencies.
If you call it ``await shutdown()``, it will shutdown both
sync and async dependencies.
If you not pass any arguments,
it will shutdown subclasses of :class:`ManualScope`.
:param scope_class: you can specify the scope class to shutdown. If passed -
only dependencies of this scope class and its subclasses will be shutdown.
"""
tasks = [
instance.shutdown(global_key=self.shutdown) # type: ignore[union-attr]
for klass, instance in self._storage.scopes.items()
if issubclass(klass, scope_class)
]
tasks = [task for task in tasks if not isinstance(task, NullAwaitable)]
return asyncio.gather(*tasks) if tasks else NullAwaitable()
[docs]
@contextmanager
def lifespan(self) -> Generator[None]:
"""
Context manager to manage the lifespan of the application.
It will automatically call init and shutdown methods.
"""
self.init()
try:
yield
finally:
self.shutdown()
[docs]
@asynccontextmanager
async def alifespan(self) -> AsyncGenerator[None]:
"""
Async context manager to manage the lifespan of the application.
It will automatically call init and shutdown methods.
"""
await self.init()
try:
yield
finally:
await self.shutdown()
@property
def touched(self) -> frozenset[DependencyCallable]:
"""
Get all dependencies that were used during the picodi lifecycle.
This method will return a frozenset of dependencies that were resolved.
It will not include dependencies that were overridden.
Primarily used for testing purposes.
For example, you can check that mongo
database was used in the test and clear it after the test.
"""
return frozenset(self._storage.touched_dependencies)
[docs]
def override(
self,
dependency: DependencyCallable,
new_dependency: DependencyCallable | None,
) -> ContextManager[None]:
"""
Override a dependency with a new one. It can be used as a context manager
or as a regular method call. New dependency will be
added to the registry.
:param dependency: dependency to override
:param new_dependency: new dependency to use. If explicitly set to ``None``,
it will remove the override.
Examples
--------
.. code-block:: python
with registry.override(get_settings, real_settings):
pass
registry.override(get_settings, real_settings)
registry.override(get_settings, None) # clear override
"""
if self._storage.get_original(dependency):
raise ValueError("Cannot override an overridden dependency")
with lock:
call_dependency = self._storage.overrides.get(dependency)
if new_dependency is not None:
self._storage.add(new_dependency)
if dependency is new_dependency:
raise ValueError("Cannot override a dependency with itself")
self._storage.overrides[dependency] = new_dependency
else:
self._storage.overrides.pop(dependency, None)
@contextmanager
def manage_context() -> Generator[None]:
try:
yield
finally:
self.override(dependency, call_dependency)
return manage_context()
@overload
def resolve(self, dependency: Callable[..., Generator[T]]) -> ContextManager[T]:
"""
Resolve a dependency that is a generator function synchronously.
"""
@overload
def resolve(self, dependency: Callable[..., T]) -> ContextManager[T]:
"""
Resolve a dependency that is a regular function synchronously.
"""
[docs]
def resolve(self, dependency: DependencyCallable) -> ContextManager[Any]:
"""
Resolve a dependency synchronously. Returns a context manager that will
return the result of the dependency.
:param dependency: dependency to resolve.
:return: sync context manager.
"""
return self._storage.resolve([dependency], self, is_async=False)
@overload
def aresolve(
self, dependency: Callable[..., Generator[T]]
) -> AsyncContextManager[T]:
"""
Resolve a dependency that is a generator function asynchronously.
"""
@overload
def aresolve(
self, dependency: Callable[..., AsyncGenerator[T]]
) -> AsyncContextManager[T]:
"""
Resolve a dependency that is an async generator function asynchronously.
"""
@overload
def aresolve(
self, dependency: Callable[..., Awaitable[T]]
) -> AsyncContextManager[T]:
"""
Resolve a dependency that is an awaitable asynchronously.
"""
@overload
def aresolve(self, dependency: Callable[..., T]) -> AsyncContextManager[T]:
"""
Resolve a dependency that is a regular function asynchronously.
"""
[docs]
def aresolve(self, dependency: DependencyCallable) -> AsyncContextManager[Any]:
"""
Resolve a dependency asynchronously. Returns a context manager that will
return the result of the dependency.
Also can resolve sync dependencies in async context.
:param dependency: dependency to resolve.
:return: async context manager.
"""
return self._storage.resolve([dependency], self, is_async=True)
def _resolve_many(self, *dependencies: DependencyCallable) -> ContextManager[Any]:
"""
Internal method to resolve multiple dependencies synchronously.
"""
return self._storage.resolve(list(dependencies), self, is_async=False)
def _aresolve_many(
self, *dependencies: DependencyCallable
) -> AsyncContextManager[Any]:
"""
Internal method to resolve multiple dependencies asynchronously.
"""
return self._storage.resolve(list(dependencies), self, is_async=True)
[docs]
def clear_overrides(self) -> None:
"""
Clear all overrides. It will remove all overrides, but keep the dependencies.
"""
self._storage.overrides.clear()
def clear_touched(self) -> None:
"""
Clear the touched dependencies.
It will remove list of all dependencies resolved during the picodi lifecycle.
"""
self._storage.touched_dependencies.clear()
def _clear(self) -> None:
"""
Clear the registry. It will remove all dependencies, overrides
and touched dependencies.
This method will not close any dependencies. So you need to manually call
:func:`shutdown` before this method.
It is useful for testing purposes, when you want to clear the registry
and start from scratch.
"""
self._storage.deps.clear()
self._storage.overrides.clear()
self._storage.touched_dependencies.clear()
self._for_init.clear()
@dataclass(frozen=True)
class Provider:
dependency: DependencyCallable
is_async: bool
scope: ScopeType
@classmethod
def from_dependency(
cls,
dependency: DependencyCallable,
scope: ScopeType,
) -> Provider:
return cls(
dependency=dependency,
is_async=is_async_function(dependency),
scope=scope,
)
def get_scope(self) -> ScopeType:
return self.scope
def resolve_value(
self,
exit_stack: ExitStack | None,
registry: Registry,
dependant: Callable,
kwargs: dict[str, Any],
) -> Any:
signature = inspect.signature(self.dependency)
registry_param = signature.parameters.get("registry")
if registry_param and registry_param.default is signature.empty:
kwargs["registry"] = registry
scope = self.get_scope()
value_or_gen = self.dependency(**kwargs)
if self.is_async:
async def resolve_value_inner() -> Any:
value_or_gen_ = value_or_gen
if inspect.iscoroutine(value_or_gen):
value_or_gen_ = await value_or_gen_
if inspect.isasyncgen(value_or_gen_):
context_manager = asynccontextmanager(
lambda *args, **kwargs: value_or_gen_
)
if isinstance(scope, AutoScope):
assert exit_stack is not None, "exit_stack is required"
return await exit_stack.enter_context(context_manager())
return await scope.enter(context_manager(), global_key=dependant)
elif isinstance(value_or_gen_, _AsyncGeneratorContextManager):
if isinstance(scope, AutoScope):
assert exit_stack is not None, "exit_stack is required"
return await exit_stack.enter_context(
_recreate_cm(value_or_gen_)
)
return await scope.enter(
_recreate_cm(value_or_gen_), global_key=dependant
)
return value_or_gen_
return resolve_value_inner()
if inspect.isgenerator(value_or_gen):
context_manager = contextmanager(lambda *args, **kwargs: value_or_gen)
if isinstance(scope, AutoScope):
assert exit_stack is not None, "exit_stack is required"
return exit_stack.enter_context(context_manager())
return scope.enter(context_manager(), global_key=dependant)
elif isinstance(value_or_gen, _GeneratorContextManager):
if isinstance(scope, AutoScope):
assert exit_stack is not None, "exit_stack is required"
return exit_stack.enter_context(_recreate_cm(value_or_gen))
return scope.enter(_recreate_cm(value_or_gen), global_key=dependant)
return value_or_gen
def _recreate_cm(
gen: AsyncContextManager | ContextManager,
) -> AsyncContextManager | ContextManager:
return gen._recreate_cm() # type: ignore[union-attr] # noqa: SLF001
lock = threading.RLock()
registry = Registry()
registry.__doc__ = """
Picodi registry. You can use it to register dependencies, scopes, overrides,
initialize and shutdown dependencies.
"""