import sys from functools import wraps from ._impl import isasyncgenfunction class aclosing: def __init__(self, aiter): self._aiter = aiter async def __aenter__(self): return self._aiter async def __aexit__(self, *args): await self._aiter.aclose() # Very much derived from the one in contextlib, by copy/pasting and then # asyncifying everything. (Also I dropped the obscure support for using # context managers as function decorators. It could be re-added; I just # couldn't be bothered.) # So this is a derivative work licensed under the PSF License, which requires # the following notice: # # Copyright © 2001-2017 Python Software Foundation; All Rights Reserved class _AsyncGeneratorContextManager: def __init__(self, func, args, kwds): self._func_name = func.__name__ self._agen = func(*args, **kwds).__aiter__() async def __aenter__(self): if sys.version_info < (3, 5, 2): self._agen = await self._agen try: return await self._agen.asend(None) except StopAsyncIteration: raise RuntimeError("async generator didn't yield") from None async def __aexit__(self, type, value, traceback): async with aclosing(self._agen): if type is None: try: await self._agen.asend(None) except StopAsyncIteration: return False else: raise RuntimeError("async generator didn't stop") else: # It used to be possible to have type != None, value == None: # https://bugs.python.org/issue1705170 # but AFAICT this can't happen anymore. assert value is not None try: await self._agen.athrow(type, value, traceback) raise RuntimeError( "async generator didn't stop after athrow()" ) except StopAsyncIteration as exc: # Suppress StopIteration *unless* it's the same exception # that was passed to throw(). This prevents a # StopIteration raised inside the "with" statement from # being suppressed. return (exc is not value) except RuntimeError as exc: # Don't re-raise the passed in exception. (issue27112) if exc is value: return False # Likewise, avoid suppressing if a StopIteration exception # was passed to throw() and later wrapped into a # RuntimeError (see PEP 479). if (isinstance(value, (StopIteration, StopAsyncIteration)) and exc.__cause__ is value): return False raise except: # only re-raise if it's *not* the exception that was # passed to throw(), because __exit__() must not raise an # exception unless __exit__() itself failed. But throw() # has to raise the exception to signal propagation, so # this fixes the impedance mismatch between the throw() # protocol and the __exit__() protocol. # if sys.exc_info()[1] is value: return False raise def __enter__(self): raise RuntimeError( "use 'async with {func_name}(...)', not 'with {func_name}(...)'". format(func_name=self._func_name) ) def __exit__(self): # pragma: no cover assert False, """Never called, but should be defined""" def asynccontextmanager(func): """Like @contextmanager, but async.""" if not isasyncgenfunction(func): raise TypeError( "must be an async generator (native or from async_generator; " "if using @async_generator then @acontextmanager must be on top." ) @wraps(func) def helper(*args, **kwds): return _AsyncGeneratorContextManager(func, args, kwds) # A hint for sphinxcontrib-trio: helper.__returns_acontextmanager__ = True return helper