111 lines
4.3 KiB
Python
111 lines
4.3 KiB
Python
|
import sys
|
||
|
from functools import wraps
|
||
|
from ._impl import isasyncgenfunction
|
||
|
|
||
|
|
||
|
class aclosing:
|
||
|
def __init__(self, aiter):
|
||
|
self._aiter = aiter
|
||
|
|
||
|
async def __aenter__(self):
|
||
|
return self._aiter
|
||
|
|
||
|
async def __aexit__(self, *args):
|
||
|
await self._aiter.aclose()
|
||
|
|
||
|
|
||
|
# Very much derived from the one in contextlib, by copy/pasting and then
|
||
|
# asyncifying everything. (Also I dropped the obscure support for using
|
||
|
# context managers as function decorators. It could be re-added; I just
|
||
|
# couldn't be bothered.)
|
||
|
# So this is a derivative work licensed under the PSF License, which requires
|
||
|
# the following notice:
|
||
|
#
|
||
|
# Copyright © 2001-2017 Python Software Foundation; All Rights Reserved
|
||
|
class _AsyncGeneratorContextManager:
|
||
|
def __init__(self, func, args, kwds):
|
||
|
self._func_name = func.__name__
|
||
|
self._agen = func(*args, **kwds).__aiter__()
|
||
|
|
||
|
async def __aenter__(self):
|
||
|
if sys.version_info < (3, 5, 2):
|
||
|
self._agen = await self._agen
|
||
|
try:
|
||
|
return await self._agen.asend(None)
|
||
|
except StopAsyncIteration:
|
||
|
raise RuntimeError("async generator didn't yield") from None
|
||
|
|
||
|
async def __aexit__(self, type, value, traceback):
|
||
|
async with aclosing(self._agen):
|
||
|
if type is None:
|
||
|
try:
|
||
|
await self._agen.asend(None)
|
||
|
except StopAsyncIteration:
|
||
|
return False
|
||
|
else:
|
||
|
raise RuntimeError("async generator didn't stop")
|
||
|
else:
|
||
|
# It used to be possible to have type != None, value == None:
|
||
|
# https://bugs.python.org/issue1705170
|
||
|
# but AFAICT this can't happen anymore.
|
||
|
assert value is not None
|
||
|
try:
|
||
|
await self._agen.athrow(type, value, traceback)
|
||
|
raise RuntimeError(
|
||
|
"async generator didn't stop after athrow()"
|
||
|
)
|
||
|
except StopAsyncIteration as exc:
|
||
|
# Suppress StopIteration *unless* it's the same exception
|
||
|
# that was passed to throw(). This prevents a
|
||
|
# StopIteration raised inside the "with" statement from
|
||
|
# being suppressed.
|
||
|
return (exc is not value)
|
||
|
except RuntimeError as exc:
|
||
|
# Don't re-raise the passed in exception. (issue27112)
|
||
|
if exc is value:
|
||
|
return False
|
||
|
# Likewise, avoid suppressing if a StopIteration exception
|
||
|
# was passed to throw() and later wrapped into a
|
||
|
# RuntimeError (see PEP 479).
|
||
|
if (isinstance(value, (StopIteration, StopAsyncIteration))
|
||
|
and exc.__cause__ is value):
|
||
|
return False
|
||
|
raise
|
||
|
except:
|
||
|
# only re-raise if it's *not* the exception that was
|
||
|
# passed to throw(), because __exit__() must not raise an
|
||
|
# exception unless __exit__() itself failed. But throw()
|
||
|
# has to raise the exception to signal propagation, so
|
||
|
# this fixes the impedance mismatch between the throw()
|
||
|
# protocol and the __exit__() protocol.
|
||
|
#
|
||
|
if sys.exc_info()[1] is value:
|
||
|
return False
|
||
|
raise
|
||
|
|
||
|
def __enter__(self):
|
||
|
raise RuntimeError(
|
||
|
"use 'async with {func_name}(...)', not 'with {func_name}(...)'".
|
||
|
format(func_name=self._func_name)
|
||
|
)
|
||
|
|
||
|
def __exit__(self): # pragma: no cover
|
||
|
assert False, """Never called, but should be defined"""
|
||
|
|
||
|
|
||
|
def asynccontextmanager(func):
|
||
|
"""Like @contextmanager, but async."""
|
||
|
if not isasyncgenfunction(func):
|
||
|
raise TypeError(
|
||
|
"must be an async generator (native or from async_generator; "
|
||
|
"if using @async_generator then @acontextmanager must be on top."
|
||
|
)
|
||
|
|
||
|
@wraps(func)
|
||
|
def helper(*args, **kwds):
|
||
|
return _AsyncGeneratorContextManager(func, args, kwds)
|
||
|
|
||
|
# A hint for sphinxcontrib-trio:
|
||
|
helper.__returns_acontextmanager__ = True
|
||
|
return helper
|