| from contextlib import asynccontextmanager as asynccontextmanager | |
| from typing import AsyncGenerator, ContextManager, TypeVar | |
| import anyio | |
| from anyio import CapacityLimiter | |
| from starlette.concurrency import iterate_in_threadpool as iterate_in_threadpool # noqa | |
| from starlette.concurrency import run_in_threadpool as run_in_threadpool # noqa | |
| from starlette.concurrency import ( # noqa | |
| run_until_first_complete as run_until_first_complete, | |
| ) | |
| _T = TypeVar("_T") | |
| async def contextmanager_in_threadpool( | |
| cm: ContextManager[_T], | |
| ) -> AsyncGenerator[_T, None]: | |
| # blocking __exit__ from running waiting on a free thread | |
| # can create race conditions/deadlocks if the context manager itself | |
| # has its own internal pool (e.g. a database connection pool) | |
| # to avoid this we let __exit__ run without a capacity limit | |
| # since we're creating a new limiter for each call, any non-zero limit | |
| # works (1 is arbitrary) | |
| exit_limiter = CapacityLimiter(1) | |
| try: | |
| yield await run_in_threadpool(cm.__enter__) | |
| except Exception as e: | |
| ok = bool( | |
| await anyio.to_thread.run_sync( | |
| cm.__exit__, type(e), e, None, limiter=exit_limiter | |
| ) | |
| ) | |
| if not ok: | |
| raise e | |
| else: | |
| await anyio.to_thread.run_sync( | |
| cm.__exit__, None, None, None, limiter=exit_limiter | |
| ) | |