from__future__importannotationsimportthreadingfromcontextvarsimportContextVar,TokenfromtypingimportTYPE_CHECKING,Any,AsyncContextManager,ContextManagerimportpicodifrompicodi.supportimportExitStackifTYPE_CHECKING:fromcollections.abcimportAwaitable,Hashablefromstarlette.typesimportASGIApp,Receive,Scope,Send_context:ContextVar[dict[str,Any]]=ContextVar("picodi_starlette_context")_lock=threading.Lock()def_get_or_create_context()->dict[str,Any]:try:return_context.get()exceptLookupError:with_lock:# Double check if context was created by another threadtry:return_context.get()exceptLookupError:return_new_context()[0]def_new_context()->tuple[dict[str,Any],Token]:context={"store":{},"exit_stack":ExitStack(),}returncontext,_context.set(context)classRequestScope(picodi.ContextVarScope):defget(self,key:Hashable,*,global_key:Hashable)->Any:# noqa: ARG002context=_get_or_create_context()try:value=context["store"][key]exceptLookupError:raiseKeyError(key)fromNonereturnvaluedefset(self,key:Hashable,value:Any,*,global_key:Hashable,# noqa: ARG002)->None:context=_get_or_create_context()context["store"][key]=valuedefenter(self,context_manager:AsyncContextManager|ContextManager,*,global_key:Hashable,# noqa: ARG002)->Awaitable:context=_get_or_create_context()exit_stack=context["exit_stack"]returnexit_stack.enter_context(context_manager)defshutdown(self,exc:BaseException|None=None,*,global_key:Hashable# noqa: ARG002)->Any:context=_get_or_create_context()context["store"].clear()exit_stack=context["exit_stack"]returnexit_stack.close(exc)
[docs]classRequestScopeMiddleware:""" Starlette Pure ASGI Middleware for automatically initializing and closing request scoped dependencies """def__init__(self,app:ASGIApp,*,registry:picodi.Registry|None=None,dependencies_for_init:picodi.InitDependencies|None=None,)->None:self.app=appself._registry=registryorpicodi.registryself._dependencies_for_init=dependencies_for_initasyncdef__call__(self,scope:Scope,receive:Receive,send:Send)->None:ifscope["type"]!="http":awaitself.app(scope,receive,send)return# initialize context in middleware contextvars# this is needed to ensure that context is available and if starlette# will execute view in another thread and copy contextvars - it will# already have our context_,token=_new_context()ifself._dependencies_for_init:awaitself._registry.init(self._dependencies_for_init)try:awaitself.app(scope,receive,send)finally:awaitself._registry.shutdown(scope_class=RequestScope)_context.reset(token)