987 lines
39 KiB
Python
987 lines
39 KiB
Python
# Copyright 2019 gRPC authors.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
"""Interceptors implementation of gRPC Asyncio Python."""
|
||
import asyncio
|
||
import collections
|
||
import functools
|
||
from abc import ABCMeta, abstractmethod
|
||
from typing import Callable, Optional, Iterator, Sequence, Union, Awaitable, AsyncIterable
|
||
|
||
import grpc
|
||
from grpc._cython import cygrpc
|
||
|
||
from . import _base_call
|
||
from ._call import UnaryUnaryCall, UnaryStreamCall, StreamUnaryCall, StreamStreamCall, AioRpcError
|
||
from ._call import _RPC_ALREADY_FINISHED_DETAILS, _RPC_HALF_CLOSED_DETAILS
|
||
from ._call import _API_STYLE_ERROR
|
||
from ._utils import _timeout_to_deadline
|
||
from ._typing import (RequestType, SerializingFunction, DeserializingFunction,
|
||
ResponseType, DoneCallbackType, RequestIterableType,
|
||
ResponseIterableType)
|
||
from ._metadata import Metadata
|
||
|
||
_LOCAL_CANCELLATION_DETAILS = 'Locally cancelled by application!'
|
||
|
||
|
||
class ServerInterceptor(metaclass=ABCMeta):
|
||
"""Affords intercepting incoming RPCs on the service-side.
|
||
|
||
This is an EXPERIMENTAL API.
|
||
"""
|
||
|
||
@abstractmethod
|
||
async def intercept_service(
|
||
self, continuation: Callable[[grpc.HandlerCallDetails], Awaitable[
|
||
grpc.RpcMethodHandler]],
|
||
handler_call_details: grpc.HandlerCallDetails
|
||
) -> grpc.RpcMethodHandler:
|
||
"""Intercepts incoming RPCs before handing them over to a handler.
|
||
|
||
Args:
|
||
continuation: A function that takes a HandlerCallDetails and
|
||
proceeds to invoke the next interceptor in the chain, if any,
|
||
or the RPC handler lookup logic, with the call details passed
|
||
as an argument, and returns an RpcMethodHandler instance if
|
||
the RPC is considered serviced, or None otherwise.
|
||
handler_call_details: A HandlerCallDetails describing the RPC.
|
||
|
||
Returns:
|
||
An RpcMethodHandler with which the RPC may be serviced if the
|
||
interceptor chooses to service this RPC, or None otherwise.
|
||
"""
|
||
|
||
|
||
class ClientCallDetails(
|
||
collections.namedtuple(
|
||
'ClientCallDetails',
|
||
('method', 'timeout', 'metadata', 'credentials', 'wait_for_ready')),
|
||
grpc.ClientCallDetails):
|
||
"""Describes an RPC to be invoked.
|
||
|
||
This is an EXPERIMENTAL API.
|
||
|
||
Args:
|
||
method: The method name of the RPC.
|
||
timeout: An optional duration of time in seconds to allow for the RPC.
|
||
metadata: Optional metadata to be transmitted to the service-side of
|
||
the RPC.
|
||
credentials: An optional CallCredentials for the RPC.
|
||
wait_for_ready: This is an EXPERIMENTAL argument. An optional
|
||
flag to enable :term:`wait_for_ready` mechanism.
|
||
"""
|
||
|
||
method: str
|
||
timeout: Optional[float]
|
||
metadata: Optional[Metadata]
|
||
credentials: Optional[grpc.CallCredentials]
|
||
wait_for_ready: Optional[bool]
|
||
|
||
|
||
class ClientInterceptor(metaclass=ABCMeta):
|
||
"""Base class used for all Aio Client Interceptor classes"""
|
||
|
||
|
||
class UnaryUnaryClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
|
||
"""Affords intercepting unary-unary invocations."""
|
||
|
||
@abstractmethod
|
||
async def intercept_unary_unary(
|
||
self, continuation: Callable[[ClientCallDetails, RequestType],
|
||
UnaryUnaryCall],
|
||
client_call_details: ClientCallDetails,
|
||
request: RequestType) -> Union[UnaryUnaryCall, ResponseType]:
|
||
"""Intercepts a unary-unary invocation asynchronously.
|
||
|
||
Args:
|
||
continuation: A coroutine that proceeds with the invocation by
|
||
executing the next interceptor in the chain or invoking the
|
||
actual RPC on the underlying Channel. It is the interceptor's
|
||
responsibility to call it if it decides to move the RPC forward.
|
||
The interceptor can use
|
||
`call = await continuation(client_call_details, request)`
|
||
to continue with the RPC. `continuation` returns the call to the
|
||
RPC.
|
||
client_call_details: A ClientCallDetails object describing the
|
||
outgoing RPC.
|
||
request: The request value for the RPC.
|
||
|
||
Returns:
|
||
An object with the RPC response.
|
||
|
||
Raises:
|
||
AioRpcError: Indicating that the RPC terminated with non-OK status.
|
||
asyncio.CancelledError: Indicating that the RPC was canceled.
|
||
"""
|
||
|
||
|
||
class UnaryStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
|
||
"""Affords intercepting unary-stream invocations."""
|
||
|
||
@abstractmethod
|
||
async def intercept_unary_stream(
|
||
self, continuation: Callable[[ClientCallDetails, RequestType],
|
||
UnaryStreamCall],
|
||
client_call_details: ClientCallDetails, request: RequestType
|
||
) -> Union[ResponseIterableType, UnaryStreamCall]:
|
||
"""Intercepts a unary-stream invocation asynchronously.
|
||
|
||
The function could return the call object or an asynchronous
|
||
iterator, in case of being an asyncrhonous iterator this will
|
||
become the source of the reads done by the caller.
|
||
|
||
Args:
|
||
continuation: A coroutine that proceeds with the invocation by
|
||
executing the next interceptor in the chain or invoking the
|
||
actual RPC on the underlying Channel. It is the interceptor's
|
||
responsibility to call it if it decides to move the RPC forward.
|
||
The interceptor can use
|
||
`call = await continuation(client_call_details, request)`
|
||
to continue with the RPC. `continuation` returns the call to the
|
||
RPC.
|
||
client_call_details: A ClientCallDetails object describing the
|
||
outgoing RPC.
|
||
request: The request value for the RPC.
|
||
|
||
Returns:
|
||
The RPC Call or an asynchronous iterator.
|
||
|
||
Raises:
|
||
AioRpcError: Indicating that the RPC terminated with non-OK status.
|
||
asyncio.CancelledError: Indicating that the RPC was canceled.
|
||
"""
|
||
|
||
|
||
class StreamUnaryClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
|
||
"""Affords intercepting stream-unary invocations."""
|
||
|
||
@abstractmethod
|
||
async def intercept_stream_unary(
|
||
self,
|
||
continuation: Callable[[ClientCallDetails, RequestType],
|
||
UnaryStreamCall],
|
||
client_call_details: ClientCallDetails,
|
||
request_iterator: RequestIterableType,
|
||
) -> StreamUnaryCall:
|
||
"""Intercepts a stream-unary invocation asynchronously.
|
||
|
||
Within the interceptor the usage of the call methods like `write` or
|
||
even awaiting the call should be done carefully, since the caller
|
||
could be expecting an untouched call, for example for start writing
|
||
messages to it.
|
||
|
||
Args:
|
||
continuation: A coroutine that proceeds with the invocation by
|
||
executing the next interceptor in the chain or invoking the
|
||
actual RPC on the underlying Channel. It is the interceptor's
|
||
responsibility to call it if it decides to move the RPC forward.
|
||
The interceptor can use
|
||
`call = await continuation(client_call_details, request_iterator)`
|
||
to continue with the RPC. `continuation` returns the call to the
|
||
RPC.
|
||
client_call_details: A ClientCallDetails object describing the
|
||
outgoing RPC.
|
||
request_iterator: The request iterator that will produce requests
|
||
for the RPC.
|
||
|
||
Returns:
|
||
The RPC Call.
|
||
|
||
Raises:
|
||
AioRpcError: Indicating that the RPC terminated with non-OK status.
|
||
asyncio.CancelledError: Indicating that the RPC was canceled.
|
||
"""
|
||
|
||
|
||
class StreamStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
|
||
"""Affords intercepting stream-stream invocations."""
|
||
|
||
@abstractmethod
|
||
async def intercept_stream_stream(
|
||
self,
|
||
continuation: Callable[[ClientCallDetails, RequestType],
|
||
UnaryStreamCall],
|
||
client_call_details: ClientCallDetails,
|
||
request_iterator: RequestIterableType,
|
||
) -> Union[ResponseIterableType, StreamStreamCall]:
|
||
"""Intercepts a stream-stream invocation asynchronously.
|
||
|
||
Within the interceptor the usage of the call methods like `write` or
|
||
even awaiting the call should be done carefully, since the caller
|
||
could be expecting an untouched call, for example for start writing
|
||
messages to it.
|
||
|
||
The function could return the call object or an asynchronous
|
||
iterator, in case of being an asyncrhonous iterator this will
|
||
become the source of the reads done by the caller.
|
||
|
||
Args:
|
||
continuation: A coroutine that proceeds with the invocation by
|
||
executing the next interceptor in the chain or invoking the
|
||
actual RPC on the underlying Channel. It is the interceptor's
|
||
responsibility to call it if it decides to move the RPC forward.
|
||
The interceptor can use
|
||
`call = await continuation(client_call_details, request_iterator)`
|
||
to continue with the RPC. `continuation` returns the call to the
|
||
RPC.
|
||
client_call_details: A ClientCallDetails object describing the
|
||
outgoing RPC.
|
||
request_iterator: The request iterator that will produce requests
|
||
for the RPC.
|
||
|
||
Returns:
|
||
The RPC Call or an asynchronous iterator.
|
||
|
||
Raises:
|
||
AioRpcError: Indicating that the RPC terminated with non-OK status.
|
||
asyncio.CancelledError: Indicating that the RPC was canceled.
|
||
"""
|
||
|
||
|
||
class InterceptedCall:
|
||
"""Base implementation for all intercepted call arities.
|
||
|
||
Interceptors might have some work to do before the RPC invocation with
|
||
the capacity of changing the invocation parameters, and some work to do
|
||
after the RPC invocation with the capacity for accessing to the wrapped
|
||
`UnaryUnaryCall`.
|
||
|
||
It handles also early and later cancellations, when the RPC has not even
|
||
started and the execution is still held by the interceptors or when the
|
||
RPC has finished but again the execution is still held by the interceptors.
|
||
|
||
Once the RPC is finally executed, all methods are finally done against the
|
||
intercepted call, being at the same time the same call returned to the
|
||
interceptors.
|
||
|
||
As a base class for all of the interceptors implements the logic around
|
||
final status, metadata and cancellation.
|
||
"""
|
||
|
||
_interceptors_task: asyncio.Task
|
||
_pending_add_done_callbacks: Sequence[DoneCallbackType]
|
||
|
||
def __init__(self, interceptors_task: asyncio.Task) -> None:
|
||
self._interceptors_task = interceptors_task
|
||
self._pending_add_done_callbacks = []
|
||
self._interceptors_task.add_done_callback(
|
||
self._fire_or_add_pending_done_callbacks)
|
||
|
||
def __del__(self):
|
||
self.cancel()
|
||
|
||
def _fire_or_add_pending_done_callbacks(self,
|
||
interceptors_task: asyncio.Task
|
||
) -> None:
|
||
|
||
if not self._pending_add_done_callbacks:
|
||
return
|
||
|
||
call_completed = False
|
||
|
||
try:
|
||
call = interceptors_task.result()
|
||
if call.done():
|
||
call_completed = True
|
||
except (AioRpcError, asyncio.CancelledError):
|
||
call_completed = True
|
||
|
||
if call_completed:
|
||
for callback in self._pending_add_done_callbacks:
|
||
callback(self)
|
||
else:
|
||
for callback in self._pending_add_done_callbacks:
|
||
callback = functools.partial(self._wrap_add_done_callback,
|
||
callback)
|
||
call.add_done_callback(callback)
|
||
|
||
self._pending_add_done_callbacks = []
|
||
|
||
def _wrap_add_done_callback(self, callback: DoneCallbackType,
|
||
unused_call: _base_call.Call) -> None:
|
||
callback(self)
|
||
|
||
def cancel(self) -> bool:
|
||
if not self._interceptors_task.done():
|
||
# There is no yet the intercepted call available,
|
||
# Trying to cancel it by using the generic Asyncio
|
||
# cancellation method.
|
||
return self._interceptors_task.cancel()
|
||
|
||
try:
|
||
call = self._interceptors_task.result()
|
||
except AioRpcError:
|
||
return False
|
||
except asyncio.CancelledError:
|
||
return False
|
||
|
||
return call.cancel()
|
||
|
||
def cancelled(self) -> bool:
|
||
if not self._interceptors_task.done():
|
||
return False
|
||
|
||
try:
|
||
call = self._interceptors_task.result()
|
||
except AioRpcError as err:
|
||
return err.code() == grpc.StatusCode.CANCELLED
|
||
except asyncio.CancelledError:
|
||
return True
|
||
|
||
return call.cancelled()
|
||
|
||
def done(self) -> bool:
|
||
if not self._interceptors_task.done():
|
||
return False
|
||
|
||
try:
|
||
call = self._interceptors_task.result()
|
||
except (AioRpcError, asyncio.CancelledError):
|
||
return True
|
||
|
||
return call.done()
|
||
|
||
def add_done_callback(self, callback: DoneCallbackType) -> None:
|
||
if not self._interceptors_task.done():
|
||
self._pending_add_done_callbacks.append(callback)
|
||
return
|
||
|
||
try:
|
||
call = self._interceptors_task.result()
|
||
except (AioRpcError, asyncio.CancelledError):
|
||
callback(self)
|
||
return
|
||
|
||
if call.done():
|
||
callback(self)
|
||
else:
|
||
callback = functools.partial(self._wrap_add_done_callback, callback)
|
||
call.add_done_callback(callback)
|
||
|
||
def time_remaining(self) -> Optional[float]:
|
||
raise NotImplementedError()
|
||
|
||
async def initial_metadata(self) -> Optional[Metadata]:
|
||
try:
|
||
call = await self._interceptors_task
|
||
except AioRpcError as err:
|
||
return err.initial_metadata()
|
||
except asyncio.CancelledError:
|
||
return None
|
||
|
||
return await call.initial_metadata()
|
||
|
||
async def trailing_metadata(self) -> Optional[Metadata]:
|
||
try:
|
||
call = await self._interceptors_task
|
||
except AioRpcError as err:
|
||
return err.trailing_metadata()
|
||
except asyncio.CancelledError:
|
||
return None
|
||
|
||
return await call.trailing_metadata()
|
||
|
||
async def code(self) -> grpc.StatusCode:
|
||
try:
|
||
call = await self._interceptors_task
|
||
except AioRpcError as err:
|
||
return err.code()
|
||
except asyncio.CancelledError:
|
||
return grpc.StatusCode.CANCELLED
|
||
|
||
return await call.code()
|
||
|
||
async def details(self) -> str:
|
||
try:
|
||
call = await self._interceptors_task
|
||
except AioRpcError as err:
|
||
return err.details()
|
||
except asyncio.CancelledError:
|
||
return _LOCAL_CANCELLATION_DETAILS
|
||
|
||
return await call.details()
|
||
|
||
async def debug_error_string(self) -> Optional[str]:
|
||
try:
|
||
call = await self._interceptors_task
|
||
except AioRpcError as err:
|
||
return err.debug_error_string()
|
||
except asyncio.CancelledError:
|
||
return ''
|
||
|
||
return await call.debug_error_string()
|
||
|
||
async def wait_for_connection(self) -> None:
|
||
call = await self._interceptors_task
|
||
return await call.wait_for_connection()
|
||
|
||
|
||
class _InterceptedUnaryResponseMixin:
|
||
|
||
def __await__(self):
|
||
call = yield from self._interceptors_task.__await__()
|
||
response = yield from call.__await__()
|
||
return response
|
||
|
||
|
||
class _InterceptedStreamResponseMixin:
|
||
_response_aiter: Optional[AsyncIterable[ResponseType]]
|
||
|
||
def _init_stream_response_mixin(self) -> None:
|
||
# Is initalized later, otherwise if the iterator is not finnally
|
||
# consumed a logging warning is emmited by Asyncio.
|
||
self._response_aiter = None
|
||
|
||
async def _wait_for_interceptor_task_response_iterator(self
|
||
) -> ResponseType:
|
||
call = await self._interceptors_task
|
||
async for response in call:
|
||
yield response
|
||
|
||
def __aiter__(self) -> AsyncIterable[ResponseType]:
|
||
if self._response_aiter is None:
|
||
self._response_aiter = self._wait_for_interceptor_task_response_iterator(
|
||
)
|
||
return self._response_aiter
|
||
|
||
async def read(self) -> ResponseType:
|
||
if self._response_aiter is None:
|
||
self._response_aiter = self._wait_for_interceptor_task_response_iterator(
|
||
)
|
||
return await self._response_aiter.asend(None)
|
||
|
||
|
||
class _InterceptedStreamRequestMixin:
|
||
|
||
_write_to_iterator_async_gen: Optional[AsyncIterable[RequestType]]
|
||
_write_to_iterator_queue: Optional[asyncio.Queue]
|
||
|
||
_FINISH_ITERATOR_SENTINEL = object()
|
||
|
||
def _init_stream_request_mixin(
|
||
self, request_iterator: Optional[RequestIterableType]
|
||
) -> RequestIterableType:
|
||
|
||
if request_iterator is None:
|
||
# We provide our own request iterator which is a proxy
|
||
# of the futures writes that will be done by the caller.
|
||
self._write_to_iterator_queue = asyncio.Queue(maxsize=1)
|
||
self._write_to_iterator_async_gen = self._proxy_writes_as_request_iterator(
|
||
)
|
||
request_iterator = self._write_to_iterator_async_gen
|
||
else:
|
||
self._write_to_iterator_queue = None
|
||
|
||
return request_iterator
|
||
|
||
async def _proxy_writes_as_request_iterator(self):
|
||
await self._interceptors_task
|
||
|
||
while True:
|
||
value = await self._write_to_iterator_queue.get()
|
||
if value is _InterceptedStreamRequestMixin._FINISH_ITERATOR_SENTINEL:
|
||
break
|
||
yield value
|
||
|
||
async def write(self, request: RequestType) -> None:
|
||
# If no queue was created it means that requests
|
||
# should be expected through an iterators provided
|
||
# by the caller.
|
||
if self._write_to_iterator_queue is None:
|
||
raise cygrpc.UsageError(_API_STYLE_ERROR)
|
||
|
||
try:
|
||
call = await self._interceptors_task
|
||
except (asyncio.CancelledError, AioRpcError):
|
||
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
|
||
|
||
if call.done():
|
||
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
|
||
elif call._done_writing_flag:
|
||
raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS)
|
||
|
||
# Write might never end up since the call could abrubtly finish,
|
||
# we give up on the first awaitable object that finishes.
|
||
_, _ = await asyncio.wait(
|
||
(self._write_to_iterator_queue.put(request), call.code()),
|
||
return_when=asyncio.FIRST_COMPLETED)
|
||
|
||
if call.done():
|
||
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
|
||
|
||
async def done_writing(self) -> None:
|
||
"""Signal peer that client is done writing.
|
||
|
||
This method is idempotent.
|
||
"""
|
||
# If no queue was created it means that requests
|
||
# should be expected through an iterators provided
|
||
# by the caller.
|
||
if self._write_to_iterator_queue is None:
|
||
raise cygrpc.UsageError(_API_STYLE_ERROR)
|
||
|
||
try:
|
||
call = await self._interceptors_task
|
||
except asyncio.CancelledError:
|
||
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
|
||
|
||
# Write might never end up since the call could abrubtly finish,
|
||
# we give up on the first awaitable object that finishes.
|
||
_, _ = await asyncio.wait((self._write_to_iterator_queue.put(
|
||
_InterceptedStreamRequestMixin._FINISH_ITERATOR_SENTINEL),
|
||
call.code()),
|
||
return_when=asyncio.FIRST_COMPLETED)
|
||
|
||
|
||
class InterceptedUnaryUnaryCall(_InterceptedUnaryResponseMixin, InterceptedCall,
|
||
_base_call.UnaryUnaryCall):
|
||
"""Used for running a `UnaryUnaryCall` wrapped by interceptors.
|
||
|
||
For the `__await__` method is it is proxied to the intercepted call only when
|
||
the interceptor task is finished.
|
||
"""
|
||
|
||
_loop: asyncio.AbstractEventLoop
|
||
_channel: cygrpc.AioChannel
|
||
|
||
# pylint: disable=too-many-arguments
|
||
def __init__(self, interceptors: Sequence[UnaryUnaryClientInterceptor],
|
||
request: RequestType, timeout: Optional[float],
|
||
metadata: Metadata,
|
||
credentials: Optional[grpc.CallCredentials],
|
||
wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
|
||
method: bytes, request_serializer: SerializingFunction,
|
||
response_deserializer: DeserializingFunction,
|
||
loop: asyncio.AbstractEventLoop) -> None:
|
||
self._loop = loop
|
||
self._channel = channel
|
||
interceptors_task = loop.create_task(
|
||
self._invoke(interceptors, method, timeout, metadata, credentials,
|
||
wait_for_ready, request, request_serializer,
|
||
response_deserializer))
|
||
super().__init__(interceptors_task)
|
||
|
||
# pylint: disable=too-many-arguments
|
||
async def _invoke(self, interceptors: Sequence[UnaryUnaryClientInterceptor],
|
||
method: bytes, timeout: Optional[float],
|
||
metadata: Optional[Metadata],
|
||
credentials: Optional[grpc.CallCredentials],
|
||
wait_for_ready: Optional[bool], request: RequestType,
|
||
request_serializer: SerializingFunction,
|
||
response_deserializer: DeserializingFunction
|
||
) -> UnaryUnaryCall:
|
||
"""Run the RPC call wrapped in interceptors"""
|
||
|
||
async def _run_interceptor(
|
||
interceptors: Iterator[UnaryUnaryClientInterceptor],
|
||
client_call_details: ClientCallDetails,
|
||
request: RequestType) -> _base_call.UnaryUnaryCall:
|
||
|
||
interceptor = next(interceptors, None)
|
||
|
||
if interceptor:
|
||
continuation = functools.partial(_run_interceptor, interceptors)
|
||
|
||
call_or_response = await interceptor.intercept_unary_unary(
|
||
continuation, client_call_details, request)
|
||
|
||
if isinstance(call_or_response, _base_call.UnaryUnaryCall):
|
||
return call_or_response
|
||
else:
|
||
return UnaryUnaryCallResponse(call_or_response)
|
||
|
||
else:
|
||
return UnaryUnaryCall(
|
||
request, _timeout_to_deadline(client_call_details.timeout),
|
||
client_call_details.metadata,
|
||
client_call_details.credentials,
|
||
client_call_details.wait_for_ready, self._channel,
|
||
client_call_details.method, request_serializer,
|
||
response_deserializer, self._loop)
|
||
|
||
client_call_details = ClientCallDetails(method, timeout, metadata,
|
||
credentials, wait_for_ready)
|
||
return await _run_interceptor(iter(interceptors), client_call_details,
|
||
request)
|
||
|
||
def time_remaining(self) -> Optional[float]:
|
||
raise NotImplementedError()
|
||
|
||
|
||
class InterceptedUnaryStreamCall(_InterceptedStreamResponseMixin,
|
||
InterceptedCall, _base_call.UnaryStreamCall):
|
||
"""Used for running a `UnaryStreamCall` wrapped by interceptors."""
|
||
|
||
_loop: asyncio.AbstractEventLoop
|
||
_channel: cygrpc.AioChannel
|
||
_last_returned_call_from_interceptors = Optional[_base_call.UnaryStreamCall]
|
||
|
||
# pylint: disable=too-many-arguments
|
||
def __init__(self, interceptors: Sequence[UnaryStreamClientInterceptor],
|
||
request: RequestType, timeout: Optional[float],
|
||
metadata: Metadata,
|
||
credentials: Optional[grpc.CallCredentials],
|
||
wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
|
||
method: bytes, request_serializer: SerializingFunction,
|
||
response_deserializer: DeserializingFunction,
|
||
loop: asyncio.AbstractEventLoop) -> None:
|
||
self._loop = loop
|
||
self._channel = channel
|
||
self._init_stream_response_mixin()
|
||
self._last_returned_call_from_interceptors = None
|
||
interceptors_task = loop.create_task(
|
||
self._invoke(interceptors, method, timeout, metadata, credentials,
|
||
wait_for_ready, request, request_serializer,
|
||
response_deserializer))
|
||
super().__init__(interceptors_task)
|
||
|
||
# pylint: disable=too-many-arguments
|
||
async def _invoke(self, interceptors: Sequence[UnaryUnaryClientInterceptor],
|
||
method: bytes, timeout: Optional[float],
|
||
metadata: Optional[Metadata],
|
||
credentials: Optional[grpc.CallCredentials],
|
||
wait_for_ready: Optional[bool], request: RequestType,
|
||
request_serializer: SerializingFunction,
|
||
response_deserializer: DeserializingFunction
|
||
) -> UnaryStreamCall:
|
||
"""Run the RPC call wrapped in interceptors"""
|
||
|
||
async def _run_interceptor(
|
||
interceptors: Iterator[UnaryStreamClientInterceptor],
|
||
client_call_details: ClientCallDetails,
|
||
request: RequestType,
|
||
) -> _base_call.UnaryUnaryCall:
|
||
|
||
interceptor = next(interceptors, None)
|
||
|
||
if interceptor:
|
||
continuation = functools.partial(_run_interceptor, interceptors)
|
||
|
||
call_or_response_iterator = await interceptor.intercept_unary_stream(
|
||
continuation, client_call_details, request)
|
||
|
||
if isinstance(call_or_response_iterator,
|
||
_base_call.UnaryStreamCall):
|
||
self._last_returned_call_from_interceptors = call_or_response_iterator
|
||
else:
|
||
self._last_returned_call_from_interceptors = UnaryStreamCallResponseIterator(
|
||
self._last_returned_call_from_interceptors,
|
||
call_or_response_iterator)
|
||
return self._last_returned_call_from_interceptors
|
||
else:
|
||
self._last_returned_call_from_interceptors = UnaryStreamCall(
|
||
request, _timeout_to_deadline(client_call_details.timeout),
|
||
client_call_details.metadata,
|
||
client_call_details.credentials,
|
||
client_call_details.wait_for_ready, self._channel,
|
||
client_call_details.method, request_serializer,
|
||
response_deserializer, self._loop)
|
||
|
||
return self._last_returned_call_from_interceptors
|
||
|
||
client_call_details = ClientCallDetails(method, timeout, metadata,
|
||
credentials, wait_for_ready)
|
||
return await _run_interceptor(iter(interceptors), client_call_details,
|
||
request)
|
||
|
||
def time_remaining(self) -> Optional[float]:
|
||
raise NotImplementedError()
|
||
|
||
|
||
class InterceptedStreamUnaryCall(_InterceptedUnaryResponseMixin,
|
||
_InterceptedStreamRequestMixin,
|
||
InterceptedCall, _base_call.StreamUnaryCall):
|
||
"""Used for running a `StreamUnaryCall` wrapped by interceptors.
|
||
|
||
For the `__await__` method is it is proxied to the intercepted call only when
|
||
the interceptor task is finished.
|
||
"""
|
||
|
||
_loop: asyncio.AbstractEventLoop
|
||
_channel: cygrpc.AioChannel
|
||
|
||
# pylint: disable=too-many-arguments
|
||
def __init__(self, interceptors: Sequence[StreamUnaryClientInterceptor],
|
||
request_iterator: Optional[RequestIterableType],
|
||
timeout: Optional[float], metadata: Metadata,
|
||
credentials: Optional[grpc.CallCredentials],
|
||
wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
|
||
method: bytes, request_serializer: SerializingFunction,
|
||
response_deserializer: DeserializingFunction,
|
||
loop: asyncio.AbstractEventLoop) -> None:
|
||
self._loop = loop
|
||
self._channel = channel
|
||
request_iterator = self._init_stream_request_mixin(request_iterator)
|
||
interceptors_task = loop.create_task(
|
||
self._invoke(interceptors, method, timeout, metadata, credentials,
|
||
wait_for_ready, request_iterator, request_serializer,
|
||
response_deserializer))
|
||
super().__init__(interceptors_task)
|
||
|
||
# pylint: disable=too-many-arguments
|
||
async def _invoke(
|
||
self, interceptors: Sequence[StreamUnaryClientInterceptor],
|
||
method: bytes, timeout: Optional[float],
|
||
metadata: Optional[Metadata],
|
||
credentials: Optional[grpc.CallCredentials],
|
||
wait_for_ready: Optional[bool],
|
||
request_iterator: RequestIterableType,
|
||
request_serializer: SerializingFunction,
|
||
response_deserializer: DeserializingFunction) -> StreamUnaryCall:
|
||
"""Run the RPC call wrapped in interceptors"""
|
||
|
||
async def _run_interceptor(
|
||
interceptors: Iterator[UnaryUnaryClientInterceptor],
|
||
client_call_details: ClientCallDetails,
|
||
request_iterator: RequestIterableType
|
||
) -> _base_call.StreamUnaryCall:
|
||
|
||
interceptor = next(interceptors, None)
|
||
|
||
if interceptor:
|
||
continuation = functools.partial(_run_interceptor, interceptors)
|
||
|
||
return await interceptor.intercept_stream_unary(
|
||
continuation, client_call_details, request_iterator)
|
||
else:
|
||
return StreamUnaryCall(
|
||
request_iterator,
|
||
_timeout_to_deadline(client_call_details.timeout),
|
||
client_call_details.metadata,
|
||
client_call_details.credentials,
|
||
client_call_details.wait_for_ready, self._channel,
|
||
client_call_details.method, request_serializer,
|
||
response_deserializer, self._loop)
|
||
|
||
client_call_details = ClientCallDetails(method, timeout, metadata,
|
||
credentials, wait_for_ready)
|
||
return await _run_interceptor(iter(interceptors), client_call_details,
|
||
request_iterator)
|
||
|
||
def time_remaining(self) -> Optional[float]:
|
||
raise NotImplementedError()
|
||
|
||
|
||
class InterceptedStreamStreamCall(_InterceptedStreamResponseMixin,
|
||
_InterceptedStreamRequestMixin,
|
||
InterceptedCall, _base_call.StreamStreamCall):
|
||
"""Used for running a `StreamStreamCall` wrapped by interceptors."""
|
||
|
||
_loop: asyncio.AbstractEventLoop
|
||
_channel: cygrpc.AioChannel
|
||
_last_returned_call_from_interceptors = Optional[_base_call.UnaryStreamCall]
|
||
|
||
# pylint: disable=too-many-arguments
|
||
def __init__(self, interceptors: Sequence[StreamStreamClientInterceptor],
|
||
request_iterator: Optional[RequestIterableType],
|
||
timeout: Optional[float], metadata: Metadata,
|
||
credentials: Optional[grpc.CallCredentials],
|
||
wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
|
||
method: bytes, request_serializer: SerializingFunction,
|
||
response_deserializer: DeserializingFunction,
|
||
loop: asyncio.AbstractEventLoop) -> None:
|
||
self._loop = loop
|
||
self._channel = channel
|
||
self._init_stream_response_mixin()
|
||
request_iterator = self._init_stream_request_mixin(request_iterator)
|
||
self._last_returned_call_from_interceptors = None
|
||
interceptors_task = loop.create_task(
|
||
self._invoke(interceptors, method, timeout, metadata, credentials,
|
||
wait_for_ready, request_iterator, request_serializer,
|
||
response_deserializer))
|
||
super().__init__(interceptors_task)
|
||
|
||
# pylint: disable=too-many-arguments
|
||
async def _invoke(
|
||
self, interceptors: Sequence[StreamStreamClientInterceptor],
|
||
method: bytes, timeout: Optional[float],
|
||
metadata: Optional[Metadata],
|
||
credentials: Optional[grpc.CallCredentials],
|
||
wait_for_ready: Optional[bool],
|
||
request_iterator: RequestIterableType,
|
||
request_serializer: SerializingFunction,
|
||
response_deserializer: DeserializingFunction) -> StreamStreamCall:
|
||
"""Run the RPC call wrapped in interceptors"""
|
||
|
||
async def _run_interceptor(
|
||
interceptors: Iterator[StreamStreamClientInterceptor],
|
||
client_call_details: ClientCallDetails,
|
||
request_iterator: RequestIterableType
|
||
) -> _base_call.StreamStreamCall:
|
||
|
||
interceptor = next(interceptors, None)
|
||
|
||
if interceptor:
|
||
continuation = functools.partial(_run_interceptor, interceptors)
|
||
|
||
call_or_response_iterator = await interceptor.intercept_stream_stream(
|
||
continuation, client_call_details, request_iterator)
|
||
|
||
if isinstance(call_or_response_iterator,
|
||
_base_call.StreamStreamCall):
|
||
self._last_returned_call_from_interceptors = call_or_response_iterator
|
||
else:
|
||
self._last_returned_call_from_interceptors = StreamStreamCallResponseIterator(
|
||
self._last_returned_call_from_interceptors,
|
||
call_or_response_iterator)
|
||
return self._last_returned_call_from_interceptors
|
||
else:
|
||
self._last_returned_call_from_interceptors = StreamStreamCall(
|
||
request_iterator,
|
||
_timeout_to_deadline(client_call_details.timeout),
|
||
client_call_details.metadata,
|
||
client_call_details.credentials,
|
||
client_call_details.wait_for_ready, self._channel,
|
||
client_call_details.method, request_serializer,
|
||
response_deserializer, self._loop)
|
||
return self._last_returned_call_from_interceptors
|
||
|
||
client_call_details = ClientCallDetails(method, timeout, metadata,
|
||
credentials, wait_for_ready)
|
||
return await _run_interceptor(iter(interceptors), client_call_details,
|
||
request_iterator)
|
||
|
||
def time_remaining(self) -> Optional[float]:
|
||
raise NotImplementedError()
|
||
|
||
|
||
class UnaryUnaryCallResponse(_base_call.UnaryUnaryCall):
|
||
"""Final UnaryUnaryCall class finished with a response."""
|
||
_response: ResponseType
|
||
|
||
def __init__(self, response: ResponseType) -> None:
|
||
self._response = response
|
||
|
||
def cancel(self) -> bool:
|
||
return False
|
||
|
||
def cancelled(self) -> bool:
|
||
return False
|
||
|
||
def done(self) -> bool:
|
||
return True
|
||
|
||
def add_done_callback(self, unused_callback) -> None:
|
||
raise NotImplementedError()
|
||
|
||
def time_remaining(self) -> Optional[float]:
|
||
raise NotImplementedError()
|
||
|
||
async def initial_metadata(self) -> Optional[Metadata]:
|
||
return None
|
||
|
||
async def trailing_metadata(self) -> Optional[Metadata]:
|
||
return None
|
||
|
||
async def code(self) -> grpc.StatusCode:
|
||
return grpc.StatusCode.OK
|
||
|
||
async def details(self) -> str:
|
||
return ''
|
||
|
||
async def debug_error_string(self) -> Optional[str]:
|
||
return None
|
||
|
||
def __await__(self):
|
||
if False: # pylint: disable=using-constant-test
|
||
# This code path is never used, but a yield statement is needed
|
||
# for telling the interpreter that __await__ is a generator.
|
||
yield None
|
||
return self._response
|
||
|
||
async def wait_for_connection(self) -> None:
|
||
pass
|
||
|
||
|
||
class _StreamCallResponseIterator:
|
||
|
||
_call: Union[_base_call.UnaryStreamCall, _base_call.StreamStreamCall]
|
||
_response_iterator: AsyncIterable[ResponseType]
|
||
|
||
def __init__(self, call: Union[_base_call.UnaryStreamCall, _base_call.
|
||
StreamStreamCall],
|
||
response_iterator: AsyncIterable[ResponseType]) -> None:
|
||
self._response_iterator = response_iterator
|
||
self._call = call
|
||
|
||
def cancel(self) -> bool:
|
||
return self._call.cancel()
|
||
|
||
def cancelled(self) -> bool:
|
||
return self._call.cancelled()
|
||
|
||
def done(self) -> bool:
|
||
return self._call.done()
|
||
|
||
def add_done_callback(self, callback) -> None:
|
||
self._call.add_done_callback(callback)
|
||
|
||
def time_remaining(self) -> Optional[float]:
|
||
return self._call.time_remaining()
|
||
|
||
async def initial_metadata(self) -> Optional[Metadata]:
|
||
return await self._call.initial_metadata()
|
||
|
||
async def trailing_metadata(self) -> Optional[Metadata]:
|
||
return await self._call.trailing_metadata()
|
||
|
||
async def code(self) -> grpc.StatusCode:
|
||
return await self._call.code()
|
||
|
||
async def details(self) -> str:
|
||
return await self._call.details()
|
||
|
||
async def debug_error_string(self) -> Optional[str]:
|
||
return await self._call.debug_error_string()
|
||
|
||
def __aiter__(self):
|
||
return self._response_iterator.__aiter__()
|
||
|
||
async def wait_for_connection(self) -> None:
|
||
return await self._call.wait_for_connection()
|
||
|
||
|
||
class UnaryStreamCallResponseIterator(_StreamCallResponseIterator,
|
||
_base_call.UnaryStreamCall):
|
||
"""UnaryStreamCall class wich uses an alternative response iterator."""
|
||
|
||
async def read(self) -> ResponseType:
|
||
# Behind the scenes everyting goes through the
|
||
# async iterator. So this path should not be reached.
|
||
raise NotImplementedError()
|
||
|
||
|
||
class StreamStreamCallResponseIterator(_StreamCallResponseIterator,
|
||
_base_call.StreamStreamCall):
|
||
"""StreamStreamCall class wich uses an alternative response iterator."""
|
||
|
||
async def read(self) -> ResponseType:
|
||
# Behind the scenes everyting goes through the
|
||
# async iterator. So this path should not be reached.
|
||
raise NotImplementedError()
|
||
|
||
async def write(self, request: RequestType) -> None:
|
||
# Behind the scenes everyting goes through the
|
||
# async iterator provided by the InterceptedStreamStreamCall.
|
||
# So this path should not be reached.
|
||
raise NotImplementedError()
|
||
|
||
async def done_writing(self) -> None:
|
||
# Behind the scenes everyting goes through the
|
||
# async iterator provided by the InterceptedStreamStreamCall.
|
||
# So this path should not be reached.
|
||
raise NotImplementedError()
|
||
|
||
@property
|
||
def _done_writing_flag(self) -> bool:
|
||
return self._call._done_writing_flag
|