Added delete option to database storage.

This commit is contained in:
Batuhan Berk Başoğlu 2020-10-12 12:10:01 -04:00
parent 308604a33c
commit 963b5bc68b
1868 changed files with 192402 additions and 13278 deletions

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,58 @@
# Copyright 2016 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GRPCAuthMetadataPlugins for standard authentication."""
import inspect
import grpc
def _sign_request(callback, token, error):
metadata = (('authorization', 'Bearer {}'.format(token)),)
callback(metadata, error)
class GoogleCallCredentials(grpc.AuthMetadataPlugin):
"""Metadata wrapper for GoogleCredentials from the oauth2client library."""
def __init__(self, credentials):
self._credentials = credentials
# Hack to determine if these are JWT creds and we need to pass
# additional_claims when getting a token
self._is_jwt = 'additional_claims' in inspect.getargspec( # pylint: disable=deprecated-method
credentials.get_access_token).args
def __call__(self, context, callback):
try:
if self._is_jwt:
access_token = self._credentials.get_access_token(
additional_claims={
'aud': context.service_url
}).access_token
else:
access_token = self._credentials.get_access_token().access_token
except Exception as exception: # pylint: disable=broad-except
_sign_request(callback, None, exception)
else:
_sign_request(callback, access_token, None)
class AccessTokenAuthMetadataPlugin(grpc.AuthMetadataPlugin):
"""Metadata wrapper for raw access token credentials."""
def __init__(self, access_token):
self._access_token = access_token
def __call__(self, context, callback):
_sign_request(callback, self._access_token, None)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,168 @@
# Copyright 2016 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared implementation."""
import logging
import time
import six
import grpc
from grpc._cython import cygrpc
_LOGGER = logging.getLogger(__name__)
CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY = {
cygrpc.ConnectivityState.idle:
grpc.ChannelConnectivity.IDLE,
cygrpc.ConnectivityState.connecting:
grpc.ChannelConnectivity.CONNECTING,
cygrpc.ConnectivityState.ready:
grpc.ChannelConnectivity.READY,
cygrpc.ConnectivityState.transient_failure:
grpc.ChannelConnectivity.TRANSIENT_FAILURE,
cygrpc.ConnectivityState.shutdown:
grpc.ChannelConnectivity.SHUTDOWN,
}
CYGRPC_STATUS_CODE_TO_STATUS_CODE = {
cygrpc.StatusCode.ok: grpc.StatusCode.OK,
cygrpc.StatusCode.cancelled: grpc.StatusCode.CANCELLED,
cygrpc.StatusCode.unknown: grpc.StatusCode.UNKNOWN,
cygrpc.StatusCode.invalid_argument: grpc.StatusCode.INVALID_ARGUMENT,
cygrpc.StatusCode.deadline_exceeded: grpc.StatusCode.DEADLINE_EXCEEDED,
cygrpc.StatusCode.not_found: grpc.StatusCode.NOT_FOUND,
cygrpc.StatusCode.already_exists: grpc.StatusCode.ALREADY_EXISTS,
cygrpc.StatusCode.permission_denied: grpc.StatusCode.PERMISSION_DENIED,
cygrpc.StatusCode.unauthenticated: grpc.StatusCode.UNAUTHENTICATED,
cygrpc.StatusCode.resource_exhausted: grpc.StatusCode.RESOURCE_EXHAUSTED,
cygrpc.StatusCode.failed_precondition: grpc.StatusCode.FAILED_PRECONDITION,
cygrpc.StatusCode.aborted: grpc.StatusCode.ABORTED,
cygrpc.StatusCode.out_of_range: grpc.StatusCode.OUT_OF_RANGE,
cygrpc.StatusCode.unimplemented: grpc.StatusCode.UNIMPLEMENTED,
cygrpc.StatusCode.internal: grpc.StatusCode.INTERNAL,
cygrpc.StatusCode.unavailable: grpc.StatusCode.UNAVAILABLE,
cygrpc.StatusCode.data_loss: grpc.StatusCode.DATA_LOSS,
}
STATUS_CODE_TO_CYGRPC_STATUS_CODE = {
grpc_code: cygrpc_code for cygrpc_code, grpc_code in six.iteritems(
CYGRPC_STATUS_CODE_TO_STATUS_CODE)
}
MAXIMUM_WAIT_TIMEOUT = 0.1
_ERROR_MESSAGE_PORT_BINDING_FAILED = 'Failed to bind to address %s; set ' \
'GRPC_VERBOSITY=debug environment variable to see detailed error message.'
def encode(s):
if isinstance(s, bytes):
return s
else:
return s.encode('utf8')
def decode(b):
if isinstance(b, bytes):
return b.decode('utf-8', 'replace')
return b
def _transform(message, transformer, exception_message):
if transformer is None:
return message
else:
try:
return transformer(message)
except Exception: # pylint: disable=broad-except
_LOGGER.exception(exception_message)
return None
def serialize(message, serializer):
return _transform(message, serializer, 'Exception serializing message!')
def deserialize(serialized_message, deserializer):
return _transform(serialized_message, deserializer,
'Exception deserializing message!')
def fully_qualified_method(group, method):
return '/{}/{}'.format(group, method)
def _wait_once(wait_fn, timeout, spin_cb):
wait_fn(timeout=timeout)
if spin_cb is not None:
spin_cb()
def wait(wait_fn, wait_complete_fn, timeout=None, spin_cb=None):
"""Blocks waiting for an event without blocking the thread indefinitely.
See https://github.com/grpc/grpc/issues/19464 for full context. CPython's
`threading.Event.wait` and `threading.Condition.wait` methods, if invoked
without a timeout kwarg, may block the calling thread indefinitely. If the
call is made from the main thread, this means that signal handlers may not
run for an arbitrarily long period of time.
This wrapper calls the supplied wait function with an arbitrary short
timeout to ensure that no signal handler has to wait longer than
MAXIMUM_WAIT_TIMEOUT before executing.
Args:
wait_fn: A callable acceptable a single float-valued kwarg named
`timeout`. This function is expected to be one of `threading.Event.wait`
or `threading.Condition.wait`.
wait_complete_fn: A callable taking no arguments and returning a bool.
When this function returns true, it indicates that waiting should cease.
timeout: An optional float-valued number of seconds after which the wait
should cease.
spin_cb: An optional Callable taking no arguments and returning nothing.
This callback will be called on each iteration of the spin. This may be
used for, e.g. work related to forking.
Returns:
True if a timeout was supplied and it was reached. False otherwise.
"""
if timeout is None:
while not wait_complete_fn():
_wait_once(wait_fn, MAXIMUM_WAIT_TIMEOUT, spin_cb)
else:
end = time.time() + timeout
while not wait_complete_fn():
remaining = min(end - time.time(), MAXIMUM_WAIT_TIMEOUT)
if remaining < 0:
return True
_wait_once(wait_fn, remaining, spin_cb)
return False
def validate_port_binding_result(address, port):
"""Validates if the port binding succeed.
If the port returned by Core is 0, the binding is failed. However, in that
case, the Core API doesn't return a detailed failing reason. The best we
can do is raising an exception to prevent further confusion.
Args:
address: The address string to be bound.
port: An int returned by core
"""
if port == 0:
# The Core API doesn't return a failure message. The best we can do
# is raising an exception to prevent further confusion.
raise RuntimeError(_ERROR_MESSAGE_PORT_BINDING_FAILED % address)
else:
return port

View file

@ -0,0 +1,55 @@
# Copyright 2019 The gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from grpc._cython import cygrpc
NoCompression = cygrpc.CompressionAlgorithm.none
Deflate = cygrpc.CompressionAlgorithm.deflate
Gzip = cygrpc.CompressionAlgorithm.gzip
_METADATA_STRING_MAPPING = {
NoCompression: 'identity',
Deflate: 'deflate',
Gzip: 'gzip',
}
def _compression_algorithm_to_metadata_value(compression):
return _METADATA_STRING_MAPPING[compression]
def compression_algorithm_to_metadata(compression):
return (cygrpc.GRPC_COMPRESSION_REQUEST_ALGORITHM_MD_KEY,
_compression_algorithm_to_metadata_value(compression))
def create_channel_option(compression):
return ((cygrpc.GRPC_COMPRESSION_CHANNEL_DEFAULT_ALGORITHM,
int(compression)),) if compression else ()
def augment_metadata(metadata, compression):
if not metadata and not compression:
return None
base_metadata = tuple(metadata) if metadata else ()
compression_metadata = (
compression_algorithm_to_metadata(compression),) if compression else ()
return base_metadata + compression_metadata
__all__ = (
"NoCompression",
"Deflate",
"Gzip",
)

View file

@ -0,0 +1,13 @@
# Copyright 2015 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,13 @@
# Copyright 2015 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1 @@
__version__ = """1.32.0"""

View file

@ -0,0 +1,562 @@
# Copyright 2017 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of gRPC Python interceptors."""
import collections
import sys
import grpc
class _ServicePipeline(object):
def __init__(self, interceptors):
self.interceptors = tuple(interceptors)
def _continuation(self, thunk, index):
return lambda context: self._intercept_at(thunk, index, context)
def _intercept_at(self, thunk, index, context):
if index < len(self.interceptors):
interceptor = self.interceptors[index]
thunk = self._continuation(thunk, index + 1)
return interceptor.intercept_service(thunk, context)
else:
return thunk(context)
def execute(self, thunk, context):
return self._intercept_at(thunk, 0, context)
def service_pipeline(interceptors):
return _ServicePipeline(interceptors) if interceptors else None
class _ClientCallDetails(
collections.namedtuple('_ClientCallDetails',
('method', 'timeout', 'metadata', 'credentials',
'wait_for_ready', 'compression')),
grpc.ClientCallDetails):
pass
def _unwrap_client_call_details(call_details, default_details):
try:
method = call_details.method
except AttributeError:
method = default_details.method
try:
timeout = call_details.timeout
except AttributeError:
timeout = default_details.timeout
try:
metadata = call_details.metadata
except AttributeError:
metadata = default_details.metadata
try:
credentials = call_details.credentials
except AttributeError:
credentials = default_details.credentials
try:
wait_for_ready = call_details.wait_for_ready
except AttributeError:
wait_for_ready = default_details.wait_for_ready
try:
compression = call_details.compression
except AttributeError:
compression = default_details.compression
return method, timeout, metadata, credentials, wait_for_ready, compression
class _FailureOutcome(grpc.RpcError, grpc.Future, grpc.Call): # pylint: disable=too-many-ancestors
def __init__(self, exception, traceback):
super(_FailureOutcome, self).__init__()
self._exception = exception
self._traceback = traceback
def initial_metadata(self):
return None
def trailing_metadata(self):
return None
def code(self):
return grpc.StatusCode.INTERNAL
def details(self):
return 'Exception raised while intercepting the RPC'
def cancel(self):
return False
def cancelled(self):
return False
def is_active(self):
return False
def time_remaining(self):
return None
def running(self):
return False
def done(self):
return True
def result(self, ignored_timeout=None):
raise self._exception
def exception(self, ignored_timeout=None):
return self._exception
def traceback(self, ignored_timeout=None):
return self._traceback
def add_callback(self, unused_callback):
return False
def add_done_callback(self, fn):
fn(self)
def __iter__(self):
return self
def __next__(self):
raise self._exception
def next(self):
return self.__next__()
class _UnaryOutcome(grpc.Call, grpc.Future):
def __init__(self, response, call):
self._response = response
self._call = call
def initial_metadata(self):
return self._call.initial_metadata()
def trailing_metadata(self):
return self._call.trailing_metadata()
def code(self):
return self._call.code()
def details(self):
return self._call.details()
def is_active(self):
return self._call.is_active()
def time_remaining(self):
return self._call.time_remaining()
def cancel(self):
return self._call.cancel()
def add_callback(self, callback):
return self._call.add_callback(callback)
def cancelled(self):
return False
def running(self):
return False
def done(self):
return True
def result(self, ignored_timeout=None):
return self._response
def exception(self, ignored_timeout=None):
return None
def traceback(self, ignored_timeout=None):
return None
def add_done_callback(self, fn):
fn(self)
class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
def __init__(self, thunk, method, interceptor):
self._thunk = thunk
self._method = method
self._interceptor = interceptor
def __call__(self,
request,
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None,
compression=None):
response, ignored_call = self._with_call(request,
timeout=timeout,
metadata=metadata,
credentials=credentials,
wait_for_ready=wait_for_ready,
compression=compression)
return response
def _with_call(self,
request,
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None,
compression=None):
client_call_details = _ClientCallDetails(self._method, timeout,
metadata, credentials,
wait_for_ready, compression)
def continuation(new_details, request):
(new_method, new_timeout, new_metadata, new_credentials,
new_wait_for_ready,
new_compression) = (_unwrap_client_call_details(
new_details, client_call_details))
try:
response, call = self._thunk(new_method).with_call(
request,
timeout=new_timeout,
metadata=new_metadata,
credentials=new_credentials,
wait_for_ready=new_wait_for_ready,
compression=new_compression)
return _UnaryOutcome(response, call)
except grpc.RpcError as rpc_error:
return rpc_error
except Exception as exception: # pylint:disable=broad-except
return _FailureOutcome(exception, sys.exc_info()[2])
call = self._interceptor.intercept_unary_unary(continuation,
client_call_details,
request)
return call.result(), call
def with_call(self,
request,
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None,
compression=None):
return self._with_call(request,
timeout=timeout,
metadata=metadata,
credentials=credentials,
wait_for_ready=wait_for_ready,
compression=compression)
def future(self,
request,
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None,
compression=None):
client_call_details = _ClientCallDetails(self._method, timeout,
metadata, credentials,
wait_for_ready, compression)
def continuation(new_details, request):
(new_method, new_timeout, new_metadata, new_credentials,
new_wait_for_ready,
new_compression) = (_unwrap_client_call_details(
new_details, client_call_details))
return self._thunk(new_method).future(
request,
timeout=new_timeout,
metadata=new_metadata,
credentials=new_credentials,
wait_for_ready=new_wait_for_ready,
compression=new_compression)
try:
return self._interceptor.intercept_unary_unary(
continuation, client_call_details, request)
except Exception as exception: # pylint:disable=broad-except
return _FailureOutcome(exception, sys.exc_info()[2])
class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
def __init__(self, thunk, method, interceptor):
self._thunk = thunk
self._method = method
self._interceptor = interceptor
def __call__(self,
request,
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None,
compression=None):
client_call_details = _ClientCallDetails(self._method, timeout,
metadata, credentials,
wait_for_ready, compression)
def continuation(new_details, request):
(new_method, new_timeout, new_metadata, new_credentials,
new_wait_for_ready,
new_compression) = (_unwrap_client_call_details(
new_details, client_call_details))
return self._thunk(new_method)(request,
timeout=new_timeout,
metadata=new_metadata,
credentials=new_credentials,
wait_for_ready=new_wait_for_ready,
compression=new_compression)
try:
return self._interceptor.intercept_unary_stream(
continuation, client_call_details, request)
except Exception as exception: # pylint:disable=broad-except
return _FailureOutcome(exception, sys.exc_info()[2])
class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
def __init__(self, thunk, method, interceptor):
self._thunk = thunk
self._method = method
self._interceptor = interceptor
def __call__(self,
request_iterator,
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None,
compression=None):
response, ignored_call = self._with_call(request_iterator,
timeout=timeout,
metadata=metadata,
credentials=credentials,
wait_for_ready=wait_for_ready,
compression=compression)
return response
def _with_call(self,
request_iterator,
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None,
compression=None):
client_call_details = _ClientCallDetails(self._method, timeout,
metadata, credentials,
wait_for_ready, compression)
def continuation(new_details, request_iterator):
(new_method, new_timeout, new_metadata, new_credentials,
new_wait_for_ready,
new_compression) = (_unwrap_client_call_details(
new_details, client_call_details))
try:
response, call = self._thunk(new_method).with_call(
request_iterator,
timeout=new_timeout,
metadata=new_metadata,
credentials=new_credentials,
wait_for_ready=new_wait_for_ready,
compression=new_compression)
return _UnaryOutcome(response, call)
except grpc.RpcError as rpc_error:
return rpc_error
except Exception as exception: # pylint:disable=broad-except
return _FailureOutcome(exception, sys.exc_info()[2])
call = self._interceptor.intercept_stream_unary(continuation,
client_call_details,
request_iterator)
return call.result(), call
def with_call(self,
request_iterator,
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None,
compression=None):
return self._with_call(request_iterator,
timeout=timeout,
metadata=metadata,
credentials=credentials,
wait_for_ready=wait_for_ready,
compression=compression)
def future(self,
request_iterator,
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None,
compression=None):
client_call_details = _ClientCallDetails(self._method, timeout,
metadata, credentials,
wait_for_ready, compression)
def continuation(new_details, request_iterator):
(new_method, new_timeout, new_metadata, new_credentials,
new_wait_for_ready,
new_compression) = (_unwrap_client_call_details(
new_details, client_call_details))
return self._thunk(new_method).future(
request_iterator,
timeout=new_timeout,
metadata=new_metadata,
credentials=new_credentials,
wait_for_ready=new_wait_for_ready,
compression=new_compression)
try:
return self._interceptor.intercept_stream_unary(
continuation, client_call_details, request_iterator)
except Exception as exception: # pylint:disable=broad-except
return _FailureOutcome(exception, sys.exc_info()[2])
class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
def __init__(self, thunk, method, interceptor):
self._thunk = thunk
self._method = method
self._interceptor = interceptor
def __call__(self,
request_iterator,
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None,
compression=None):
client_call_details = _ClientCallDetails(self._method, timeout,
metadata, credentials,
wait_for_ready, compression)
def continuation(new_details, request_iterator):
(new_method, new_timeout, new_metadata, new_credentials,
new_wait_for_ready,
new_compression) = (_unwrap_client_call_details(
new_details, client_call_details))
return self._thunk(new_method)(request_iterator,
timeout=new_timeout,
metadata=new_metadata,
credentials=new_credentials,
wait_for_ready=new_wait_for_ready,
compression=new_compression)
try:
return self._interceptor.intercept_stream_stream(
continuation, client_call_details, request_iterator)
except Exception as exception: # pylint:disable=broad-except
return _FailureOutcome(exception, sys.exc_info()[2])
class _Channel(grpc.Channel):
def __init__(self, channel, interceptor):
self._channel = channel
self._interceptor = interceptor
def subscribe(self, callback, try_to_connect=False):
self._channel.subscribe(callback, try_to_connect=try_to_connect)
def unsubscribe(self, callback):
self._channel.unsubscribe(callback)
def unary_unary(self,
method,
request_serializer=None,
response_deserializer=None):
thunk = lambda m: self._channel.unary_unary(m, request_serializer,
response_deserializer)
if isinstance(self._interceptor, grpc.UnaryUnaryClientInterceptor):
return _UnaryUnaryMultiCallable(thunk, method, self._interceptor)
else:
return thunk(method)
def unary_stream(self,
method,
request_serializer=None,
response_deserializer=None):
thunk = lambda m: self._channel.unary_stream(m, request_serializer,
response_deserializer)
if isinstance(self._interceptor, grpc.UnaryStreamClientInterceptor):
return _UnaryStreamMultiCallable(thunk, method, self._interceptor)
else:
return thunk(method)
def stream_unary(self,
method,
request_serializer=None,
response_deserializer=None):
thunk = lambda m: self._channel.stream_unary(m, request_serializer,
response_deserializer)
if isinstance(self._interceptor, grpc.StreamUnaryClientInterceptor):
return _StreamUnaryMultiCallable(thunk, method, self._interceptor)
else:
return thunk(method)
def stream_stream(self,
method,
request_serializer=None,
response_deserializer=None):
thunk = lambda m: self._channel.stream_stream(m, request_serializer,
response_deserializer)
if isinstance(self._interceptor, grpc.StreamStreamClientInterceptor):
return _StreamStreamMultiCallable(thunk, method, self._interceptor)
else:
return thunk(method)
def _close(self):
self._channel.close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self._close()
return False
def close(self):
self._channel.close()
def intercept_channel(channel, *interceptors):
for interceptor in reversed(list(interceptors)):
if not isinstance(interceptor, grpc.UnaryUnaryClientInterceptor) and \
not isinstance(interceptor, grpc.UnaryStreamClientInterceptor) and \
not isinstance(interceptor, grpc.StreamUnaryClientInterceptor) and \
not isinstance(interceptor, grpc.StreamStreamClientInterceptor):
raise TypeError('interceptor must be '
'grpc.UnaryUnaryClientInterceptor or '
'grpc.UnaryStreamClientInterceptor or '
'grpc.StreamUnaryClientInterceptor or '
'grpc.StreamStreamClientInterceptor or ')
channel = _Channel(channel, interceptor)
return channel

View file

@ -0,0 +1,101 @@
# Copyright 2015 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import logging
import threading
import grpc
from grpc import _common
from grpc._cython import cygrpc
_LOGGER = logging.getLogger(__name__)
class _AuthMetadataContext(
collections.namedtuple('AuthMetadataContext', (
'service_url',
'method_name',
)), grpc.AuthMetadataContext):
pass
class _CallbackState(object):
def __init__(self):
self.lock = threading.Lock()
self.called = False
self.exception = None
class _AuthMetadataPluginCallback(grpc.AuthMetadataPluginCallback):
def __init__(self, state, callback):
self._state = state
self._callback = callback
def __call__(self, metadata, error):
with self._state.lock:
if self._state.exception is None:
if self._state.called:
raise RuntimeError(
'AuthMetadataPluginCallback invoked more than once!')
else:
self._state.called = True
else:
raise RuntimeError(
'AuthMetadataPluginCallback raised exception "{}"!'.format(
self._state.exception))
if error is None:
self._callback(metadata, cygrpc.StatusCode.ok, None)
else:
self._callback(None, cygrpc.StatusCode.internal,
_common.encode(str(error)))
class _Plugin(object):
def __init__(self, metadata_plugin):
self._metadata_plugin = metadata_plugin
def __call__(self, service_url, method_name, callback):
context = _AuthMetadataContext(_common.decode(service_url),
_common.decode(method_name))
callback_state = _CallbackState()
try:
self._metadata_plugin(
context, _AuthMetadataPluginCallback(callback_state, callback))
except Exception as exception: # pylint: disable=broad-except
_LOGGER.exception(
'AuthMetadataPluginCallback "%s" raised exception!',
self._metadata_plugin)
with callback_state.lock:
callback_state.exception = exception
if callback_state.called:
return
callback(None, cygrpc.StatusCode.internal,
_common.encode(str(exception)))
def metadata_plugin_call_credentials(metadata_plugin, name):
if name is None:
try:
effective_name = metadata_plugin.__name__
except AttributeError:
effective_name = metadata_plugin.__class__.__name__
else:
effective_name = name
return grpc.CallCredentials(
cygrpc.MetadataPluginCallCredentials(_Plugin(metadata_plugin),
_common.encode(effective_name)))

View file

@ -0,0 +1,171 @@
# Copyright 2020 The gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
_REQUIRED_SYMBOLS = ("_protos", "_services", "_protos_and_services")
def _uninstalled_protos(*args, **kwargs):
raise NotImplementedError(
"Install the grpcio-tools package (1.32.0+) to use the protos function."
)
def _uninstalled_services(*args, **kwargs):
raise NotImplementedError(
"Install the grpcio-tools package (1.32.0+) to use the services function."
)
def _uninstalled_protos_and_services(*args, **kwargs):
raise NotImplementedError(
"Install the grpcio-tools package (1.32.0+) to use the protos_and_services function."
)
def _interpreter_version_protos(*args, **kwargs):
raise NotImplementedError(
"The protos function is only on available on Python 3.X interpreters.")
def _interpreter_version_services(*args, **kwargs):
raise NotImplementedError(
"The services function is only on available on Python 3.X interpreters."
)
def _interpreter_version_protos_and_services(*args, **kwargs):
raise NotImplementedError(
"The protos_and_services function is only on available on Python 3.X interpreters."
)
def protos(protobuf_path): # pylint: disable=unused-argument
"""Returns a module generated by the indicated .proto file.
THIS IS AN EXPERIMENTAL API.
Use this function to retrieve classes corresponding to message
definitions in the .proto file.
To inspect the contents of the returned module, use the dir function.
For example:
```
protos = grpc.protos("foo.proto")
print(dir(protos))
```
The returned module object corresponds to the _pb2.py file generated
by protoc. The path is expected to be relative to an entry on sys.path
and all transitive dependencies of the file should also be resolveable
from an entry on sys.path.
To completely disable the machinery behind this function, set the
GRPC_PYTHON_DISABLE_DYNAMIC_STUBS environment variable to "true".
Args:
protobuf_path: The path to the .proto file on the filesystem. This path
must be resolveable from an entry on sys.path and so must all of its
transitive dependencies.
Returns:
A module object corresponding to the message code for the indicated
.proto file. Equivalent to a generated _pb2.py file.
"""
def services(protobuf_path): # pylint: disable=unused-argument
"""Returns a module generated by the indicated .proto file.
THIS IS AN EXPERIMENTAL API.
Use this function to retrieve classes and functions corresponding to
service definitions in the .proto file, including both stub and servicer
definitions.
To inspect the contents of the returned module, use the dir function.
For example:
```
services = grpc.services("foo.proto")
print(dir(services))
```
The returned module object corresponds to the _pb2_grpc.py file generated
by protoc. The path is expected to be relative to an entry on sys.path
and all transitive dependencies of the file should also be resolveable
from an entry on sys.path.
To completely disable the machinery behind this function, set the
GRPC_PYTHON_DISABLE_DYNAMIC_STUBS environment variable to "true".
Args:
protobuf_path: The path to the .proto file on the filesystem. This path
must be resolveable from an entry on sys.path and so must all of its
transitive dependencies.
Returns:
A module object corresponding to the stub/service code for the indicated
.proto file. Equivalent to a generated _pb2_grpc.py file.
"""
def protos_and_services(protobuf_path): # pylint: disable=unused-argument
"""Returns a 2-tuple of modules corresponding to protos and services.
THIS IS AN EXPERIMENTAL API.
The return value of this function is equivalent to a call to protos and a
call to services.
To completely disable the machinery behind this function, set the
GRPC_PYTHON_DISABLE_DYNAMIC_STUBS environment variable to "true".
Args:
protobuf_path: The path to the .proto file on the filesystem. This path
must be resolveable from an entry on sys.path and so must all of its
transitive dependencies.
Returns:
A 2-tuple of module objects corresponding to (protos(path), services(path)).
"""
if sys.version_info < (3, 5, 0):
protos = _interpreter_version_protos
services = _interpreter_version_services
protos_and_services = _interpreter_version_protos_and_services
else:
try:
import grpc_tools # pylint: disable=unused-import
except ImportError as e:
# NOTE: It's possible that we're encountering a transitive ImportError, so
# we check for that and re-raise if so.
if "grpc_tools" not in e.args[0]:
raise
protos = _uninstalled_protos
services = _uninstalled_services
protos_and_services = _uninstalled_protos_and_services
else:
import grpc_tools.protoc # pylint: disable=unused-import
if all(hasattr(grpc_tools.protoc, sym) for sym in _REQUIRED_SYMBOLS):
from grpc_tools.protoc import _protos as protos # pylint: disable=unused-import
from grpc_tools.protoc import _services as services # pylint: disable=unused-import
from grpc_tools.protoc import _protos_and_services as protos_and_services # pylint: disable=unused-import
else:
protos = _uninstalled_protos
services = _uninstalled_services
protos_and_services = _uninstalled_protos_and_services

View file

@ -0,0 +1,995 @@
# Copyright 2016 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Service-side implementation of gRPC Python."""
import collections
import enum
import logging
import threading
import time
from concurrent import futures
import six
import grpc
from grpc import _common
from grpc import _compression
from grpc import _interceptor
from grpc._cython import cygrpc
_LOGGER = logging.getLogger(__name__)
_SHUTDOWN_TAG = 'shutdown'
_REQUEST_CALL_TAG = 'request_call'
_RECEIVE_CLOSE_ON_SERVER_TOKEN = 'receive_close_on_server'
_SEND_INITIAL_METADATA_TOKEN = 'send_initial_metadata'
_RECEIVE_MESSAGE_TOKEN = 'receive_message'
_SEND_MESSAGE_TOKEN = 'send_message'
_SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN = (
'send_initial_metadata * send_message')
_SEND_STATUS_FROM_SERVER_TOKEN = 'send_status_from_server'
_SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN = (
'send_initial_metadata * send_status_from_server')
_OPEN = 'open'
_CLOSED = 'closed'
_CANCELLED = 'cancelled'
_EMPTY_FLAGS = 0
_DEALLOCATED_SERVER_CHECK_PERIOD_S = 1.0
_INF_TIMEOUT = 1e9
def _serialized_request(request_event):
return request_event.batch_operations[0].message()
def _application_code(code):
cygrpc_code = _common.STATUS_CODE_TO_CYGRPC_STATUS_CODE.get(code)
return cygrpc.StatusCode.unknown if cygrpc_code is None else cygrpc_code
def _completion_code(state):
if state.code is None:
return cygrpc.StatusCode.ok
else:
return _application_code(state.code)
def _abortion_code(state, code):
if state.code is None:
return code
else:
return _application_code(state.code)
def _details(state):
return b'' if state.details is None else state.details
class _HandlerCallDetails(
collections.namedtuple('_HandlerCallDetails', (
'method',
'invocation_metadata',
)), grpc.HandlerCallDetails):
pass
class _RPCState(object):
def __init__(self):
self.condition = threading.Condition()
self.due = set()
self.request = None
self.client = _OPEN
self.initial_metadata_allowed = True
self.compression_algorithm = None
self.disable_next_compression = False
self.trailing_metadata = None
self.code = None
self.details = None
self.statused = False
self.rpc_errors = []
self.callbacks = []
self.aborted = False
def _raise_rpc_error(state):
rpc_error = grpc.RpcError()
state.rpc_errors.append(rpc_error)
raise rpc_error
def _possibly_finish_call(state, token):
state.due.remove(token)
if not _is_rpc_state_active(state) and not state.due:
callbacks = state.callbacks
state.callbacks = None
return state, callbacks
else:
return None, ()
def _send_status_from_server(state, token):
def send_status_from_server(unused_send_status_from_server_event):
with state.condition:
return _possibly_finish_call(state, token)
return send_status_from_server
def _get_initial_metadata(state, metadata):
with state.condition:
if state.compression_algorithm:
compression_metadata = (
_compression.compression_algorithm_to_metadata(
state.compression_algorithm),)
if metadata is None:
return compression_metadata
else:
return compression_metadata + tuple(metadata)
else:
return metadata
def _get_initial_metadata_operation(state, metadata):
operation = cygrpc.SendInitialMetadataOperation(
_get_initial_metadata(state, metadata), _EMPTY_FLAGS)
return operation
def _abort(state, call, code, details):
if state.client is not _CANCELLED:
effective_code = _abortion_code(state, code)
effective_details = details if state.details is None else state.details
if state.initial_metadata_allowed:
operations = (
_get_initial_metadata_operation(state, None),
cygrpc.SendStatusFromServerOperation(state.trailing_metadata,
effective_code,
effective_details,
_EMPTY_FLAGS),
)
token = _SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN
else:
operations = (cygrpc.SendStatusFromServerOperation(
state.trailing_metadata, effective_code, effective_details,
_EMPTY_FLAGS),)
token = _SEND_STATUS_FROM_SERVER_TOKEN
call.start_server_batch(operations,
_send_status_from_server(state, token))
state.statused = True
state.due.add(token)
def _receive_close_on_server(state):
def receive_close_on_server(receive_close_on_server_event):
with state.condition:
if receive_close_on_server_event.batch_operations[0].cancelled():
state.client = _CANCELLED
elif state.client is _OPEN:
state.client = _CLOSED
state.condition.notify_all()
return _possibly_finish_call(state, _RECEIVE_CLOSE_ON_SERVER_TOKEN)
return receive_close_on_server
def _receive_message(state, call, request_deserializer):
def receive_message(receive_message_event):
serialized_request = _serialized_request(receive_message_event)
if serialized_request is None:
with state.condition:
if state.client is _OPEN:
state.client = _CLOSED
state.condition.notify_all()
return _possibly_finish_call(state, _RECEIVE_MESSAGE_TOKEN)
else:
request = _common.deserialize(serialized_request,
request_deserializer)
with state.condition:
if request is None:
_abort(state, call, cygrpc.StatusCode.internal,
b'Exception deserializing request!')
else:
state.request = request
state.condition.notify_all()
return _possibly_finish_call(state, _RECEIVE_MESSAGE_TOKEN)
return receive_message
def _send_initial_metadata(state):
def send_initial_metadata(unused_send_initial_metadata_event):
with state.condition:
return _possibly_finish_call(state, _SEND_INITIAL_METADATA_TOKEN)
return send_initial_metadata
def _send_message(state, token):
def send_message(unused_send_message_event):
with state.condition:
state.condition.notify_all()
return _possibly_finish_call(state, token)
return send_message
class _Context(grpc.ServicerContext):
def __init__(self, rpc_event, state, request_deserializer):
self._rpc_event = rpc_event
self._state = state
self._request_deserializer = request_deserializer
def is_active(self):
with self._state.condition:
return _is_rpc_state_active(self._state)
def time_remaining(self):
return max(self._rpc_event.call_details.deadline - time.time(), 0)
def cancel(self):
self._rpc_event.call.cancel()
def add_callback(self, callback):
with self._state.condition:
if self._state.callbacks is None:
return False
else:
self._state.callbacks.append(callback)
return True
def disable_next_message_compression(self):
with self._state.condition:
self._state.disable_next_compression = True
def invocation_metadata(self):
return self._rpc_event.invocation_metadata
def peer(self):
return _common.decode(self._rpc_event.call.peer())
def peer_identities(self):
return cygrpc.peer_identities(self._rpc_event.call)
def peer_identity_key(self):
id_key = cygrpc.peer_identity_key(self._rpc_event.call)
return id_key if id_key is None else _common.decode(id_key)
def auth_context(self):
return {
_common.decode(key): value for key, value in six.iteritems(
cygrpc.auth_context(self._rpc_event.call))
}
def set_compression(self, compression):
with self._state.condition:
self._state.compression_algorithm = compression
def send_initial_metadata(self, initial_metadata):
with self._state.condition:
if self._state.client is _CANCELLED:
_raise_rpc_error(self._state)
else:
if self._state.initial_metadata_allowed:
operation = _get_initial_metadata_operation(
self._state, initial_metadata)
self._rpc_event.call.start_server_batch(
(operation,), _send_initial_metadata(self._state))
self._state.initial_metadata_allowed = False
self._state.due.add(_SEND_INITIAL_METADATA_TOKEN)
else:
raise ValueError('Initial metadata no longer allowed!')
def set_trailing_metadata(self, trailing_metadata):
with self._state.condition:
self._state.trailing_metadata = trailing_metadata
def abort(self, code, details):
# treat OK like other invalid arguments: fail the RPC
if code == grpc.StatusCode.OK:
_LOGGER.error(
'abort() called with StatusCode.OK; returning UNKNOWN')
code = grpc.StatusCode.UNKNOWN
details = ''
with self._state.condition:
self._state.code = code
self._state.details = _common.encode(details)
self._state.aborted = True
raise Exception()
def abort_with_status(self, status):
self._state.trailing_metadata = status.trailing_metadata
self.abort(status.code, status.details)
def set_code(self, code):
with self._state.condition:
self._state.code = code
def set_details(self, details):
with self._state.condition:
self._state.details = _common.encode(details)
def _finalize_state(self):
pass
class _RequestIterator(object):
def __init__(self, state, call, request_deserializer):
self._state = state
self._call = call
self._request_deserializer = request_deserializer
def _raise_or_start_receive_message(self):
if self._state.client is _CANCELLED:
_raise_rpc_error(self._state)
elif not _is_rpc_state_active(self._state):
raise StopIteration()
else:
self._call.start_server_batch(
(cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),),
_receive_message(self._state, self._call,
self._request_deserializer))
self._state.due.add(_RECEIVE_MESSAGE_TOKEN)
def _look_for_request(self):
if self._state.client is _CANCELLED:
_raise_rpc_error(self._state)
elif (self._state.request is None and
_RECEIVE_MESSAGE_TOKEN not in self._state.due):
raise StopIteration()
else:
request = self._state.request
self._state.request = None
return request
raise AssertionError() # should never run
def _next(self):
with self._state.condition:
self._raise_or_start_receive_message()
while True:
self._state.condition.wait()
request = self._look_for_request()
if request is not None:
return request
def __iter__(self):
return self
def __next__(self):
return self._next()
def next(self):
return self._next()
def _unary_request(rpc_event, state, request_deserializer):
def unary_request():
with state.condition:
if not _is_rpc_state_active(state):
return None
else:
rpc_event.call.start_server_batch(
(cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),),
_receive_message(state, rpc_event.call,
request_deserializer))
state.due.add(_RECEIVE_MESSAGE_TOKEN)
while True:
state.condition.wait()
if state.request is None:
if state.client is _CLOSED:
details = '"{}" requires exactly one request message.'.format(
rpc_event.call_details.method)
_abort(state, rpc_event.call,
cygrpc.StatusCode.unimplemented,
_common.encode(details))
return None
elif state.client is _CANCELLED:
return None
else:
request = state.request
state.request = None
return request
return unary_request
def _call_behavior(rpc_event,
state,
behavior,
argument,
request_deserializer,
send_response_callback=None):
from grpc import _create_servicer_context
with _create_servicer_context(rpc_event, state,
request_deserializer) as context:
try:
response_or_iterator = None
if send_response_callback is not None:
response_or_iterator = behavior(argument, context,
send_response_callback)
else:
response_or_iterator = behavior(argument, context)
return response_or_iterator, True
except Exception as exception: # pylint: disable=broad-except
with state.condition:
if state.aborted:
_abort(state, rpc_event.call, cygrpc.StatusCode.unknown,
b'RPC Aborted')
elif exception not in state.rpc_errors:
details = 'Exception calling application: {}'.format(
exception)
_LOGGER.exception(details)
_abort(state, rpc_event.call, cygrpc.StatusCode.unknown,
_common.encode(details))
return None, False
def _take_response_from_response_iterator(rpc_event, state, response_iterator):
try:
return next(response_iterator), True
except StopIteration:
return None, True
except Exception as exception: # pylint: disable=broad-except
with state.condition:
if state.aborted:
_abort(state, rpc_event.call, cygrpc.StatusCode.unknown,
b'RPC Aborted')
elif exception not in state.rpc_errors:
details = 'Exception iterating responses: {}'.format(exception)
_LOGGER.exception(details)
_abort(state, rpc_event.call, cygrpc.StatusCode.unknown,
_common.encode(details))
return None, False
def _serialize_response(rpc_event, state, response, response_serializer):
serialized_response = _common.serialize(response, response_serializer)
if serialized_response is None:
with state.condition:
_abort(state, rpc_event.call, cygrpc.StatusCode.internal,
b'Failed to serialize response!')
return None
else:
return serialized_response
def _get_send_message_op_flags_from_state(state):
if state.disable_next_compression:
return cygrpc.WriteFlag.no_compress
else:
return _EMPTY_FLAGS
def _reset_per_message_state(state):
with state.condition:
state.disable_next_compression = False
def _send_response(rpc_event, state, serialized_response):
with state.condition:
if not _is_rpc_state_active(state):
return False
else:
if state.initial_metadata_allowed:
operations = (
_get_initial_metadata_operation(state, None),
cygrpc.SendMessageOperation(
serialized_response,
_get_send_message_op_flags_from_state(state)),
)
state.initial_metadata_allowed = False
token = _SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN
else:
operations = (cygrpc.SendMessageOperation(
serialized_response,
_get_send_message_op_flags_from_state(state)),)
token = _SEND_MESSAGE_TOKEN
rpc_event.call.start_server_batch(operations,
_send_message(state, token))
state.due.add(token)
_reset_per_message_state(state)
while True:
state.condition.wait()
if token not in state.due:
return _is_rpc_state_active(state)
def _status(rpc_event, state, serialized_response):
with state.condition:
if state.client is not _CANCELLED:
code = _completion_code(state)
details = _details(state)
operations = [
cygrpc.SendStatusFromServerOperation(state.trailing_metadata,
code, details,
_EMPTY_FLAGS),
]
if state.initial_metadata_allowed:
operations.append(_get_initial_metadata_operation(state, None))
if serialized_response is not None:
operations.append(
cygrpc.SendMessageOperation(
serialized_response,
_get_send_message_op_flags_from_state(state)))
rpc_event.call.start_server_batch(
operations,
_send_status_from_server(state, _SEND_STATUS_FROM_SERVER_TOKEN))
state.statused = True
_reset_per_message_state(state)
state.due.add(_SEND_STATUS_FROM_SERVER_TOKEN)
def _unary_response_in_pool(rpc_event, state, behavior, argument_thunk,
request_deserializer, response_serializer):
cygrpc.install_context_from_request_call_event(rpc_event)
try:
argument = argument_thunk()
if argument is not None:
response, proceed = _call_behavior(rpc_event, state, behavior,
argument, request_deserializer)
if proceed:
serialized_response = _serialize_response(
rpc_event, state, response, response_serializer)
if serialized_response is not None:
_status(rpc_event, state, serialized_response)
finally:
cygrpc.uninstall_context()
def _stream_response_in_pool(rpc_event, state, behavior, argument_thunk,
request_deserializer, response_serializer):
cygrpc.install_context_from_request_call_event(rpc_event)
def send_response(response):
if response is None:
_status(rpc_event, state, None)
else:
serialized_response = _serialize_response(rpc_event, state,
response,
response_serializer)
if serialized_response is not None:
_send_response(rpc_event, state, serialized_response)
try:
argument = argument_thunk()
if argument is not None:
if hasattr(behavior, 'experimental_non_blocking'
) and behavior.experimental_non_blocking:
_call_behavior(rpc_event,
state,
behavior,
argument,
request_deserializer,
send_response_callback=send_response)
else:
response_iterator, proceed = _call_behavior(
rpc_event, state, behavior, argument, request_deserializer)
if proceed:
_send_message_callback_to_blocking_iterator_adapter(
rpc_event, state, send_response, response_iterator)
finally:
cygrpc.uninstall_context()
def _is_rpc_state_active(state):
return state.client is not _CANCELLED and not state.statused
def _send_message_callback_to_blocking_iterator_adapter(rpc_event, state,
send_response_callback,
response_iterator):
while True:
response, proceed = _take_response_from_response_iterator(
rpc_event, state, response_iterator)
if proceed:
send_response_callback(response)
if not _is_rpc_state_active(state):
break
else:
break
def _select_thread_pool_for_behavior(behavior, default_thread_pool):
if hasattr(behavior, 'experimental_thread_pool') and isinstance(
behavior.experimental_thread_pool, futures.ThreadPoolExecutor):
return behavior.experimental_thread_pool
else:
return default_thread_pool
def _handle_unary_unary(rpc_event, state, method_handler, default_thread_pool):
unary_request = _unary_request(rpc_event, state,
method_handler.request_deserializer)
thread_pool = _select_thread_pool_for_behavior(method_handler.unary_unary,
default_thread_pool)
return thread_pool.submit(_unary_response_in_pool, rpc_event, state,
method_handler.unary_unary, unary_request,
method_handler.request_deserializer,
method_handler.response_serializer)
def _handle_unary_stream(rpc_event, state, method_handler, default_thread_pool):
unary_request = _unary_request(rpc_event, state,
method_handler.request_deserializer)
thread_pool = _select_thread_pool_for_behavior(method_handler.unary_stream,
default_thread_pool)
return thread_pool.submit(_stream_response_in_pool, rpc_event, state,
method_handler.unary_stream, unary_request,
method_handler.request_deserializer,
method_handler.response_serializer)
def _handle_stream_unary(rpc_event, state, method_handler, default_thread_pool):
request_iterator = _RequestIterator(state, rpc_event.call,
method_handler.request_deserializer)
thread_pool = _select_thread_pool_for_behavior(method_handler.stream_unary,
default_thread_pool)
return thread_pool.submit(_unary_response_in_pool, rpc_event, state,
method_handler.stream_unary,
lambda: request_iterator,
method_handler.request_deserializer,
method_handler.response_serializer)
def _handle_stream_stream(rpc_event, state, method_handler,
default_thread_pool):
request_iterator = _RequestIterator(state, rpc_event.call,
method_handler.request_deserializer)
thread_pool = _select_thread_pool_for_behavior(method_handler.stream_stream,
default_thread_pool)
return thread_pool.submit(_stream_response_in_pool, rpc_event, state,
method_handler.stream_stream,
lambda: request_iterator,
method_handler.request_deserializer,
method_handler.response_serializer)
def _find_method_handler(rpc_event, generic_handlers, interceptor_pipeline):
def query_handlers(handler_call_details):
for generic_handler in generic_handlers:
method_handler = generic_handler.service(handler_call_details)
if method_handler is not None:
return method_handler
return None
handler_call_details = _HandlerCallDetails(
_common.decode(rpc_event.call_details.method),
rpc_event.invocation_metadata)
if interceptor_pipeline is not None:
return interceptor_pipeline.execute(query_handlers,
handler_call_details)
else:
return query_handlers(handler_call_details)
def _reject_rpc(rpc_event, status, details):
rpc_state = _RPCState()
operations = (
_get_initial_metadata_operation(rpc_state, None),
cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),
cygrpc.SendStatusFromServerOperation(None, status, details,
_EMPTY_FLAGS),
)
rpc_event.call.start_server_batch(operations, lambda ignored_event: (
rpc_state,
(),
))
return rpc_state
def _handle_with_method_handler(rpc_event, method_handler, thread_pool):
state = _RPCState()
with state.condition:
rpc_event.call.start_server_batch(
(cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),),
_receive_close_on_server(state))
state.due.add(_RECEIVE_CLOSE_ON_SERVER_TOKEN)
if method_handler.request_streaming:
if method_handler.response_streaming:
return state, _handle_stream_stream(rpc_event, state,
method_handler, thread_pool)
else:
return state, _handle_stream_unary(rpc_event, state,
method_handler, thread_pool)
else:
if method_handler.response_streaming:
return state, _handle_unary_stream(rpc_event, state,
method_handler, thread_pool)
else:
return state, _handle_unary_unary(rpc_event, state,
method_handler, thread_pool)
def _handle_call(rpc_event, generic_handlers, interceptor_pipeline, thread_pool,
concurrency_exceeded):
if not rpc_event.success:
return None, None
if rpc_event.call_details.method is not None:
try:
method_handler = _find_method_handler(rpc_event, generic_handlers,
interceptor_pipeline)
except Exception as exception: # pylint: disable=broad-except
details = 'Exception servicing handler: {}'.format(exception)
_LOGGER.exception(details)
return _reject_rpc(rpc_event, cygrpc.StatusCode.unknown,
b'Error in service handler!'), None
if method_handler is None:
return _reject_rpc(rpc_event, cygrpc.StatusCode.unimplemented,
b'Method not found!'), None
elif concurrency_exceeded:
return _reject_rpc(rpc_event, cygrpc.StatusCode.resource_exhausted,
b'Concurrent RPC limit exceeded!'), None
else:
return _handle_with_method_handler(rpc_event, method_handler,
thread_pool)
else:
return None, None
@enum.unique
class _ServerStage(enum.Enum):
STOPPED = 'stopped'
STARTED = 'started'
GRACE = 'grace'
class _ServerState(object):
# pylint: disable=too-many-arguments
def __init__(self, completion_queue, server, generic_handlers,
interceptor_pipeline, thread_pool, maximum_concurrent_rpcs):
self.lock = threading.RLock()
self.completion_queue = completion_queue
self.server = server
self.generic_handlers = list(generic_handlers)
self.interceptor_pipeline = interceptor_pipeline
self.thread_pool = thread_pool
self.stage = _ServerStage.STOPPED
self.termination_event = threading.Event()
self.shutdown_events = [self.termination_event]
self.maximum_concurrent_rpcs = maximum_concurrent_rpcs
self.active_rpc_count = 0
# TODO(https://github.com/grpc/grpc/issues/6597): eliminate these fields.
self.rpc_states = set()
self.due = set()
# A "volatile" flag to interrupt the daemon serving thread
self.server_deallocated = False
def _add_generic_handlers(state, generic_handlers):
with state.lock:
state.generic_handlers.extend(generic_handlers)
def _add_insecure_port(state, address):
with state.lock:
return state.server.add_http2_port(address)
def _add_secure_port(state, address, server_credentials):
with state.lock:
return state.server.add_http2_port(address,
server_credentials._credentials)
def _request_call(state):
state.server.request_call(state.completion_queue, state.completion_queue,
_REQUEST_CALL_TAG)
state.due.add(_REQUEST_CALL_TAG)
# TODO(https://github.com/grpc/grpc/issues/6597): delete this function.
def _stop_serving(state):
if not state.rpc_states and not state.due:
state.server.destroy()
for shutdown_event in state.shutdown_events:
shutdown_event.set()
state.stage = _ServerStage.STOPPED
return True
else:
return False
def _on_call_completed(state):
with state.lock:
state.active_rpc_count -= 1
def _process_event_and_continue(state, event):
should_continue = True
if event.tag is _SHUTDOWN_TAG:
with state.lock:
state.due.remove(_SHUTDOWN_TAG)
if _stop_serving(state):
should_continue = False
elif event.tag is _REQUEST_CALL_TAG:
with state.lock:
state.due.remove(_REQUEST_CALL_TAG)
concurrency_exceeded = (
state.maximum_concurrent_rpcs is not None and
state.active_rpc_count >= state.maximum_concurrent_rpcs)
rpc_state, rpc_future = _handle_call(event, state.generic_handlers,
state.interceptor_pipeline,
state.thread_pool,
concurrency_exceeded)
if rpc_state is not None:
state.rpc_states.add(rpc_state)
if rpc_future is not None:
state.active_rpc_count += 1
rpc_future.add_done_callback(
lambda unused_future: _on_call_completed(state))
if state.stage is _ServerStage.STARTED:
_request_call(state)
elif _stop_serving(state):
should_continue = False
else:
rpc_state, callbacks = event.tag(event)
for callback in callbacks:
try:
callback()
except Exception: # pylint: disable=broad-except
_LOGGER.exception('Exception calling callback!')
if rpc_state is not None:
with state.lock:
state.rpc_states.remove(rpc_state)
if _stop_serving(state):
should_continue = False
return should_continue
def _serve(state):
while True:
timeout = time.time() + _DEALLOCATED_SERVER_CHECK_PERIOD_S
event = state.completion_queue.poll(timeout)
if state.server_deallocated:
_begin_shutdown_once(state)
if event.completion_type != cygrpc.CompletionType.queue_timeout:
if not _process_event_and_continue(state, event):
return
# We want to force the deletion of the previous event
# ~before~ we poll again; if the event has a reference
# to a shutdown Call object, this can induce spinlock.
event = None
def _begin_shutdown_once(state):
with state.lock:
if state.stage is _ServerStage.STARTED:
state.server.shutdown(state.completion_queue, _SHUTDOWN_TAG)
state.stage = _ServerStage.GRACE
state.due.add(_SHUTDOWN_TAG)
def _stop(state, grace):
with state.lock:
if state.stage is _ServerStage.STOPPED:
shutdown_event = threading.Event()
shutdown_event.set()
return shutdown_event
else:
_begin_shutdown_once(state)
shutdown_event = threading.Event()
state.shutdown_events.append(shutdown_event)
if grace is None:
state.server.cancel_all_calls()
else:
def cancel_all_calls_after_grace():
shutdown_event.wait(timeout=grace)
with state.lock:
state.server.cancel_all_calls()
thread = threading.Thread(target=cancel_all_calls_after_grace)
thread.start()
return shutdown_event
shutdown_event.wait()
return shutdown_event
def _start(state):
with state.lock:
if state.stage is not _ServerStage.STOPPED:
raise ValueError('Cannot start already-started server!')
state.server.start()
state.stage = _ServerStage.STARTED
_request_call(state)
thread = threading.Thread(target=_serve, args=(state,))
thread.daemon = True
thread.start()
def _validate_generic_rpc_handlers(generic_rpc_handlers):
for generic_rpc_handler in generic_rpc_handlers:
service_attribute = getattr(generic_rpc_handler, 'service', None)
if service_attribute is None:
raise AttributeError(
'"{}" must conform to grpc.GenericRpcHandler type but does '
'not have "service" method!'.format(generic_rpc_handler))
def _augment_options(base_options, compression):
compression_option = _compression.create_channel_option(compression)
return tuple(base_options) + compression_option
class _Server(grpc.Server):
# pylint: disable=too-many-arguments
def __init__(self, thread_pool, generic_handlers, interceptors, options,
maximum_concurrent_rpcs, compression):
completion_queue = cygrpc.CompletionQueue()
server = cygrpc.Server(_augment_options(options, compression))
server.register_completion_queue(completion_queue)
self._state = _ServerState(completion_queue, server, generic_handlers,
_interceptor.service_pipeline(interceptors),
thread_pool, maximum_concurrent_rpcs)
def add_generic_rpc_handlers(self, generic_rpc_handlers):
_validate_generic_rpc_handlers(generic_rpc_handlers)
_add_generic_handlers(self._state, generic_rpc_handlers)
def add_insecure_port(self, address):
return _common.validate_port_binding_result(
address, _add_insecure_port(self._state, _common.encode(address)))
def add_secure_port(self, address, server_credentials):
return _common.validate_port_binding_result(
address,
_add_secure_port(self._state, _common.encode(address),
server_credentials))
def start(self):
_start(self._state)
def wait_for_termination(self, timeout=None):
# NOTE(https://bugs.python.org/issue35935)
# Remove this workaround once threading.Event.wait() is working with
# CTRL+C across platforms.
return _common.wait(self._state.termination_event.wait,
self._state.termination_event.is_set,
timeout=timeout)
def stop(self, grace):
return _stop(self._state, grace)
def __del__(self):
if hasattr(self, '_state'):
# We can not grab a lock in __del__(), so set a flag to signal the
# serving daemon thread (if it exists) to initiate shutdown.
self._state.server_deallocated = True
def create_server(thread_pool, generic_rpc_handlers, interceptors, options,
maximum_concurrent_rpcs, compression):
_validate_generic_rpc_handlers(generic_rpc_handlers)
return _Server(thread_pool, generic_rpc_handlers, interceptors, options,
maximum_concurrent_rpcs, compression)

View file

@ -0,0 +1,493 @@
# Copyright 2020 The gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions that obviate explicit stubs and explicit channels."""
import collections
import datetime
import os
import logging
import threading
from typing import (Any, AnyStr, Callable, Dict, Iterator, Optional, Sequence,
Tuple, TypeVar, Union)
import grpc
from grpc.experimental import experimental_api
RequestType = TypeVar('RequestType')
ResponseType = TypeVar('ResponseType')
OptionsType = Sequence[Tuple[str, str]]
CacheKey = Tuple[str, OptionsType, Optional[grpc.ChannelCredentials], Optional[
grpc.Compression]]
_LOGGER = logging.getLogger(__name__)
_EVICTION_PERIOD_KEY = "GRPC_PYTHON_MANAGED_CHANNEL_EVICTION_SECONDS"
if _EVICTION_PERIOD_KEY in os.environ:
_EVICTION_PERIOD = datetime.timedelta(
seconds=float(os.environ[_EVICTION_PERIOD_KEY]))
_LOGGER.debug("Setting managed channel eviction period to %s",
_EVICTION_PERIOD)
else:
_EVICTION_PERIOD = datetime.timedelta(minutes=10)
_MAXIMUM_CHANNELS_KEY = "GRPC_PYTHON_MANAGED_CHANNEL_MAXIMUM"
if _MAXIMUM_CHANNELS_KEY in os.environ:
_MAXIMUM_CHANNELS = int(os.environ[_MAXIMUM_CHANNELS_KEY])
_LOGGER.debug("Setting maximum managed channels to %d", _MAXIMUM_CHANNELS)
else:
_MAXIMUM_CHANNELS = 2**8
_DEFAULT_TIMEOUT_KEY = "GRPC_PYTHON_DEFAULT_TIMEOUT_SECONDS"
if _DEFAULT_TIMEOUT_KEY in os.environ:
_DEFAULT_TIMEOUT = float(os.environ[_DEFAULT_TIMEOUT_KEY])
_LOGGER.debug("Setting default timeout seconds to %f", _DEFAULT_TIMEOUT)
else:
_DEFAULT_TIMEOUT = 60.0
def _create_channel(target: str, options: Sequence[Tuple[str, str]],
channel_credentials: Optional[grpc.ChannelCredentials],
compression: Optional[grpc.Compression]) -> grpc.Channel:
if channel_credentials is grpc.experimental.insecure_channel_credentials():
_LOGGER.debug(f"Creating insecure channel with options '{options}' " +
f"and compression '{compression}'")
return grpc.insecure_channel(target,
options=options,
compression=compression)
else:
_LOGGER.debug(
f"Creating secure channel with credentials '{channel_credentials}', "
+ f"options '{options}' and compression '{compression}'")
return grpc.secure_channel(target,
credentials=channel_credentials,
options=options,
compression=compression)
class ChannelCache:
# NOTE(rbellevi): Untyped due to reference cycle.
_singleton = None
_lock: threading.RLock = threading.RLock()
_condition: threading.Condition = threading.Condition(lock=_lock)
_eviction_ready: threading.Event = threading.Event()
_mapping: Dict[CacheKey, Tuple[grpc.Channel, datetime.datetime]]
_eviction_thread: threading.Thread
def __init__(self):
self._mapping = collections.OrderedDict()
self._eviction_thread = threading.Thread(
target=ChannelCache._perform_evictions, daemon=True)
self._eviction_thread.start()
@staticmethod
def get():
with ChannelCache._lock:
if ChannelCache._singleton is None:
ChannelCache._singleton = ChannelCache()
ChannelCache._eviction_ready.wait()
return ChannelCache._singleton
def _evict_locked(self, key: CacheKey):
channel, _ = self._mapping.pop(key)
_LOGGER.debug("Evicting channel %s with configuration %s.", channel,
key)
channel.close()
del channel
@staticmethod
def _perform_evictions():
while True:
with ChannelCache._lock:
ChannelCache._eviction_ready.set()
if not ChannelCache._singleton._mapping:
ChannelCache._condition.wait()
elif len(ChannelCache._singleton._mapping) > _MAXIMUM_CHANNELS:
key = next(iter(ChannelCache._singleton._mapping.keys()))
ChannelCache._singleton._evict_locked(key)
# And immediately reevaluate.
else:
key, (_, eviction_time) = next(
iter(ChannelCache._singleton._mapping.items()))
now = datetime.datetime.now()
if eviction_time <= now:
ChannelCache._singleton._evict_locked(key)
continue
else:
time_to_eviction = (eviction_time - now).total_seconds()
# NOTE: We aim to *eventually* coalesce to a state in
# which no overdue channels are in the cache and the
# length of the cache is longer than _MAXIMUM_CHANNELS.
# We tolerate momentary states in which these two
# criteria are not met.
ChannelCache._condition.wait(timeout=time_to_eviction)
def get_channel(self, target: str, options: Sequence[Tuple[str, str]],
channel_credentials: Optional[grpc.ChannelCredentials],
insecure: bool,
compression: Optional[grpc.Compression]) -> grpc.Channel:
if insecure and channel_credentials:
raise ValueError("The insecure option is mutually exclusive with " +
"the channel_credentials option. Please use one " +
"or the other.")
if insecure:
channel_credentials = grpc.experimental.insecure_channel_credentials(
)
elif channel_credentials is None:
_LOGGER.debug("Defaulting to SSL channel credentials.")
channel_credentials = grpc.ssl_channel_credentials()
key = (target, options, channel_credentials, compression)
with self._lock:
channel_data = self._mapping.get(key, None)
if channel_data is not None:
channel = channel_data[0]
self._mapping.pop(key)
self._mapping[key] = (channel, datetime.datetime.now() +
_EVICTION_PERIOD)
return channel
else:
channel = _create_channel(target, options, channel_credentials,
compression)
self._mapping[key] = (channel, datetime.datetime.now() +
_EVICTION_PERIOD)
if len(self._mapping) == 1 or len(
self._mapping) >= _MAXIMUM_CHANNELS:
self._condition.notify()
return channel
def _test_only_channel_count(self) -> int:
with self._lock:
return len(self._mapping)
@experimental_api
def unary_unary(
request: RequestType,
target: str,
method: str,
request_serializer: Optional[Callable[[Any], bytes]] = None,
response_deserializer: Optional[Callable[[bytes], Any]] = None,
options: Sequence[Tuple[AnyStr, AnyStr]] = (),
channel_credentials: Optional[grpc.ChannelCredentials] = None,
insecure: bool = False,
call_credentials: Optional[grpc.CallCredentials] = None,
compression: Optional[grpc.Compression] = None,
wait_for_ready: Optional[bool] = None,
timeout: Optional[float] = _DEFAULT_TIMEOUT,
metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None
) -> ResponseType:
"""Invokes a unary-unary RPC without an explicitly specified channel.
THIS IS AN EXPERIMENTAL API.
This is backed by a per-process cache of channels. Channels are evicted
from the cache after a fixed period by a background. Channels will also be
evicted if more than a configured maximum accumulate.
The default eviction period is 10 minutes. One may set the environment
variable "GRPC_PYTHON_MANAGED_CHANNEL_EVICTION_SECONDS" to configure this.
The default maximum number of channels is 256. One may set the
environment variable "GRPC_PYTHON_MANAGED_CHANNEL_MAXIMUM" to configure
this.
Args:
request: An iterator that yields request values for the RPC.
target: The server address.
method: The name of the RPC method.
request_serializer: Optional :term:`serializer` for serializing the request
message. Request goes unserialized in case None is passed.
response_deserializer: Optional :term:`deserializer` for deserializing the response
message. Response goes undeserialized in case None is passed.
options: An optional list of key-value pairs (:term:`channel_arguments` in gRPC Core
runtime) to configure the channel.
channel_credentials: A credential applied to the whole channel, e.g. the
return value of grpc.ssl_channel_credentials() or
grpc.insecure_channel_credentials().
insecure: If True, specifies channel_credentials as
:term:`grpc.insecure_channel_credentials()`. This option is mutually
exclusive with the `channel_credentials` option.
call_credentials: A call credential applied to each call individually,
e.g. the output of grpc.metadata_call_credentials() or
grpc.access_token_call_credentials().
compression: An optional value indicating the compression method to be
used over the lifetime of the channel, e.g. grpc.Compression.Gzip.
wait_for_ready: An optional flag indicating whether the RPC should fail
immediately if the connection is not ready at the time the RPC is
invoked, or if it should wait until the connection to the server
becomes ready. When using this option, the user will likely also want
to set a timeout. Defaults to True.
timeout: An optional duration of time in seconds to allow for the RPC,
after which an exception will be raised. If timeout is unspecified,
defaults to a timeout controlled by the
GRPC_PYTHON_DEFAULT_TIMEOUT_SECONDS environment variable. If that is
unset, defaults to 60 seconds. Supply a value of None to indicate that
no timeout should be enforced.
metadata: Optional metadata to send to the server.
Returns:
The response to the RPC.
"""
channel = ChannelCache.get().get_channel(target, options,
channel_credentials, insecure,
compression)
multicallable = channel.unary_unary(method, request_serializer,
response_deserializer)
wait_for_ready = wait_for_ready if wait_for_ready is not None else True
return multicallable(request,
metadata=metadata,
wait_for_ready=wait_for_ready,
credentials=call_credentials,
timeout=timeout)
@experimental_api
def unary_stream(
request: RequestType,
target: str,
method: str,
request_serializer: Optional[Callable[[Any], bytes]] = None,
response_deserializer: Optional[Callable[[bytes], Any]] = None,
options: Sequence[Tuple[AnyStr, AnyStr]] = (),
channel_credentials: Optional[grpc.ChannelCredentials] = None,
insecure: bool = False,
call_credentials: Optional[grpc.CallCredentials] = None,
compression: Optional[grpc.Compression] = None,
wait_for_ready: Optional[bool] = None,
timeout: Optional[float] = _DEFAULT_TIMEOUT,
metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None
) -> Iterator[ResponseType]:
"""Invokes a unary-stream RPC without an explicitly specified channel.
THIS IS AN EXPERIMENTAL API.
This is backed by a per-process cache of channels. Channels are evicted
from the cache after a fixed period by a background. Channels will also be
evicted if more than a configured maximum accumulate.
The default eviction period is 10 minutes. One may set the environment
variable "GRPC_PYTHON_MANAGED_CHANNEL_EVICTION_SECONDS" to configure this.
The default maximum number of channels is 256. One may set the
environment variable "GRPC_PYTHON_MANAGED_CHANNEL_MAXIMUM" to configure
this.
Args:
request: An iterator that yields request values for the RPC.
target: The server address.
method: The name of the RPC method.
request_serializer: Optional :term:`serializer` for serializing the request
message. Request goes unserialized in case None is passed.
response_deserializer: Optional :term:`deserializer` for deserializing the response
message. Response goes undeserialized in case None is passed.
options: An optional list of key-value pairs (:term:`channel_arguments` in gRPC Core
runtime) to configure the channel.
channel_credentials: A credential applied to the whole channel, e.g. the
return value of grpc.ssl_channel_credentials().
insecure: If True, specifies channel_credentials as
:term:`grpc.insecure_channel_credentials()`. This option is mutually
exclusive with the `channel_credentials` option.
call_credentials: A call credential applied to each call individually,
e.g. the output of grpc.metadata_call_credentials() or
grpc.access_token_call_credentials().
compression: An optional value indicating the compression method to be
used over the lifetime of the channel, e.g. grpc.Compression.Gzip.
wait_for_ready: An optional flag indicating whether the RPC should fail
immediately if the connection is not ready at the time the RPC is
invoked, or if it should wait until the connection to the server
becomes ready. When using this option, the user will likely also want
to set a timeout. Defaults to True.
timeout: An optional duration of time in seconds to allow for the RPC,
after which an exception will be raised. If timeout is unspecified,
defaults to a timeout controlled by the
GRPC_PYTHON_DEFAULT_TIMEOUT_SECONDS environment variable. If that is
unset, defaults to 60 seconds. Supply a value of None to indicate that
no timeout should be enforced.
metadata: Optional metadata to send to the server.
Returns:
An iterator of responses.
"""
channel = ChannelCache.get().get_channel(target, options,
channel_credentials, insecure,
compression)
multicallable = channel.unary_stream(method, request_serializer,
response_deserializer)
wait_for_ready = wait_for_ready if wait_for_ready is not None else True
return multicallable(request,
metadata=metadata,
wait_for_ready=wait_for_ready,
credentials=call_credentials,
timeout=timeout)
@experimental_api
def stream_unary(
request_iterator: Iterator[RequestType],
target: str,
method: str,
request_serializer: Optional[Callable[[Any], bytes]] = None,
response_deserializer: Optional[Callable[[bytes], Any]] = None,
options: Sequence[Tuple[AnyStr, AnyStr]] = (),
channel_credentials: Optional[grpc.ChannelCredentials] = None,
insecure: bool = False,
call_credentials: Optional[grpc.CallCredentials] = None,
compression: Optional[grpc.Compression] = None,
wait_for_ready: Optional[bool] = None,
timeout: Optional[float] = _DEFAULT_TIMEOUT,
metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None
) -> ResponseType:
"""Invokes a stream-unary RPC without an explicitly specified channel.
THIS IS AN EXPERIMENTAL API.
This is backed by a per-process cache of channels. Channels are evicted
from the cache after a fixed period by a background. Channels will also be
evicted if more than a configured maximum accumulate.
The default eviction period is 10 minutes. One may set the environment
variable "GRPC_PYTHON_MANAGED_CHANNEL_EVICTION_SECONDS" to configure this.
The default maximum number of channels is 256. One may set the
environment variable "GRPC_PYTHON_MANAGED_CHANNEL_MAXIMUM" to configure
this.
Args:
request_iterator: An iterator that yields request values for the RPC.
target: The server address.
method: The name of the RPC method.
request_serializer: Optional :term:`serializer` for serializing the request
message. Request goes unserialized in case None is passed.
response_deserializer: Optional :term:`deserializer` for deserializing the response
message. Response goes undeserialized in case None is passed.
options: An optional list of key-value pairs (:term:`channel_arguments` in gRPC Core
runtime) to configure the channel.
channel_credentials: A credential applied to the whole channel, e.g. the
return value of grpc.ssl_channel_credentials().
call_credentials: A call credential applied to each call individually,
e.g. the output of grpc.metadata_call_credentials() or
grpc.access_token_call_credentials().
insecure: If True, specifies channel_credentials as
:term:`grpc.insecure_channel_credentials()`. This option is mutually
exclusive with the `channel_credentials` option.
compression: An optional value indicating the compression method to be
used over the lifetime of the channel, e.g. grpc.Compression.Gzip.
wait_for_ready: An optional flag indicating whether the RPC should fail
immediately if the connection is not ready at the time the RPC is
invoked, or if it should wait until the connection to the server
becomes ready. When using this option, the user will likely also want
to set a timeout. Defaults to True.
timeout: An optional duration of time in seconds to allow for the RPC,
after which an exception will be raised. If timeout is unspecified,
defaults to a timeout controlled by the
GRPC_PYTHON_DEFAULT_TIMEOUT_SECONDS environment variable. If that is
unset, defaults to 60 seconds. Supply a value of None to indicate that
no timeout should be enforced.
metadata: Optional metadata to send to the server.
Returns:
The response to the RPC.
"""
channel = ChannelCache.get().get_channel(target, options,
channel_credentials, insecure,
compression)
multicallable = channel.stream_unary(method, request_serializer,
response_deserializer)
wait_for_ready = wait_for_ready if wait_for_ready is not None else True
return multicallable(request_iterator,
metadata=metadata,
wait_for_ready=wait_for_ready,
credentials=call_credentials,
timeout=timeout)
@experimental_api
def stream_stream(
request_iterator: Iterator[RequestType],
target: str,
method: str,
request_serializer: Optional[Callable[[Any], bytes]] = None,
response_deserializer: Optional[Callable[[bytes], Any]] = None,
options: Sequence[Tuple[AnyStr, AnyStr]] = (),
channel_credentials: Optional[grpc.ChannelCredentials] = None,
insecure: bool = False,
call_credentials: Optional[grpc.CallCredentials] = None,
compression: Optional[grpc.Compression] = None,
wait_for_ready: Optional[bool] = None,
timeout: Optional[float] = _DEFAULT_TIMEOUT,
metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None
) -> Iterator[ResponseType]:
"""Invokes a stream-stream RPC without an explicitly specified channel.
THIS IS AN EXPERIMENTAL API.
This is backed by a per-process cache of channels. Channels are evicted
from the cache after a fixed period by a background. Channels will also be
evicted if more than a configured maximum accumulate.
The default eviction period is 10 minutes. One may set the environment
variable "GRPC_PYTHON_MANAGED_CHANNEL_EVICTION_SECONDS" to configure this.
The default maximum number of channels is 256. One may set the
environment variable "GRPC_PYTHON_MANAGED_CHANNEL_MAXIMUM" to configure
this.
Args:
request_iterator: An iterator that yields request values for the RPC.
target: The server address.
method: The name of the RPC method.
request_serializer: Optional :term:`serializer` for serializing the request
message. Request goes unserialized in case None is passed.
response_deserializer: Optional :term:`deserializer` for deserializing the response
message. Response goes undeserialized in case None is passed.
options: An optional list of key-value pairs (:term:`channel_arguments` in gRPC Core
runtime) to configure the channel.
channel_credentials: A credential applied to the whole channel, e.g. the
return value of grpc.ssl_channel_credentials().
call_credentials: A call credential applied to each call individually,
e.g. the output of grpc.metadata_call_credentials() or
grpc.access_token_call_credentials().
insecure: If True, specifies channel_credentials as
:term:`grpc.insecure_channel_credentials()`. This option is mutually
exclusive with the `channel_credentials` option.
compression: An optional value indicating the compression method to be
used over the lifetime of the channel, e.g. grpc.Compression.Gzip.
wait_for_ready: An optional flag indicating whether the RPC should fail
immediately if the connection is not ready at the time the RPC is
invoked, or if it should wait until the connection to the server
becomes ready. When using this option, the user will likely also want
to set a timeout. Defaults to True.
timeout: An optional duration of time in seconds to allow for the RPC,
after which an exception will be raised. If timeout is unspecified,
defaults to a timeout controlled by the
GRPC_PYTHON_DEFAULT_TIMEOUT_SECONDS environment variable. If that is
unset, defaults to 60 seconds. Supply a value of None to indicate that
no timeout should be enforced.
metadata: Optional metadata to send to the server.
Returns:
An iterator of responses.
"""
channel = ChannelCache.get().get_channel(target, options,
channel_credentials, insecure,
compression)
multicallable = channel.stream_stream(method, request_serializer,
response_deserializer)
wait_for_ready = wait_for_ready if wait_for_ready is not None else True
return multicallable(request_iterator,
metadata=metadata,
wait_for_ready=wait_for_ready,
credentials=call_credentials,
timeout=timeout)

View file

@ -0,0 +1,169 @@
# Copyright 2015 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Internal utilities for gRPC Python."""
import collections
import threading
import time
import logging
import six
import grpc
from grpc import _common
_LOGGER = logging.getLogger(__name__)
_DONE_CALLBACK_EXCEPTION_LOG_MESSAGE = (
'Exception calling connectivity future "done" callback!')
class RpcMethodHandler(
collections.namedtuple('_RpcMethodHandler', (
'request_streaming',
'response_streaming',
'request_deserializer',
'response_serializer',
'unary_unary',
'unary_stream',
'stream_unary',
'stream_stream',
)), grpc.RpcMethodHandler):
pass
class DictionaryGenericHandler(grpc.ServiceRpcHandler):
def __init__(self, service, method_handlers):
self._name = service
self._method_handlers = {
_common.fully_qualified_method(service, method): method_handler
for method, method_handler in six.iteritems(method_handlers)
}
def service_name(self):
return self._name
def service(self, handler_call_details):
return self._method_handlers.get(handler_call_details.method)
class _ChannelReadyFuture(grpc.Future):
def __init__(self, channel):
self._condition = threading.Condition()
self._channel = channel
self._matured = False
self._cancelled = False
self._done_callbacks = []
def _block(self, timeout):
until = None if timeout is None else time.time() + timeout
with self._condition:
while True:
if self._cancelled:
raise grpc.FutureCancelledError()
elif self._matured:
return
else:
if until is None:
self._condition.wait()
else:
remaining = until - time.time()
if remaining < 0:
raise grpc.FutureTimeoutError()
else:
self._condition.wait(timeout=remaining)
def _update(self, connectivity):
with self._condition:
if (not self._cancelled and
connectivity is grpc.ChannelConnectivity.READY):
self._matured = True
self._channel.unsubscribe(self._update)
self._condition.notify_all()
done_callbacks = tuple(self._done_callbacks)
self._done_callbacks = None
else:
return
for done_callback in done_callbacks:
try:
done_callback(self)
except Exception: # pylint: disable=broad-except
_LOGGER.exception(_DONE_CALLBACK_EXCEPTION_LOG_MESSAGE)
def cancel(self):
with self._condition:
if not self._matured:
self._cancelled = True
self._channel.unsubscribe(self._update)
self._condition.notify_all()
done_callbacks = tuple(self._done_callbacks)
self._done_callbacks = None
else:
return False
for done_callback in done_callbacks:
try:
done_callback(self)
except Exception: # pylint: disable=broad-except
_LOGGER.exception(_DONE_CALLBACK_EXCEPTION_LOG_MESSAGE)
return True
def cancelled(self):
with self._condition:
return self._cancelled
def running(self):
with self._condition:
return not self._cancelled and not self._matured
def done(self):
with self._condition:
return self._cancelled or self._matured
def result(self, timeout=None):
self._block(timeout)
def exception(self, timeout=None):
self._block(timeout)
def traceback(self, timeout=None):
self._block(timeout)
def add_done_callback(self, fn):
with self._condition:
if not self._cancelled and not self._matured:
self._done_callbacks.append(fn)
return
fn(self)
def start(self):
with self._condition:
self._channel.subscribe(self._update, try_to_connect=True)
def __del__(self):
with self._condition:
if not self._cancelled and not self._matured:
self._channel.unsubscribe(self._update)
def channel_ready_future(channel):
ready_future = _ChannelReadyFuture(channel)
ready_future.start()
return ready_future

View file

@ -0,0 +1,81 @@
# Copyright 2019 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""gRPC's Asynchronous Python API.
gRPC Async API objects may only be used on the thread on which they were
created. AsyncIO doesn't provide thread safety for most of its APIs.
"""
from typing import Any, Optional, Sequence, Tuple
import grpc
from grpc._cython.cygrpc import (init_grpc_aio, shutdown_grpc_aio, EOF,
AbortError, BaseError, InternalError,
UsageError)
from ._base_call import (Call, RpcContext, StreamStreamCall, StreamUnaryCall,
UnaryStreamCall, UnaryUnaryCall)
from ._base_channel import (Channel, StreamStreamMultiCallable,
StreamUnaryMultiCallable, UnaryStreamMultiCallable,
UnaryUnaryMultiCallable)
from ._call import AioRpcError
from ._interceptor import (ClientCallDetails, ClientInterceptor,
InterceptedUnaryUnaryCall,
UnaryUnaryClientInterceptor,
UnaryStreamClientInterceptor,
StreamUnaryClientInterceptor,
StreamStreamClientInterceptor, ServerInterceptor)
from ._server import server
from ._base_server import Server, ServicerContext
from ._typing import ChannelArgumentType
from ._channel import insecure_channel, secure_channel
from ._metadata import Metadata
################################### __all__ #################################
__all__ = (
'init_grpc_aio',
'shutdown_grpc_aio',
'AioRpcError',
'RpcContext',
'Call',
'UnaryUnaryCall',
'UnaryStreamCall',
'StreamUnaryCall',
'StreamStreamCall',
'Channel',
'UnaryUnaryMultiCallable',
'UnaryStreamMultiCallable',
'StreamUnaryMultiCallable',
'StreamStreamMultiCallable',
'ClientCallDetails',
'ClientInterceptor',
'UnaryStreamClientInterceptor',
'UnaryUnaryClientInterceptor',
'StreamUnaryClientInterceptor',
'StreamStreamClientInterceptor',
'InterceptedUnaryUnaryCall',
'ServerInterceptor',
'insecure_channel',
'server',
'Server',
'ServicerContext',
'EOF',
'secure_channel',
'AbortError',
'BaseError',
'UsageError',
'InternalError',
'Metadata',
)

View file

@ -0,0 +1,244 @@
# Copyright 2019 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Abstract base classes for client-side Call objects.
Call objects represents the RPC itself, and offer methods to access / modify
its information. They also offer methods to manipulate the life-cycle of the
RPC, e.g. cancellation.
"""
from abc import ABCMeta, abstractmethod
from typing import AsyncIterable, Awaitable, Generic, Optional, Union
import grpc
from ._typing import (DoneCallbackType, EOFType, RequestType, ResponseType)
from ._metadata import Metadata
__all__ = 'RpcContext', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall'
class RpcContext(metaclass=ABCMeta):
"""Provides RPC-related information and action."""
@abstractmethod
def cancelled(self) -> bool:
"""Return True if the RPC is cancelled.
The RPC is cancelled when the cancellation was requested with cancel().
Returns:
A bool indicates whether the RPC is cancelled or not.
"""
@abstractmethod
def done(self) -> bool:
"""Return True if the RPC is done.
An RPC is done if the RPC is completed, cancelled or aborted.
Returns:
A bool indicates if the RPC is done.
"""
@abstractmethod
def time_remaining(self) -> Optional[float]:
"""Describes the length of allowed time remaining for the RPC.
Returns:
A nonnegative float indicating the length of allowed time in seconds
remaining for the RPC to complete before it is considered to have
timed out, or None if no deadline was specified for the RPC.
"""
@abstractmethod
def cancel(self) -> bool:
"""Cancels the RPC.
Idempotent and has no effect if the RPC has already terminated.
Returns:
A bool indicates if the cancellation is performed or not.
"""
@abstractmethod
def add_done_callback(self, callback: DoneCallbackType) -> None:
"""Registers a callback to be called on RPC termination.
Args:
callback: A callable object will be called with the call object as
its only argument.
"""
class Call(RpcContext, metaclass=ABCMeta):
"""The abstract base class of an RPC on the client-side."""
@abstractmethod
async def initial_metadata(self) -> Metadata:
"""Accesses the initial metadata sent by the server.
Returns:
The initial :term:`metadata`.
"""
@abstractmethod
async def trailing_metadata(self) -> Metadata:
"""Accesses the trailing metadata sent by the server.
Returns:
The trailing :term:`metadata`.
"""
@abstractmethod
async def code(self) -> grpc.StatusCode:
"""Accesses the status code sent by the server.
Returns:
The StatusCode value for the RPC.
"""
@abstractmethod
async def details(self) -> str:
"""Accesses the details sent by the server.
Returns:
The details string of the RPC.
"""
@abstractmethod
async def wait_for_connection(self) -> None:
"""Waits until connected to peer and raises aio.AioRpcError if failed.
This is an EXPERIMENTAL method.
This method ensures the RPC has been successfully connected. Otherwise,
an AioRpcError will be raised to explain the reason of the connection
failure.
This method is recommended for building retry mechanisms.
"""
class UnaryUnaryCall(Generic[RequestType, ResponseType],
Call,
metaclass=ABCMeta):
"""The abstract base class of an unary-unary RPC on the client-side."""
@abstractmethod
def __await__(self) -> Awaitable[ResponseType]:
"""Await the response message to be ready.
Returns:
The response message of the RPC.
"""
class UnaryStreamCall(Generic[RequestType, ResponseType],
Call,
metaclass=ABCMeta):
@abstractmethod
def __aiter__(self) -> AsyncIterable[ResponseType]:
"""Returns the async iterable representation that yields messages.
Under the hood, it is calling the "read" method.
Returns:
An async iterable object that yields messages.
"""
@abstractmethod
async def read(self) -> Union[EOFType, ResponseType]:
"""Reads one message from the stream.
Read operations must be serialized when called from multiple
coroutines.
Returns:
A response message, or an `grpc.aio.EOF` to indicate the end of the
stream.
"""
class StreamUnaryCall(Generic[RequestType, ResponseType],
Call,
metaclass=ABCMeta):
@abstractmethod
async def write(self, request: RequestType) -> None:
"""Writes one message to the stream.
Raises:
An RpcError exception if the write failed.
"""
@abstractmethod
async def done_writing(self) -> None:
"""Notifies server that the client is done sending messages.
After done_writing is called, any additional invocation to the write
function will fail. This function is idempotent.
"""
@abstractmethod
def __await__(self) -> Awaitable[ResponseType]:
"""Await the response message to be ready.
Returns:
The response message of the stream.
"""
class StreamStreamCall(Generic[RequestType, ResponseType],
Call,
metaclass=ABCMeta):
@abstractmethod
def __aiter__(self) -> AsyncIterable[ResponseType]:
"""Returns the async iterable representation that yields messages.
Under the hood, it is calling the "read" method.
Returns:
An async iterable object that yields messages.
"""
@abstractmethod
async def read(self) -> Union[EOFType, ResponseType]:
"""Reads one message from the stream.
Read operations must be serialized when called from multiple
coroutines.
Returns:
A response message, or an `grpc.aio.EOF` to indicate the end of the
stream.
"""
@abstractmethod
async def write(self, request: RequestType) -> None:
"""Writes one message to the stream.
Raises:
An RpcError exception if the write failed.
"""
@abstractmethod
async def done_writing(self) -> None:
"""Notifies server that the client is done sending messages.
After done_writing is called, any additional invocation to the write
function will fail. This function is idempotent.
"""

View file

@ -0,0 +1,347 @@
# Copyright 2020 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Abstract base classes for Channel objects and Multicallable objects."""
import abc
from typing import Any, Optional
import grpc
from . import _base_call
from ._typing import (DeserializingFunction, RequestIterableType,
SerializingFunction)
from ._metadata import Metadata
class UnaryUnaryMultiCallable(abc.ABC):
"""Enables asynchronous invocation of a unary-call RPC."""
@abc.abstractmethod
def __call__(self,
request: Any,
*,
timeout: Optional[float] = None,
metadata: Optional[Metadata] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> _base_call.UnaryUnaryCall:
"""Asynchronously invokes the underlying RPC.
Args:
request: The request value for the RPC.
timeout: An optional duration of time in seconds to allow
for the RPC.
metadata: Optional :term:`metadata` to be transmitted to the
service-side of the RPC.
credentials: An optional CallCredentials for the RPC. Only valid for
secure Channel.
wait_for_ready: This is an EXPERIMENTAL argument. An optional
flag to enable :term:`wait_for_ready` mechanism.
compression: An element of grpc.compression, e.g.
grpc.compression.Gzip. This is an EXPERIMENTAL option.
Returns:
A UnaryUnaryCall object.
Raises:
RpcError: Indicates that the RPC terminated with non-OK status. The
raised RpcError will also be a Call for the RPC affording the RPC's
metadata, status code, and details.
"""
class UnaryStreamMultiCallable(abc.ABC):
"""Enables asynchronous invocation of a server-streaming RPC."""
@abc.abstractmethod
def __call__(self,
request: Any,
*,
timeout: Optional[float] = None,
metadata: Optional[Metadata] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> _base_call.UnaryStreamCall:
"""Asynchronously invokes the underlying RPC.
Args:
request: The request value for the RPC.
timeout: An optional duration of time in seconds to allow
for the RPC.
metadata: Optional :term:`metadata` to be transmitted to the
service-side of the RPC.
credentials: An optional CallCredentials for the RPC. Only valid for
secure Channel.
wait_for_ready: This is an EXPERIMENTAL argument. An optional
flag to enable :term:`wait_for_ready` mechanism.
compression: An element of grpc.compression, e.g.
grpc.compression.Gzip. This is an EXPERIMENTAL option.
Returns:
A UnaryStreamCall object.
Raises:
RpcError: Indicates that the RPC terminated with non-OK status. The
raised RpcError will also be a Call for the RPC affording the RPC's
metadata, status code, and details.
"""
class StreamUnaryMultiCallable(abc.ABC):
"""Enables asynchronous invocation of a client-streaming RPC."""
@abc.abstractmethod
def __call__(self,
request_iterator: Optional[RequestIterableType] = None,
timeout: Optional[float] = None,
metadata: Optional[Metadata] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> _base_call.StreamUnaryCall:
"""Asynchronously invokes the underlying RPC.
Args:
request_iterator: An optional async iterable or iterable of request
messages for the RPC.
timeout: An optional duration of time in seconds to allow
for the RPC.
metadata: Optional :term:`metadata` to be transmitted to the
service-side of the RPC.
credentials: An optional CallCredentials for the RPC. Only valid for
secure Channel.
wait_for_ready: This is an EXPERIMENTAL argument. An optional
flag to enable :term:`wait_for_ready` mechanism.
compression: An element of grpc.compression, e.g.
grpc.compression.Gzip. This is an EXPERIMENTAL option.
Returns:
A StreamUnaryCall object.
Raises:
RpcError: Indicates that the RPC terminated with non-OK status. The
raised RpcError will also be a Call for the RPC affording the RPC's
metadata, status code, and details.
"""
class StreamStreamMultiCallable(abc.ABC):
"""Enables asynchronous invocation of a bidirectional-streaming RPC."""
@abc.abstractmethod
def __call__(self,
request_iterator: Optional[RequestIterableType] = None,
timeout: Optional[float] = None,
metadata: Optional[Metadata] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> _base_call.StreamStreamCall:
"""Asynchronously invokes the underlying RPC.
Args:
request_iterator: An optional async iterable or iterable of request
messages for the RPC.
timeout: An optional duration of time in seconds to allow
for the RPC.
metadata: Optional :term:`metadata` to be transmitted to the
service-side of the RPC.
credentials: An optional CallCredentials for the RPC. Only valid for
secure Channel.
wait_for_ready: This is an EXPERIMENTAL argument. An optional
flag to enable :term:`wait_for_ready` mechanism.
compression: An element of grpc.compression, e.g.
grpc.compression.Gzip. This is an EXPERIMENTAL option.
Returns:
A StreamStreamCall object.
Raises:
RpcError: Indicates that the RPC terminated with non-OK status. The
raised RpcError will also be a Call for the RPC affording the RPC's
metadata, status code, and details.
"""
class Channel(abc.ABC):
"""Enables asynchronous RPC invocation as a client.
Channel objects implement the Asynchronous Context Manager (aka. async
with) type, although they are not supportted to be entered and exited
multiple times.
"""
@abc.abstractmethod
async def __aenter__(self):
"""Starts an asynchronous context manager.
Returns:
Channel the channel that was instantiated.
"""
@abc.abstractmethod
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Finishes the asynchronous context manager by closing the channel.
Still active RPCs will be cancelled.
"""
@abc.abstractmethod
async def close(self, grace: Optional[float] = None):
"""Closes this Channel and releases all resources held by it.
This method immediately stops the channel from executing new RPCs in
all cases.
If a grace period is specified, this method wait until all active
RPCs are finshed, once the grace period is reached the ones that haven't
been terminated are cancelled. If a grace period is not specified
(by passing None for grace), all existing RPCs are cancelled immediately.
This method is idempotent.
"""
@abc.abstractmethod
def get_state(self,
try_to_connect: bool = False) -> grpc.ChannelConnectivity:
"""Checks the connectivity state of a channel.
This is an EXPERIMENTAL API.
If the channel reaches a stable connectivity state, it is guaranteed
that the return value of this function will eventually converge to that
state.
Args:
try_to_connect: a bool indicate whether the Channel should try to
connect to peer or not.
Returns: A ChannelConnectivity object.
"""
@abc.abstractmethod
async def wait_for_state_change(
self,
last_observed_state: grpc.ChannelConnectivity,
) -> None:
"""Waits for a change in connectivity state.
This is an EXPERIMENTAL API.
The function blocks until there is a change in the channel connectivity
state from the "last_observed_state". If the state is already
different, this function will return immediately.
There is an inherent race between the invocation of
"Channel.wait_for_state_change" and "Channel.get_state". The state can
change arbitrary many times during the race, so there is no way to
observe every state transition.
If there is a need to put a timeout for this function, please refer to
"asyncio.wait_for".
Args:
last_observed_state: A grpc.ChannelConnectivity object representing
the last known state.
"""
@abc.abstractmethod
async def channel_ready(self) -> None:
"""Creates a coroutine that blocks until the Channel is READY."""
@abc.abstractmethod
def unary_unary(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None
) -> UnaryUnaryMultiCallable:
"""Creates a UnaryUnaryMultiCallable for a unary-unary method.
Args:
method: The name of the RPC method.
request_serializer: Optional :term:`serializer` for serializing the request
message. Request goes unserialized in case None is passed.
response_deserializer: Optional :term:`deserializer` for deserializing the
response message. Response goes undeserialized in case None
is passed.
Returns:
A UnaryUnaryMultiCallable value for the named unary-unary method.
"""
@abc.abstractmethod
def unary_stream(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None
) -> UnaryStreamMultiCallable:
"""Creates a UnaryStreamMultiCallable for a unary-stream method.
Args:
method: The name of the RPC method.
request_serializer: Optional :term:`serializer` for serializing the request
message. Request goes unserialized in case None is passed.
response_deserializer: Optional :term:`deserializer` for deserializing the
response message. Response goes undeserialized in case None
is passed.
Returns:
A UnarySteramMultiCallable value for the named unary-stream method.
"""
@abc.abstractmethod
def stream_unary(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None
) -> StreamUnaryMultiCallable:
"""Creates a StreamUnaryMultiCallable for a stream-unary method.
Args:
method: The name of the RPC method.
request_serializer: Optional :term:`serializer` for serializing the request
message. Request goes unserialized in case None is passed.
response_deserializer: Optional :term:`deserializer` for deserializing the
response message. Response goes undeserialized in case None
is passed.
Returns:
A StreamUnaryMultiCallable value for the named stream-unary method.
"""
@abc.abstractmethod
def stream_stream(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None
) -> StreamStreamMultiCallable:
"""Creates a StreamStreamMultiCallable for a stream-stream method.
Args:
method: The name of the RPC method.
request_serializer: Optional :term:`serializer` for serializing the request
message. Request goes unserialized in case None is passed.
response_deserializer: Optional :term:`deserializer` for deserializing the
response message. Response goes undeserialized in case None
is passed.
Returns:
A StreamStreamMultiCallable value for the named stream-stream method.
"""

View file

@ -0,0 +1,294 @@
# Copyright 2020 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Abstract base classes for server-side classes."""
import abc
from typing import Generic, Mapping, Optional, Iterable, Sequence
import grpc
from ._typing import RequestType, ResponseType
from ._metadata import Metadata
class Server(abc.ABC):
"""Serves RPCs."""
@abc.abstractmethod
def add_generic_rpc_handlers(
self,
generic_rpc_handlers: Sequence[grpc.GenericRpcHandler]) -> None:
"""Registers GenericRpcHandlers with this Server.
This method is only safe to call before the server is started.
Args:
generic_rpc_handlers: A sequence of GenericRpcHandlers that will be
used to service RPCs.
"""
@abc.abstractmethod
def add_insecure_port(self, address: str) -> int:
"""Opens an insecure port for accepting RPCs.
A port is a communication endpoint that used by networking protocols,
like TCP and UDP. To date, we only support TCP.
This method may only be called before starting the server.
Args:
address: The address for which to open a port. If the port is 0,
or not specified in the address, then the gRPC runtime will choose a port.
Returns:
An integer port on which the server will accept RPC requests.
"""
@abc.abstractmethod
def add_secure_port(self, address: str,
server_credentials: grpc.ServerCredentials) -> int:
"""Opens a secure port for accepting RPCs.
A port is a communication endpoint that used by networking protocols,
like TCP and UDP. To date, we only support TCP.
This method may only be called before starting the server.
Args:
address: The address for which to open a port.
if the port is 0, or not specified in the address, then the gRPC
runtime will choose a port.
server_credentials: A ServerCredentials object.
Returns:
An integer port on which the server will accept RPC requests.
"""
@abc.abstractmethod
async def start(self) -> None:
"""Starts this Server.
This method may only be called once. (i.e. it is not idempotent).
"""
@abc.abstractmethod
async def stop(self, grace: Optional[float]) -> None:
"""Stops this Server.
This method immediately stops the server from servicing new RPCs in
all cases.
If a grace period is specified, this method returns immediately and all
RPCs active at the end of the grace period are aborted. If a grace
period is not specified (by passing None for grace), all existing RPCs
are aborted immediately and this method blocks until the last RPC
handler terminates.
This method is idempotent and may be called at any time. Passing a
smaller grace value in a subsequent call will have the effect of
stopping the Server sooner (passing None will have the effect of
stopping the server immediately). Passing a larger grace value in a
subsequent call will not have the effect of stopping the server later
(i.e. the most restrictive grace value is used).
Args:
grace: A duration of time in seconds or None.
"""
@abc.abstractmethod
async def wait_for_termination(self,
timeout: Optional[float] = None) -> bool:
"""Continues current coroutine once the server stops.
This is an EXPERIMENTAL API.
The wait will not consume computational resources during blocking, and
it will block until one of the two following conditions are met:
1) The server is stopped or terminated;
2) A timeout occurs if timeout is not `None`.
The timeout argument works in the same way as `threading.Event.wait()`.
https://docs.python.org/3/library/threading.html#threading.Event.wait
Args:
timeout: A floating point number specifying a timeout for the
operation in seconds.
Returns:
A bool indicates if the operation times out.
"""
class ServicerContext(Generic[RequestType, ResponseType], abc.ABC):
"""A context object passed to method implementations."""
@abc.abstractmethod
async def read(self) -> RequestType:
"""Reads one message from the RPC.
Only one read operation is allowed simultaneously.
Returns:
A response message of the RPC.
Raises:
An RpcError exception if the read failed.
"""
@abc.abstractmethod
async def write(self, message: ResponseType) -> None:
"""Writes one message to the RPC.
Only one write operation is allowed simultaneously.
Raises:
An RpcError exception if the write failed.
"""
@abc.abstractmethod
async def send_initial_metadata(self, initial_metadata: Metadata) -> None:
"""Sends the initial metadata value to the client.
This method need not be called by implementations if they have no
metadata to add to what the gRPC runtime will transmit.
Args:
initial_metadata: The initial :term:`metadata`.
"""
@abc.abstractmethod
async def abort(self, code: grpc.StatusCode, details: str,
trailing_metadata: Metadata) -> None:
"""Raises an exception to terminate the RPC with a non-OK status.
The code and details passed as arguments will supercede any existing
ones.
Args:
code: A StatusCode object to be sent to the client.
It must not be StatusCode.OK.
details: A UTF-8-encodable string to be sent to the client upon
termination of the RPC.
trailing_metadata: A sequence of tuple represents the trailing
:term:`metadata`.
Raises:
Exception: An exception is always raised to signal the abortion the
RPC to the gRPC runtime.
"""
@abc.abstractmethod
async def set_trailing_metadata(self, trailing_metadata: Metadata) -> None:
"""Sends the trailing metadata for the RPC.
This method need not be called by implementations if they have no
metadata to add to what the gRPC runtime will transmit.
Args:
trailing_metadata: The trailing :term:`metadata`.
"""
@abc.abstractmethod
def invocation_metadata(self) -> Optional[Metadata]:
"""Accesses the metadata from the sent by the client.
Returns:
The invocation :term:`metadata`.
"""
@abc.abstractmethod
def set_code(self, code: grpc.StatusCode) -> None:
"""Sets the value to be used as status code upon RPC completion.
This method need not be called by method implementations if they wish
the gRPC runtime to determine the status code of the RPC.
Args:
code: A StatusCode object to be sent to the client.
"""
@abc.abstractmethod
def set_details(self, details: str) -> None:
"""Sets the value to be used the as detail string upon RPC completion.
This method need not be called by method implementations if they have
no details to transmit.
Args:
details: A UTF-8-encodable string to be sent to the client upon
termination of the RPC.
"""
@abc.abstractmethod
def set_compression(self, compression: grpc.Compression) -> None:
"""Set the compression algorithm to be used for the entire call.
This is an EXPERIMENTAL method.
Args:
compression: An element of grpc.compression, e.g.
grpc.compression.Gzip.
"""
@abc.abstractmethod
def disable_next_message_compression(self) -> None:
"""Disables compression for the next response message.
This is an EXPERIMENTAL method.
This method will override any compression configuration set during
server creation or set on the call.
"""
@abc.abstractmethod
def peer(self) -> str:
"""Identifies the peer that invoked the RPC being serviced.
Returns:
A string identifying the peer that invoked the RPC being serviced.
The string format is determined by gRPC runtime.
"""
@abc.abstractmethod
def peer_identities(self) -> Optional[Iterable[bytes]]:
"""Gets one or more peer identity(s).
Equivalent to
servicer_context.auth_context().get(servicer_context.peer_identity_key())
Returns:
An iterable of the identities, or None if the call is not
authenticated. Each identity is returned as a raw bytes type.
"""
@abc.abstractmethod
def peer_identity_key(self) -> Optional[str]:
"""The auth property used to identify the peer.
For example, "x509_common_name" or "x509_subject_alternative_name" are
used to identify an SSL peer.
Returns:
The auth property (string) that indicates the
peer identity, or None if the call is not authenticated.
"""
@abc.abstractmethod
def auth_context(self) -> Mapping[str, Iterable[bytes]]:
"""Gets the auth context for the call.
Returns:
A map of strings to an iterable of bytes for each auth property.
"""

View file

@ -0,0 +1,629 @@
# Copyright 2019 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Invocation-side implementation of gRPC Asyncio Python."""
import asyncio
import enum
import inspect
import logging
from functools import partial
from typing import AsyncIterable, Optional, Tuple
import grpc
from grpc import _common
from grpc._cython import cygrpc
from . import _base_call
from ._metadata import Metadata
from ._typing import (DeserializingFunction, DoneCallbackType, MetadatumType,
RequestIterableType, RequestType, ResponseType,
SerializingFunction)
__all__ = 'AioRpcError', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall'
_LOCAL_CANCELLATION_DETAILS = 'Locally cancelled by application!'
_GC_CANCELLATION_DETAILS = 'Cancelled upon garbage collection!'
_RPC_ALREADY_FINISHED_DETAILS = 'RPC already finished.'
_RPC_HALF_CLOSED_DETAILS = 'RPC is half closed after calling "done_writing".'
_API_STYLE_ERROR = 'The iterator and read/write APIs may not be mixed on a single RPC.'
_OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n'
'\tstatus = {}\n'
'\tdetails = "{}"\n'
'>')
_NON_OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n'
'\tstatus = {}\n'
'\tdetails = "{}"\n'
'\tdebug_error_string = "{}"\n'
'>')
_LOGGER = logging.getLogger(__name__)
class AioRpcError(grpc.RpcError):
"""An implementation of RpcError to be used by the asynchronous API.
Raised RpcError is a snapshot of the final status of the RPC, values are
determined. Hence, its methods no longer needs to be coroutines.
"""
_code: grpc.StatusCode
_details: Optional[str]
_initial_metadata: Optional[Metadata]
_trailing_metadata: Optional[Metadata]
_debug_error_string: Optional[str]
def __init__(self,
code: grpc.StatusCode,
initial_metadata: Metadata,
trailing_metadata: Metadata,
details: Optional[str] = None,
debug_error_string: Optional[str] = None) -> None:
"""Constructor.
Args:
code: The status code with which the RPC has been finalized.
details: Optional details explaining the reason of the error.
initial_metadata: Optional initial metadata that could be sent by the
Server.
trailing_metadata: Optional metadata that could be sent by the Server.
"""
super().__init__(self)
self._code = code
self._details = details
self._initial_metadata = initial_metadata
self._trailing_metadata = trailing_metadata
self._debug_error_string = debug_error_string
def code(self) -> grpc.StatusCode:
"""Accesses the status code sent by the server.
Returns:
The `grpc.StatusCode` status code.
"""
return self._code
def details(self) -> Optional[str]:
"""Accesses the details sent by the server.
Returns:
The description of the error.
"""
return self._details
def initial_metadata(self) -> Metadata:
"""Accesses the initial metadata sent by the server.
Returns:
The initial metadata received.
"""
return self._initial_metadata
def trailing_metadata(self) -> Metadata:
"""Accesses the trailing metadata sent by the server.
Returns:
The trailing metadata received.
"""
return self._trailing_metadata
def debug_error_string(self) -> str:
"""Accesses the debug error string sent by the server.
Returns:
The debug error string received.
"""
return self._debug_error_string
def _repr(self) -> str:
"""Assembles the error string for the RPC error."""
return _NON_OK_CALL_REPRESENTATION.format(self.__class__.__name__,
self._code, self._details,
self._debug_error_string)
def __repr__(self) -> str:
return self._repr()
def __str__(self) -> str:
return self._repr()
def _create_rpc_error(initial_metadata: Metadata,
status: cygrpc.AioRpcStatus) -> AioRpcError:
return AioRpcError(
_common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[status.code()],
Metadata.from_tuple(initial_metadata),
Metadata.from_tuple(status.trailing_metadata()),
details=status.details(),
debug_error_string=status.debug_error_string(),
)
class Call:
"""Base implementation of client RPC Call object.
Implements logic around final status, metadata and cancellation.
"""
_loop: asyncio.AbstractEventLoop
_code: grpc.StatusCode
_cython_call: cygrpc._AioCall
_metadata: Tuple[MetadatumType]
_request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction
def __init__(self, cython_call: cygrpc._AioCall, metadata: Metadata,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None:
self._loop = loop
self._cython_call = cython_call
self._metadata = tuple(metadata)
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
def __del__(self) -> None:
# The '_cython_call' object might be destructed before Call object
if hasattr(self, '_cython_call'):
if not self._cython_call.done():
self._cancel(_GC_CANCELLATION_DETAILS)
def cancelled(self) -> bool:
return self._cython_call.cancelled()
def _cancel(self, details: str) -> bool:
"""Forwards the application cancellation reasoning."""
if not self._cython_call.done():
self._cython_call.cancel(details)
return True
else:
return False
def cancel(self) -> bool:
return self._cancel(_LOCAL_CANCELLATION_DETAILS)
def done(self) -> bool:
return self._cython_call.done()
def add_done_callback(self, callback: DoneCallbackType) -> None:
cb = partial(callback, self)
self._cython_call.add_done_callback(cb)
def time_remaining(self) -> Optional[float]:
return self._cython_call.time_remaining()
async def initial_metadata(self) -> Metadata:
raw_metadata_tuple = await self._cython_call.initial_metadata()
return Metadata.from_tuple(raw_metadata_tuple)
async def trailing_metadata(self) -> Metadata:
raw_metadata_tuple = (await
self._cython_call.status()).trailing_metadata()
return Metadata.from_tuple(raw_metadata_tuple)
async def code(self) -> grpc.StatusCode:
cygrpc_code = (await self._cython_call.status()).code()
return _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[cygrpc_code]
async def details(self) -> str:
return (await self._cython_call.status()).details()
async def debug_error_string(self) -> str:
return (await self._cython_call.status()).debug_error_string()
async def _raise_for_status(self) -> None:
if self._cython_call.is_locally_cancelled():
raise asyncio.CancelledError()
code = await self.code()
if code != grpc.StatusCode.OK:
raise _create_rpc_error(await self.initial_metadata(), await
self._cython_call.status())
def _repr(self) -> str:
return repr(self._cython_call)
def __repr__(self) -> str:
return self._repr()
def __str__(self) -> str:
return self._repr()
class _APIStyle(enum.IntEnum):
UNKNOWN = 0
ASYNC_GENERATOR = 1
READER_WRITER = 2
class _UnaryResponseMixin(Call):
_call_response: asyncio.Task
def _init_unary_response_mixin(self, response_task: asyncio.Task):
self._call_response = response_task
def cancel(self) -> bool:
if super().cancel():
self._call_response.cancel()
return True
else:
return False
def __await__(self) -> ResponseType:
"""Wait till the ongoing RPC request finishes."""
try:
response = yield from self._call_response
except asyncio.CancelledError:
# Even if we caught all other CancelledError, there is still
# this corner case. If the application cancels immediately after
# the Call object is created, we will observe this
# `CancelledError`.
if not self.cancelled():
self.cancel()
raise
# NOTE(lidiz) If we raise RpcError in the task, and users doesn't
# 'await' on it. AsyncIO will log 'Task exception was never retrieved'.
# Instead, if we move the exception raising here, the spam stops.
# Unfortunately, there can only be one 'yield from' in '__await__'. So,
# we need to access the private instance variable.
if response is cygrpc.EOF:
if self._cython_call.is_locally_cancelled():
raise asyncio.CancelledError()
else:
raise _create_rpc_error(self._cython_call._initial_metadata,
self._cython_call._status)
else:
return response
class _StreamResponseMixin(Call):
_message_aiter: AsyncIterable[ResponseType]
_preparation: asyncio.Task
_response_style: _APIStyle
def _init_stream_response_mixin(self, preparation: asyncio.Task):
self._message_aiter = None
self._preparation = preparation
self._response_style = _APIStyle.UNKNOWN
def _update_response_style(self, style: _APIStyle):
if self._response_style is _APIStyle.UNKNOWN:
self._response_style = style
elif self._response_style is not style:
raise cygrpc.UsageError(_API_STYLE_ERROR)
def cancel(self) -> bool:
if super().cancel():
self._preparation.cancel()
return True
else:
return False
async def _fetch_stream_responses(self) -> ResponseType:
message = await self._read()
while message is not cygrpc.EOF:
yield message
message = await self._read()
# If the read operation failed, Core should explain why.
await self._raise_for_status()
def __aiter__(self) -> AsyncIterable[ResponseType]:
self._update_response_style(_APIStyle.ASYNC_GENERATOR)
if self._message_aiter is None:
self._message_aiter = self._fetch_stream_responses()
return self._message_aiter
async def _read(self) -> ResponseType:
# Wait for the request being sent
await self._preparation
# Reads response message from Core
try:
raw_response = await self._cython_call.receive_serialized_message()
except asyncio.CancelledError:
if not self.cancelled():
self.cancel()
await self._raise_for_status()
if raw_response is cygrpc.EOF:
return cygrpc.EOF
else:
return _common.deserialize(raw_response,
self._response_deserializer)
async def read(self) -> ResponseType:
if self.done():
await self._raise_for_status()
return cygrpc.EOF
self._update_response_style(_APIStyle.READER_WRITER)
response_message = await self._read()
if response_message is cygrpc.EOF:
# If the read operation failed, Core should explain why.
await self._raise_for_status()
return response_message
class _StreamRequestMixin(Call):
_metadata_sent: asyncio.Event
_done_writing_flag: bool
_async_request_poller: Optional[asyncio.Task]
_request_style: _APIStyle
def _init_stream_request_mixin(
self, request_iterator: Optional[RequestIterableType]):
self._metadata_sent = asyncio.Event(loop=self._loop)
self._done_writing_flag = False
# If user passes in an async iterator, create a consumer Task.
if request_iterator is not None:
self._async_request_poller = self._loop.create_task(
self._consume_request_iterator(request_iterator))
self._request_style = _APIStyle.ASYNC_GENERATOR
else:
self._async_request_poller = None
self._request_style = _APIStyle.READER_WRITER
def _raise_for_different_style(self, style: _APIStyle):
if self._request_style is not style:
raise cygrpc.UsageError(_API_STYLE_ERROR)
def cancel(self) -> bool:
if super().cancel():
if self._async_request_poller is not None:
self._async_request_poller.cancel()
return True
else:
return False
def _metadata_sent_observer(self):
self._metadata_sent.set()
async def _consume_request_iterator(self,
request_iterator: RequestIterableType
) -> None:
try:
if inspect.isasyncgen(request_iterator) or hasattr(
request_iterator, '__aiter__'):
async for request in request_iterator:
await self._write(request)
else:
for request in request_iterator:
await self._write(request)
await self._done_writing()
except AioRpcError as rpc_error:
# Rpc status should be exposed through other API. Exceptions raised
# within this Task won't be retrieved by another coroutine. It's
# better to suppress the error than spamming users' screen.
_LOGGER.debug('Exception while consuming the request_iterator: %s',
rpc_error)
async def _write(self, request: RequestType) -> None:
if self.done():
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
if self._done_writing_flag:
raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS)
if not self._metadata_sent.is_set():
await self._metadata_sent.wait()
if self.done():
await self._raise_for_status()
serialized_request = _common.serialize(request,
self._request_serializer)
try:
await self._cython_call.send_serialized_message(serialized_request)
except asyncio.CancelledError:
if not self.cancelled():
self.cancel()
await self._raise_for_status()
async def _done_writing(self) -> None:
if self.done():
# If the RPC is finished, do nothing.
return
if not self._done_writing_flag:
# If the done writing is not sent before, try to send it.
self._done_writing_flag = True
try:
await self._cython_call.send_receive_close()
except asyncio.CancelledError:
if not self.cancelled():
self.cancel()
await self._raise_for_status()
async def write(self, request: RequestType) -> None:
self._raise_for_different_style(_APIStyle.READER_WRITER)
await self._write(request)
async def done_writing(self) -> None:
"""Signal peer that client is done writing.
This method is idempotent.
"""
self._raise_for_different_style(_APIStyle.READER_WRITER)
await self._done_writing()
async def wait_for_connection(self) -> None:
await self._metadata_sent.wait()
if self.done():
await self._raise_for_status()
class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall):
"""Object for managing unary-unary RPC calls.
Returned when an instance of `UnaryUnaryMultiCallable` object is called.
"""
_request: RequestType
_invocation_task: asyncio.Task
# pylint: disable=too-many-arguments
def __init__(self, request: RequestType, deadline: Optional[float],
metadata: Metadata,
credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
method: bytes, request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None:
super().__init__(
channel.call(method, deadline, credentials, wait_for_ready),
metadata, request_serializer, response_deserializer, loop)
self._request = request
self._invocation_task = loop.create_task(self._invoke())
self._init_unary_response_mixin(self._invocation_task)
async def _invoke(self) -> ResponseType:
serialized_request = _common.serialize(self._request,
self._request_serializer)
# NOTE(lidiz) asyncio.CancelledError is not a good transport for status,
# because the asyncio.Task class do not cache the exception object.
# https://github.com/python/cpython/blob/edad4d89e357c92f70c0324b937845d652b20afd/Lib/asyncio/tasks.py#L785
try:
serialized_response = await self._cython_call.unary_unary(
serialized_request, self._metadata)
except asyncio.CancelledError:
if not self.cancelled():
self.cancel()
if self._cython_call.is_ok():
return _common.deserialize(serialized_response,
self._response_deserializer)
else:
return cygrpc.EOF
async def wait_for_connection(self) -> None:
await self._invocation_task
if self.done():
await self._raise_for_status()
class UnaryStreamCall(_StreamResponseMixin, Call, _base_call.UnaryStreamCall):
"""Object for managing unary-stream RPC calls.
Returned when an instance of `UnaryStreamMultiCallable` object is called.
"""
_request: RequestType
_send_unary_request_task: asyncio.Task
# pylint: disable=too-many-arguments
def __init__(self, request: RequestType, deadline: Optional[float],
metadata: Metadata,
credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
method: bytes, request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None:
super().__init__(
channel.call(method, deadline, credentials, wait_for_ready),
metadata, request_serializer, response_deserializer, loop)
self._request = request
self._send_unary_request_task = loop.create_task(
self._send_unary_request())
self._init_stream_response_mixin(self._send_unary_request_task)
async def _send_unary_request(self) -> ResponseType:
serialized_request = _common.serialize(self._request,
self._request_serializer)
try:
await self._cython_call.initiate_unary_stream(
serialized_request, self._metadata)
except asyncio.CancelledError:
if not self.cancelled():
self.cancel()
raise
async def wait_for_connection(self) -> None:
await self._send_unary_request_task
if self.done():
await self._raise_for_status()
class StreamUnaryCall(_StreamRequestMixin, _UnaryResponseMixin, Call,
_base_call.StreamUnaryCall):
"""Object for managing stream-unary RPC calls.
Returned when an instance of `StreamUnaryMultiCallable` object is called.
"""
# pylint: disable=too-many-arguments
def __init__(self, request_iterator: Optional[RequestIterableType],
deadline: Optional[float], metadata: Metadata,
credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
method: bytes, request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None:
super().__init__(
channel.call(method, deadline, credentials, wait_for_ready),
metadata, request_serializer, response_deserializer, loop)
self._init_stream_request_mixin(request_iterator)
self._init_unary_response_mixin(loop.create_task(self._conduct_rpc()))
async def _conduct_rpc(self) -> ResponseType:
try:
serialized_response = await self._cython_call.stream_unary(
self._metadata, self._metadata_sent_observer)
except asyncio.CancelledError:
if not self.cancelled():
self.cancel()
if self._cython_call.is_ok():
return _common.deserialize(serialized_response,
self._response_deserializer)
else:
return cygrpc.EOF
class StreamStreamCall(_StreamRequestMixin, _StreamResponseMixin, Call,
_base_call.StreamStreamCall):
"""Object for managing stream-stream RPC calls.
Returned when an instance of `StreamStreamMultiCallable` object is called.
"""
_initializer: asyncio.Task
# pylint: disable=too-many-arguments
def __init__(self, request_iterator: Optional[RequestIterableType],
deadline: Optional[float], metadata: Metadata,
credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
method: bytes, request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None:
super().__init__(
channel.call(method, deadline, credentials, wait_for_ready),
metadata, request_serializer, response_deserializer, loop)
self._initializer = self._loop.create_task(self._prepare_rpc())
self._init_stream_request_mixin(request_iterator)
self._init_stream_response_mixin(self._initializer)
async def _prepare_rpc(self):
"""This method prepares the RPC for receiving/sending messages.
All other operations around the stream should only happen after the
completion of this method.
"""
try:
await self._cython_call.initiate_stream_stream(
self._metadata, self._metadata_sent_observer)
except asyncio.CancelledError:
if not self.cancelled():
self.cancel()
# No need to raise RpcError here, because no one will `await` this task.

View file

@ -0,0 +1,469 @@
# Copyright 2019 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Invocation-side implementation of gRPC Asyncio Python."""
import asyncio
import sys
from typing import Any, Iterable, Optional, Sequence, List
import grpc
from grpc import _common, _compression, _grpcio_metadata
from grpc._cython import cygrpc
from . import _base_call, _base_channel
from ._call import (StreamStreamCall, StreamUnaryCall, UnaryStreamCall,
UnaryUnaryCall)
from ._interceptor import (
InterceptedUnaryUnaryCall, InterceptedUnaryStreamCall,
InterceptedStreamUnaryCall, InterceptedStreamStreamCall, ClientInterceptor,
UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor,
StreamUnaryClientInterceptor, StreamStreamClientInterceptor)
from ._metadata import Metadata
from ._typing import (ChannelArgumentType, DeserializingFunction,
SerializingFunction, RequestIterableType)
from ._utils import _timeout_to_deadline
_USER_AGENT = 'grpc-python-asyncio/{}'.format(_grpcio_metadata.__version__)
if sys.version_info[1] < 7:
def _all_tasks() -> Iterable[asyncio.Task]:
return asyncio.Task.all_tasks()
else:
def _all_tasks() -> Iterable[asyncio.Task]:
return asyncio.all_tasks()
def _augment_channel_arguments(base_options: ChannelArgumentType,
compression: Optional[grpc.Compression]):
compression_channel_argument = _compression.create_channel_option(
compression)
user_agent_channel_argument = ((
cygrpc.ChannelArgKey.primary_user_agent_string,
_USER_AGENT,
),)
return tuple(base_options
) + compression_channel_argument + user_agent_channel_argument
class _BaseMultiCallable:
"""Base class of all multi callable objects.
Handles the initialization logic and stores common attributes.
"""
_loop: asyncio.AbstractEventLoop
_channel: cygrpc.AioChannel
_method: bytes
_request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction
_interceptors: Optional[Sequence[ClientInterceptor]]
_loop: asyncio.AbstractEventLoop
# pylint: disable=too-many-arguments
def __init__(
self,
channel: cygrpc.AioChannel,
method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
interceptors: Optional[Sequence[ClientInterceptor]],
loop: asyncio.AbstractEventLoop,
) -> None:
self._loop = loop
self._channel = channel
self._method = method
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._interceptors = interceptors
@staticmethod
def _init_metadata(metadata: Optional[Metadata] = None,
compression: Optional[grpc.Compression] = None
) -> Metadata:
"""Based on the provided values for <metadata> or <compression> initialise the final
metadata, as it should be used for the current call.
"""
metadata = metadata or Metadata()
if compression:
metadata = Metadata(
*_compression.augment_metadata(metadata, compression))
return metadata
class UnaryUnaryMultiCallable(_BaseMultiCallable,
_base_channel.UnaryUnaryMultiCallable):
def __call__(self,
request: Any,
*,
timeout: Optional[float] = None,
metadata: Optional[Metadata] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> _base_call.UnaryUnaryCall:
metadata = self._init_metadata(metadata, compression)
if not self._interceptors:
call = UnaryUnaryCall(request, _timeout_to_deadline(timeout),
metadata, credentials, wait_for_ready,
self._channel, self._method,
self._request_serializer,
self._response_deserializer, self._loop)
else:
call = InterceptedUnaryUnaryCall(
self._interceptors, request, timeout, metadata, credentials,
wait_for_ready, self._channel, self._method,
self._request_serializer, self._response_deserializer,
self._loop)
return call
class UnaryStreamMultiCallable(_BaseMultiCallable,
_base_channel.UnaryStreamMultiCallable):
def __call__(self,
request: Any,
*,
timeout: Optional[float] = None,
metadata: Optional[Metadata] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> _base_call.UnaryStreamCall:
metadata = self._init_metadata(metadata, compression)
deadline = _timeout_to_deadline(timeout)
if not self._interceptors:
call = UnaryStreamCall(request, deadline, metadata, credentials,
wait_for_ready, self._channel, self._method,
self._request_serializer,
self._response_deserializer, self._loop)
else:
call = InterceptedUnaryStreamCall(
self._interceptors, request, deadline, metadata, credentials,
wait_for_ready, self._channel, self._method,
self._request_serializer, self._response_deserializer,
self._loop)
return call
class StreamUnaryMultiCallable(_BaseMultiCallable,
_base_channel.StreamUnaryMultiCallable):
def __call__(self,
request_iterator: Optional[RequestIterableType] = None,
timeout: Optional[float] = None,
metadata: Optional[Metadata] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> _base_call.StreamUnaryCall:
metadata = self._init_metadata(metadata, compression)
deadline = _timeout_to_deadline(timeout)
if not self._interceptors:
call = StreamUnaryCall(request_iterator, deadline, metadata,
credentials, wait_for_ready, self._channel,
self._method, self._request_serializer,
self._response_deserializer, self._loop)
else:
call = InterceptedStreamUnaryCall(
self._interceptors, request_iterator, deadline, metadata,
credentials, wait_for_ready, self._channel, self._method,
self._request_serializer, self._response_deserializer,
self._loop)
return call
class StreamStreamMultiCallable(_BaseMultiCallable,
_base_channel.StreamStreamMultiCallable):
def __call__(self,
request_iterator: Optional[RequestIterableType] = None,
timeout: Optional[float] = None,
metadata: Optional[Metadata] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> _base_call.StreamStreamCall:
metadata = self._init_metadata(metadata, compression)
deadline = _timeout_to_deadline(timeout)
if not self._interceptors:
call = StreamStreamCall(request_iterator, deadline, metadata,
credentials, wait_for_ready, self._channel,
self._method, self._request_serializer,
self._response_deserializer, self._loop)
else:
call = InterceptedStreamStreamCall(
self._interceptors, request_iterator, deadline, metadata,
credentials, wait_for_ready, self._channel, self._method,
self._request_serializer, self._response_deserializer,
self._loop)
return call
class Channel(_base_channel.Channel):
_loop: asyncio.AbstractEventLoop
_channel: cygrpc.AioChannel
_unary_unary_interceptors: List[UnaryUnaryClientInterceptor]
_unary_stream_interceptors: List[UnaryStreamClientInterceptor]
_stream_unary_interceptors: List[StreamUnaryClientInterceptor]
_stream_stream_interceptors: List[StreamStreamClientInterceptor]
def __init__(self, target: str, options: ChannelArgumentType,
credentials: Optional[grpc.ChannelCredentials],
compression: Optional[grpc.Compression],
interceptors: Optional[Sequence[ClientInterceptor]]):
"""Constructor.
Args:
target: The target to which to connect.
options: Configuration options for the channel.
credentials: A cygrpc.ChannelCredentials or None.
compression: An optional value indicating the compression method to be
used over the lifetime of the channel.
interceptors: An optional list of interceptors that would be used for
intercepting any RPC executed with that channel.
"""
self._unary_unary_interceptors = []
self._unary_stream_interceptors = []
self._stream_unary_interceptors = []
self._stream_stream_interceptors = []
if interceptors is not None:
for interceptor in interceptors:
if isinstance(interceptor, UnaryUnaryClientInterceptor):
self._unary_unary_interceptors.append(interceptor)
elif isinstance(interceptor, UnaryStreamClientInterceptor):
self._unary_stream_interceptors.append(interceptor)
elif isinstance(interceptor, StreamUnaryClientInterceptor):
self._stream_unary_interceptors.append(interceptor)
elif isinstance(interceptor, StreamStreamClientInterceptor):
self._stream_stream_interceptors.append(interceptor)
else:
raise ValueError(
"Interceptor {} must be ".format(interceptor) +
"{} or ".format(UnaryUnaryClientInterceptor.__name__) +
"{} or ".format(UnaryStreamClientInterceptor.__name__) +
"{} or ".format(StreamUnaryClientInterceptor.__name__) +
"{}. ".format(StreamStreamClientInterceptor.__name__))
self._loop = cygrpc.get_working_loop()
self._channel = cygrpc.AioChannel(
_common.encode(target),
_augment_channel_arguments(options, compression), credentials,
self._loop)
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self._close(None)
async def _close(self, grace): # pylint: disable=too-many-branches
if self._channel.closed():
return
# No new calls will be accepted by the Cython channel.
self._channel.closing()
# Iterate through running tasks
tasks = _all_tasks()
calls = []
call_tasks = []
for task in tasks:
try:
stack = task.get_stack(limit=1)
except AttributeError as attribute_error:
# NOTE(lidiz) tl;dr: If the Task is created with a CPython
# object, it will trigger AttributeError.
#
# In the global finalizer, the event loop schedules
# a CPython PyAsyncGenAThrow object.
# https://github.com/python/cpython/blob/00e45877e33d32bb61aa13a2033e3bba370bda4d/Lib/asyncio/base_events.py#L484
#
# However, the PyAsyncGenAThrow object is written in C and
# failed to include the normal Python frame objects. Hence,
# this exception is a false negative, and it is safe to ignore
# the failure. It is fixed by https://github.com/python/cpython/pull/18669,
# but not available until 3.9 or 3.8.3. So, we have to keep it
# for a while.
# TODO(lidiz) drop this hack after 3.8 deprecation
if 'frame' in str(attribute_error):
continue
else:
raise
# If the Task is created by a C-extension, the stack will be empty.
if not stack:
continue
# Locate ones created by `aio.Call`.
frame = stack[0]
candidate = frame.f_locals.get('self')
if candidate:
if isinstance(candidate, _base_call.Call):
if hasattr(candidate, '_channel'):
# For intercepted Call object
if candidate._channel is not self._channel:
continue
elif hasattr(candidate, '_cython_call'):
# For normal Call object
if candidate._cython_call._channel is not self._channel:
continue
else:
# Unidentified Call object
raise cygrpc.InternalError(
f'Unrecognized call object: {candidate}')
calls.append(candidate)
call_tasks.append(task)
# If needed, try to wait for them to finish.
# Call objects are not always awaitables.
if grace and call_tasks:
await asyncio.wait(call_tasks, timeout=grace, loop=self._loop)
# Time to cancel existing calls.
for call in calls:
call.cancel()
# Destroy the channel
self._channel.close()
async def close(self, grace: Optional[float] = None):
await self._close(grace)
def get_state(self,
try_to_connect: bool = False) -> grpc.ChannelConnectivity:
result = self._channel.check_connectivity_state(try_to_connect)
return _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[result]
async def wait_for_state_change(
self,
last_observed_state: grpc.ChannelConnectivity,
) -> None:
assert await self._channel.watch_connectivity_state(
last_observed_state.value[0], None)
async def channel_ready(self) -> None:
state = self.get_state(try_to_connect=True)
while state != grpc.ChannelConnectivity.READY:
await self.wait_for_state_change(state)
state = self.get_state(try_to_connect=True)
def unary_unary(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None
) -> UnaryUnaryMultiCallable:
return UnaryUnaryMultiCallable(self._channel, _common.encode(method),
request_serializer,
response_deserializer,
self._unary_unary_interceptors,
self._loop)
def unary_stream(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None
) -> UnaryStreamMultiCallable:
return UnaryStreamMultiCallable(self._channel, _common.encode(method),
request_serializer,
response_deserializer,
self._unary_stream_interceptors,
self._loop)
def stream_unary(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None
) -> StreamUnaryMultiCallable:
return StreamUnaryMultiCallable(self._channel, _common.encode(method),
request_serializer,
response_deserializer,
self._stream_unary_interceptors,
self._loop)
def stream_stream(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None
) -> StreamStreamMultiCallable:
return StreamStreamMultiCallable(self._channel, _common.encode(method),
request_serializer,
response_deserializer,
self._stream_stream_interceptors,
self._loop)
def insecure_channel(
target: str,
options: Optional[ChannelArgumentType] = None,
compression: Optional[grpc.Compression] = None,
interceptors: Optional[Sequence[ClientInterceptor]] = None):
"""Creates an insecure asynchronous Channel to a server.
Args:
target: The server address
options: An optional list of key-value pairs (:term:`channel_arguments`
in gRPC Core runtime) to configure the channel.
compression: An optional value indicating the compression method to be
used over the lifetime of the channel. This is an EXPERIMENTAL option.
interceptors: An optional sequence of interceptors that will be executed for
any call executed with this channel.
Returns:
A Channel.
"""
return Channel(target, () if options is None else options, None,
compression, interceptors)
def secure_channel(target: str,
credentials: grpc.ChannelCredentials,
options: Optional[ChannelArgumentType] = None,
compression: Optional[grpc.Compression] = None,
interceptors: Optional[Sequence[ClientInterceptor]] = None):
"""Creates a secure asynchronous Channel to a server.
Args:
target: The server address.
credentials: A ChannelCredentials instance.
options: An optional list of key-value pairs (:term:`channel_arguments`
in gRPC Core runtime) to configure the channel.
compression: An optional value indicating the compression method to be
used over the lifetime of the channel. This is an EXPERIMENTAL option.
interceptors: An optional sequence of interceptors that will be executed for
any call executed with this channel.
Returns:
An aio.Channel.
"""
return Channel(target, () if options is None else options,
credentials._credentials, compression, interceptors)

View file

@ -0,0 +1,987 @@
# Copyright 2019 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Interceptors implementation of gRPC Asyncio Python."""
import asyncio
import collections
import functools
from abc import ABCMeta, abstractmethod
from typing import Callable, Optional, Iterator, Sequence, Union, Awaitable, AsyncIterable
import grpc
from grpc._cython import cygrpc
from . import _base_call
from ._call import UnaryUnaryCall, UnaryStreamCall, StreamUnaryCall, StreamStreamCall, AioRpcError
from ._call import _RPC_ALREADY_FINISHED_DETAILS, _RPC_HALF_CLOSED_DETAILS
from ._call import _API_STYLE_ERROR
from ._utils import _timeout_to_deadline
from ._typing import (RequestType, SerializingFunction, DeserializingFunction,
ResponseType, DoneCallbackType, RequestIterableType,
ResponseIterableType)
from ._metadata import Metadata
_LOCAL_CANCELLATION_DETAILS = 'Locally cancelled by application!'
class ServerInterceptor(metaclass=ABCMeta):
"""Affords intercepting incoming RPCs on the service-side.
This is an EXPERIMENTAL API.
"""
@abstractmethod
async def intercept_service(
self, continuation: Callable[[grpc.HandlerCallDetails], Awaitable[
grpc.RpcMethodHandler]],
handler_call_details: grpc.HandlerCallDetails
) -> grpc.RpcMethodHandler:
"""Intercepts incoming RPCs before handing them over to a handler.
Args:
continuation: A function that takes a HandlerCallDetails and
proceeds to invoke the next interceptor in the chain, if any,
or the RPC handler lookup logic, with the call details passed
as an argument, and returns an RpcMethodHandler instance if
the RPC is considered serviced, or None otherwise.
handler_call_details: A HandlerCallDetails describing the RPC.
Returns:
An RpcMethodHandler with which the RPC may be serviced if the
interceptor chooses to service this RPC, or None otherwise.
"""
class ClientCallDetails(
collections.namedtuple(
'ClientCallDetails',
('method', 'timeout', 'metadata', 'credentials', 'wait_for_ready')),
grpc.ClientCallDetails):
"""Describes an RPC to be invoked.
This is an EXPERIMENTAL API.
Args:
method: The method name of the RPC.
timeout: An optional duration of time in seconds to allow for the RPC.
metadata: Optional metadata to be transmitted to the service-side of
the RPC.
credentials: An optional CallCredentials for the RPC.
wait_for_ready: This is an EXPERIMENTAL argument. An optional
flag to enable :term:`wait_for_ready` mechanism.
"""
method: str
timeout: Optional[float]
metadata: Optional[Metadata]
credentials: Optional[grpc.CallCredentials]
wait_for_ready: Optional[bool]
class ClientInterceptor(metaclass=ABCMeta):
"""Base class used for all Aio Client Interceptor classes"""
class UnaryUnaryClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
"""Affords intercepting unary-unary invocations."""
@abstractmethod
async def intercept_unary_unary(
self, continuation: Callable[[ClientCallDetails, RequestType],
UnaryUnaryCall],
client_call_details: ClientCallDetails,
request: RequestType) -> Union[UnaryUnaryCall, ResponseType]:
"""Intercepts a unary-unary invocation asynchronously.
Args:
continuation: A coroutine that proceeds with the invocation by
executing the next interceptor in the chain or invoking the
actual RPC on the underlying Channel. It is the interceptor's
responsibility to call it if it decides to move the RPC forward.
The interceptor can use
`call = await continuation(client_call_details, request)`
to continue with the RPC. `continuation` returns the call to the
RPC.
client_call_details: A ClientCallDetails object describing the
outgoing RPC.
request: The request value for the RPC.
Returns:
An object with the RPC response.
Raises:
AioRpcError: Indicating that the RPC terminated with non-OK status.
asyncio.CancelledError: Indicating that the RPC was canceled.
"""
class UnaryStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
"""Affords intercepting unary-stream invocations."""
@abstractmethod
async def intercept_unary_stream(
self, continuation: Callable[[ClientCallDetails, RequestType],
UnaryStreamCall],
client_call_details: ClientCallDetails, request: RequestType
) -> Union[ResponseIterableType, UnaryStreamCall]:
"""Intercepts a unary-stream invocation asynchronously.
The function could return the call object or an asynchronous
iterator, in case of being an asyncrhonous iterator this will
become the source of the reads done by the caller.
Args:
continuation: A coroutine that proceeds with the invocation by
executing the next interceptor in the chain or invoking the
actual RPC on the underlying Channel. It is the interceptor's
responsibility to call it if it decides to move the RPC forward.
The interceptor can use
`call = await continuation(client_call_details, request)`
to continue with the RPC. `continuation` returns the call to the
RPC.
client_call_details: A ClientCallDetails object describing the
outgoing RPC.
request: The request value for the RPC.
Returns:
The RPC Call or an asynchronous iterator.
Raises:
AioRpcError: Indicating that the RPC terminated with non-OK status.
asyncio.CancelledError: Indicating that the RPC was canceled.
"""
class StreamUnaryClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
"""Affords intercepting stream-unary invocations."""
@abstractmethod
async def intercept_stream_unary(
self,
continuation: Callable[[ClientCallDetails, RequestType],
UnaryStreamCall],
client_call_details: ClientCallDetails,
request_iterator: RequestIterableType,
) -> StreamUnaryCall:
"""Intercepts a stream-unary invocation asynchronously.
Within the interceptor the usage of the call methods like `write` or
even awaiting the call should be done carefully, since the caller
could be expecting an untouched call, for example for start writing
messages to it.
Args:
continuation: A coroutine that proceeds with the invocation by
executing the next interceptor in the chain or invoking the
actual RPC on the underlying Channel. It is the interceptor's
responsibility to call it if it decides to move the RPC forward.
The interceptor can use
`call = await continuation(client_call_details, request_iterator)`
to continue with the RPC. `continuation` returns the call to the
RPC.
client_call_details: A ClientCallDetails object describing the
outgoing RPC.
request_iterator: The request iterator that will produce requests
for the RPC.
Returns:
The RPC Call.
Raises:
AioRpcError: Indicating that the RPC terminated with non-OK status.
asyncio.CancelledError: Indicating that the RPC was canceled.
"""
class StreamStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
"""Affords intercepting stream-stream invocations."""
@abstractmethod
async def intercept_stream_stream(
self,
continuation: Callable[[ClientCallDetails, RequestType],
UnaryStreamCall],
client_call_details: ClientCallDetails,
request_iterator: RequestIterableType,
) -> Union[ResponseIterableType, StreamStreamCall]:
"""Intercepts a stream-stream invocation asynchronously.
Within the interceptor the usage of the call methods like `write` or
even awaiting the call should be done carefully, since the caller
could be expecting an untouched call, for example for start writing
messages to it.
The function could return the call object or an asynchronous
iterator, in case of being an asyncrhonous iterator this will
become the source of the reads done by the caller.
Args:
continuation: A coroutine that proceeds with the invocation by
executing the next interceptor in the chain or invoking the
actual RPC on the underlying Channel. It is the interceptor's
responsibility to call it if it decides to move the RPC forward.
The interceptor can use
`call = await continuation(client_call_details, request_iterator)`
to continue with the RPC. `continuation` returns the call to the
RPC.
client_call_details: A ClientCallDetails object describing the
outgoing RPC.
request_iterator: The request iterator that will produce requests
for the RPC.
Returns:
The RPC Call or an asynchronous iterator.
Raises:
AioRpcError: Indicating that the RPC terminated with non-OK status.
asyncio.CancelledError: Indicating that the RPC was canceled.
"""
class InterceptedCall:
"""Base implementation for all intercepted call arities.
Interceptors might have some work to do before the RPC invocation with
the capacity of changing the invocation parameters, and some work to do
after the RPC invocation with the capacity for accessing to the wrapped
`UnaryUnaryCall`.
It handles also early and later cancellations, when the RPC has not even
started and the execution is still held by the interceptors or when the
RPC has finished but again the execution is still held by the interceptors.
Once the RPC is finally executed, all methods are finally done against the
intercepted call, being at the same time the same call returned to the
interceptors.
As a base class for all of the interceptors implements the logic around
final status, metadata and cancellation.
"""
_interceptors_task: asyncio.Task
_pending_add_done_callbacks: Sequence[DoneCallbackType]
def __init__(self, interceptors_task: asyncio.Task) -> None:
self._interceptors_task = interceptors_task
self._pending_add_done_callbacks = []
self._interceptors_task.add_done_callback(
self._fire_or_add_pending_done_callbacks)
def __del__(self):
self.cancel()
def _fire_or_add_pending_done_callbacks(self,
interceptors_task: asyncio.Task
) -> None:
if not self._pending_add_done_callbacks:
return
call_completed = False
try:
call = interceptors_task.result()
if call.done():
call_completed = True
except (AioRpcError, asyncio.CancelledError):
call_completed = True
if call_completed:
for callback in self._pending_add_done_callbacks:
callback(self)
else:
for callback in self._pending_add_done_callbacks:
callback = functools.partial(self._wrap_add_done_callback,
callback)
call.add_done_callback(callback)
self._pending_add_done_callbacks = []
def _wrap_add_done_callback(self, callback: DoneCallbackType,
unused_call: _base_call.Call) -> None:
callback(self)
def cancel(self) -> bool:
if not self._interceptors_task.done():
# There is no yet the intercepted call available,
# Trying to cancel it by using the generic Asyncio
# cancellation method.
return self._interceptors_task.cancel()
try:
call = self._interceptors_task.result()
except AioRpcError:
return False
except asyncio.CancelledError:
return False
return call.cancel()
def cancelled(self) -> bool:
if not self._interceptors_task.done():
return False
try:
call = self._interceptors_task.result()
except AioRpcError as err:
return err.code() == grpc.StatusCode.CANCELLED
except asyncio.CancelledError:
return True
return call.cancelled()
def done(self) -> bool:
if not self._interceptors_task.done():
return False
try:
call = self._interceptors_task.result()
except (AioRpcError, asyncio.CancelledError):
return True
return call.done()
def add_done_callback(self, callback: DoneCallbackType) -> None:
if not self._interceptors_task.done():
self._pending_add_done_callbacks.append(callback)
return
try:
call = self._interceptors_task.result()
except (AioRpcError, asyncio.CancelledError):
callback(self)
return
if call.done():
callback(self)
else:
callback = functools.partial(self._wrap_add_done_callback, callback)
call.add_done_callback(callback)
def time_remaining(self) -> Optional[float]:
raise NotImplementedError()
async def initial_metadata(self) -> Optional[Metadata]:
try:
call = await self._interceptors_task
except AioRpcError as err:
return err.initial_metadata()
except asyncio.CancelledError:
return None
return await call.initial_metadata()
async def trailing_metadata(self) -> Optional[Metadata]:
try:
call = await self._interceptors_task
except AioRpcError as err:
return err.trailing_metadata()
except asyncio.CancelledError:
return None
return await call.trailing_metadata()
async def code(self) -> grpc.StatusCode:
try:
call = await self._interceptors_task
except AioRpcError as err:
return err.code()
except asyncio.CancelledError:
return grpc.StatusCode.CANCELLED
return await call.code()
async def details(self) -> str:
try:
call = await self._interceptors_task
except AioRpcError as err:
return err.details()
except asyncio.CancelledError:
return _LOCAL_CANCELLATION_DETAILS
return await call.details()
async def debug_error_string(self) -> Optional[str]:
try:
call = await self._interceptors_task
except AioRpcError as err:
return err.debug_error_string()
except asyncio.CancelledError:
return ''
return await call.debug_error_string()
async def wait_for_connection(self) -> None:
call = await self._interceptors_task
return await call.wait_for_connection()
class _InterceptedUnaryResponseMixin:
def __await__(self):
call = yield from self._interceptors_task.__await__()
response = yield from call.__await__()
return response
class _InterceptedStreamResponseMixin:
_response_aiter: Optional[AsyncIterable[ResponseType]]
def _init_stream_response_mixin(self) -> None:
# Is initalized later, otherwise if the iterator is not finnally
# consumed a logging warning is emmited by Asyncio.
self._response_aiter = None
async def _wait_for_interceptor_task_response_iterator(self
) -> ResponseType:
call = await self._interceptors_task
async for response in call:
yield response
def __aiter__(self) -> AsyncIterable[ResponseType]:
if self._response_aiter is None:
self._response_aiter = self._wait_for_interceptor_task_response_iterator(
)
return self._response_aiter
async def read(self) -> ResponseType:
if self._response_aiter is None:
self._response_aiter = self._wait_for_interceptor_task_response_iterator(
)
return await self._response_aiter.asend(None)
class _InterceptedStreamRequestMixin:
_write_to_iterator_async_gen: Optional[AsyncIterable[RequestType]]
_write_to_iterator_queue: Optional[asyncio.Queue]
_FINISH_ITERATOR_SENTINEL = object()
def _init_stream_request_mixin(
self, request_iterator: Optional[RequestIterableType]
) -> RequestIterableType:
if request_iterator is None:
# We provide our own request iterator which is a proxy
# of the futures writes that will be done by the caller.
self._write_to_iterator_queue = asyncio.Queue(maxsize=1)
self._write_to_iterator_async_gen = self._proxy_writes_as_request_iterator(
)
request_iterator = self._write_to_iterator_async_gen
else:
self._write_to_iterator_queue = None
return request_iterator
async def _proxy_writes_as_request_iterator(self):
await self._interceptors_task
while True:
value = await self._write_to_iterator_queue.get()
if value is _InterceptedStreamRequestMixin._FINISH_ITERATOR_SENTINEL:
break
yield value
async def write(self, request: RequestType) -> None:
# If no queue was created it means that requests
# should be expected through an iterators provided
# by the caller.
if self._write_to_iterator_queue is None:
raise cygrpc.UsageError(_API_STYLE_ERROR)
try:
call = await self._interceptors_task
except (asyncio.CancelledError, AioRpcError):
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
if call.done():
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
elif call._done_writing_flag:
raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS)
# Write might never end up since the call could abrubtly finish,
# we give up on the first awaitable object that finishes.
_, _ = await asyncio.wait(
(self._write_to_iterator_queue.put(request), call.code()),
return_when=asyncio.FIRST_COMPLETED)
if call.done():
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
async def done_writing(self) -> None:
"""Signal peer that client is done writing.
This method is idempotent.
"""
# If no queue was created it means that requests
# should be expected through an iterators provided
# by the caller.
if self._write_to_iterator_queue is None:
raise cygrpc.UsageError(_API_STYLE_ERROR)
try:
call = await self._interceptors_task
except asyncio.CancelledError:
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
# Write might never end up since the call could abrubtly finish,
# we give up on the first awaitable object that finishes.
_, _ = await asyncio.wait((self._write_to_iterator_queue.put(
_InterceptedStreamRequestMixin._FINISH_ITERATOR_SENTINEL),
call.code()),
return_when=asyncio.FIRST_COMPLETED)
class InterceptedUnaryUnaryCall(_InterceptedUnaryResponseMixin, InterceptedCall,
_base_call.UnaryUnaryCall):
"""Used for running a `UnaryUnaryCall` wrapped by interceptors.
For the `__await__` method is it is proxied to the intercepted call only when
the interceptor task is finished.
"""
_loop: asyncio.AbstractEventLoop
_channel: cygrpc.AioChannel
# pylint: disable=too-many-arguments
def __init__(self, interceptors: Sequence[UnaryUnaryClientInterceptor],
request: RequestType, timeout: Optional[float],
metadata: Metadata,
credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
method: bytes, request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None:
self._loop = loop
self._channel = channel
interceptors_task = loop.create_task(
self._invoke(interceptors, method, timeout, metadata, credentials,
wait_for_ready, request, request_serializer,
response_deserializer))
super().__init__(interceptors_task)
# pylint: disable=too-many-arguments
async def _invoke(self, interceptors: Sequence[UnaryUnaryClientInterceptor],
method: bytes, timeout: Optional[float],
metadata: Optional[Metadata],
credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool], request: RequestType,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction
) -> UnaryUnaryCall:
"""Run the RPC call wrapped in interceptors"""
async def _run_interceptor(
interceptors: Iterator[UnaryUnaryClientInterceptor],
client_call_details: ClientCallDetails,
request: RequestType) -> _base_call.UnaryUnaryCall:
interceptor = next(interceptors, None)
if interceptor:
continuation = functools.partial(_run_interceptor, interceptors)
call_or_response = await interceptor.intercept_unary_unary(
continuation, client_call_details, request)
if isinstance(call_or_response, _base_call.UnaryUnaryCall):
return call_or_response
else:
return UnaryUnaryCallResponse(call_or_response)
else:
return UnaryUnaryCall(
request, _timeout_to_deadline(client_call_details.timeout),
client_call_details.metadata,
client_call_details.credentials,
client_call_details.wait_for_ready, self._channel,
client_call_details.method, request_serializer,
response_deserializer, self._loop)
client_call_details = ClientCallDetails(method, timeout, metadata,
credentials, wait_for_ready)
return await _run_interceptor(iter(interceptors), client_call_details,
request)
def time_remaining(self) -> Optional[float]:
raise NotImplementedError()
class InterceptedUnaryStreamCall(_InterceptedStreamResponseMixin,
InterceptedCall, _base_call.UnaryStreamCall):
"""Used for running a `UnaryStreamCall` wrapped by interceptors."""
_loop: asyncio.AbstractEventLoop
_channel: cygrpc.AioChannel
_last_returned_call_from_interceptors = Optional[_base_call.UnaryStreamCall]
# pylint: disable=too-many-arguments
def __init__(self, interceptors: Sequence[UnaryStreamClientInterceptor],
request: RequestType, timeout: Optional[float],
metadata: Metadata,
credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
method: bytes, request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None:
self._loop = loop
self._channel = channel
self._init_stream_response_mixin()
self._last_returned_call_from_interceptors = None
interceptors_task = loop.create_task(
self._invoke(interceptors, method, timeout, metadata, credentials,
wait_for_ready, request, request_serializer,
response_deserializer))
super().__init__(interceptors_task)
# pylint: disable=too-many-arguments
async def _invoke(self, interceptors: Sequence[UnaryUnaryClientInterceptor],
method: bytes, timeout: Optional[float],
metadata: Optional[Metadata],
credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool], request: RequestType,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction
) -> UnaryStreamCall:
"""Run the RPC call wrapped in interceptors"""
async def _run_interceptor(
interceptors: Iterator[UnaryStreamClientInterceptor],
client_call_details: ClientCallDetails,
request: RequestType,
) -> _base_call.UnaryUnaryCall:
interceptor = next(interceptors, None)
if interceptor:
continuation = functools.partial(_run_interceptor, interceptors)
call_or_response_iterator = await interceptor.intercept_unary_stream(
continuation, client_call_details, request)
if isinstance(call_or_response_iterator,
_base_call.UnaryStreamCall):
self._last_returned_call_from_interceptors = call_or_response_iterator
else:
self._last_returned_call_from_interceptors = UnaryStreamCallResponseIterator(
self._last_returned_call_from_interceptors,
call_or_response_iterator)
return self._last_returned_call_from_interceptors
else:
self._last_returned_call_from_interceptors = UnaryStreamCall(
request, _timeout_to_deadline(client_call_details.timeout),
client_call_details.metadata,
client_call_details.credentials,
client_call_details.wait_for_ready, self._channel,
client_call_details.method, request_serializer,
response_deserializer, self._loop)
return self._last_returned_call_from_interceptors
client_call_details = ClientCallDetails(method, timeout, metadata,
credentials, wait_for_ready)
return await _run_interceptor(iter(interceptors), client_call_details,
request)
def time_remaining(self) -> Optional[float]:
raise NotImplementedError()
class InterceptedStreamUnaryCall(_InterceptedUnaryResponseMixin,
_InterceptedStreamRequestMixin,
InterceptedCall, _base_call.StreamUnaryCall):
"""Used for running a `StreamUnaryCall` wrapped by interceptors.
For the `__await__` method is it is proxied to the intercepted call only when
the interceptor task is finished.
"""
_loop: asyncio.AbstractEventLoop
_channel: cygrpc.AioChannel
# pylint: disable=too-many-arguments
def __init__(self, interceptors: Sequence[StreamUnaryClientInterceptor],
request_iterator: Optional[RequestIterableType],
timeout: Optional[float], metadata: Metadata,
credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
method: bytes, request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None:
self._loop = loop
self._channel = channel
request_iterator = self._init_stream_request_mixin(request_iterator)
interceptors_task = loop.create_task(
self._invoke(interceptors, method, timeout, metadata, credentials,
wait_for_ready, request_iterator, request_serializer,
response_deserializer))
super().__init__(interceptors_task)
# pylint: disable=too-many-arguments
async def _invoke(
self, interceptors: Sequence[StreamUnaryClientInterceptor],
method: bytes, timeout: Optional[float],
metadata: Optional[Metadata],
credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool],
request_iterator: RequestIterableType,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> StreamUnaryCall:
"""Run the RPC call wrapped in interceptors"""
async def _run_interceptor(
interceptors: Iterator[UnaryUnaryClientInterceptor],
client_call_details: ClientCallDetails,
request_iterator: RequestIterableType
) -> _base_call.StreamUnaryCall:
interceptor = next(interceptors, None)
if interceptor:
continuation = functools.partial(_run_interceptor, interceptors)
return await interceptor.intercept_stream_unary(
continuation, client_call_details, request_iterator)
else:
return StreamUnaryCall(
request_iterator,
_timeout_to_deadline(client_call_details.timeout),
client_call_details.metadata,
client_call_details.credentials,
client_call_details.wait_for_ready, self._channel,
client_call_details.method, request_serializer,
response_deserializer, self._loop)
client_call_details = ClientCallDetails(method, timeout, metadata,
credentials, wait_for_ready)
return await _run_interceptor(iter(interceptors), client_call_details,
request_iterator)
def time_remaining(self) -> Optional[float]:
raise NotImplementedError()
class InterceptedStreamStreamCall(_InterceptedStreamResponseMixin,
_InterceptedStreamRequestMixin,
InterceptedCall, _base_call.StreamStreamCall):
"""Used for running a `StreamStreamCall` wrapped by interceptors."""
_loop: asyncio.AbstractEventLoop
_channel: cygrpc.AioChannel
_last_returned_call_from_interceptors = Optional[_base_call.UnaryStreamCall]
# pylint: disable=too-many-arguments
def __init__(self, interceptors: Sequence[StreamStreamClientInterceptor],
request_iterator: Optional[RequestIterableType],
timeout: Optional[float], metadata: Metadata,
credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
method: bytes, request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None:
self._loop = loop
self._channel = channel
self._init_stream_response_mixin()
request_iterator = self._init_stream_request_mixin(request_iterator)
self._last_returned_call_from_interceptors = None
interceptors_task = loop.create_task(
self._invoke(interceptors, method, timeout, metadata, credentials,
wait_for_ready, request_iterator, request_serializer,
response_deserializer))
super().__init__(interceptors_task)
# pylint: disable=too-many-arguments
async def _invoke(
self, interceptors: Sequence[StreamStreamClientInterceptor],
method: bytes, timeout: Optional[float],
metadata: Optional[Metadata],
credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool],
request_iterator: RequestIterableType,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> StreamStreamCall:
"""Run the RPC call wrapped in interceptors"""
async def _run_interceptor(
interceptors: Iterator[StreamStreamClientInterceptor],
client_call_details: ClientCallDetails,
request_iterator: RequestIterableType
) -> _base_call.StreamStreamCall:
interceptor = next(interceptors, None)
if interceptor:
continuation = functools.partial(_run_interceptor, interceptors)
call_or_response_iterator = await interceptor.intercept_stream_stream(
continuation, client_call_details, request_iterator)
if isinstance(call_or_response_iterator,
_base_call.StreamStreamCall):
self._last_returned_call_from_interceptors = call_or_response_iterator
else:
self._last_returned_call_from_interceptors = StreamStreamCallResponseIterator(
self._last_returned_call_from_interceptors,
call_or_response_iterator)
return self._last_returned_call_from_interceptors
else:
self._last_returned_call_from_interceptors = StreamStreamCall(
request_iterator,
_timeout_to_deadline(client_call_details.timeout),
client_call_details.metadata,
client_call_details.credentials,
client_call_details.wait_for_ready, self._channel,
client_call_details.method, request_serializer,
response_deserializer, self._loop)
return self._last_returned_call_from_interceptors
client_call_details = ClientCallDetails(method, timeout, metadata,
credentials, wait_for_ready)
return await _run_interceptor(iter(interceptors), client_call_details,
request_iterator)
def time_remaining(self) -> Optional[float]:
raise NotImplementedError()
class UnaryUnaryCallResponse(_base_call.UnaryUnaryCall):
"""Final UnaryUnaryCall class finished with a response."""
_response: ResponseType
def __init__(self, response: ResponseType) -> None:
self._response = response
def cancel(self) -> bool:
return False
def cancelled(self) -> bool:
return False
def done(self) -> bool:
return True
def add_done_callback(self, unused_callback) -> None:
raise NotImplementedError()
def time_remaining(self) -> Optional[float]:
raise NotImplementedError()
async def initial_metadata(self) -> Optional[Metadata]:
return None
async def trailing_metadata(self) -> Optional[Metadata]:
return None
async def code(self) -> grpc.StatusCode:
return grpc.StatusCode.OK
async def details(self) -> str:
return ''
async def debug_error_string(self) -> Optional[str]:
return None
def __await__(self):
if False: # pylint: disable=using-constant-test
# This code path is never used, but a yield statement is needed
# for telling the interpreter that __await__ is a generator.
yield None
return self._response
async def wait_for_connection(self) -> None:
pass
class _StreamCallResponseIterator:
_call: Union[_base_call.UnaryStreamCall, _base_call.StreamStreamCall]
_response_iterator: AsyncIterable[ResponseType]
def __init__(self, call: Union[_base_call.UnaryStreamCall, _base_call.
StreamStreamCall],
response_iterator: AsyncIterable[ResponseType]) -> None:
self._response_iterator = response_iterator
self._call = call
def cancel(self) -> bool:
return self._call.cancel()
def cancelled(self) -> bool:
return self._call.cancelled()
def done(self) -> bool:
return self._call.done()
def add_done_callback(self, callback) -> None:
self._call.add_done_callback(callback)
def time_remaining(self) -> Optional[float]:
return self._call.time_remaining()
async def initial_metadata(self) -> Optional[Metadata]:
return await self._call.initial_metadata()
async def trailing_metadata(self) -> Optional[Metadata]:
return await self._call.trailing_metadata()
async def code(self) -> grpc.StatusCode:
return await self._call.code()
async def details(self) -> str:
return await self._call.details()
async def debug_error_string(self) -> Optional[str]:
return await self._call.debug_error_string()
def __aiter__(self):
return self._response_iterator.__aiter__()
async def wait_for_connection(self) -> None:
return await self._call.wait_for_connection()
class UnaryStreamCallResponseIterator(_StreamCallResponseIterator,
_base_call.UnaryStreamCall):
"""UnaryStreamCall class wich uses an alternative response iterator."""
async def read(self) -> ResponseType:
# Behind the scenes everyting goes through the
# async iterator. So this path should not be reached.
raise NotImplementedError()
class StreamStreamCallResponseIterator(_StreamCallResponseIterator,
_base_call.StreamStreamCall):
"""StreamStreamCall class wich uses an alternative response iterator."""
async def read(self) -> ResponseType:
# Behind the scenes everyting goes through the
# async iterator. So this path should not be reached.
raise NotImplementedError()
async def write(self, request: RequestType) -> None:
# Behind the scenes everyting goes through the
# async iterator provided by the InterceptedStreamStreamCall.
# So this path should not be reached.
raise NotImplementedError()
async def done_writing(self) -> None:
# Behind the scenes everyting goes through the
# async iterator provided by the InterceptedStreamStreamCall.
# So this path should not be reached.
raise NotImplementedError()
@property
def _done_writing_flag(self) -> bool:
return self._call._done_writing_flag

View file

@ -0,0 +1,119 @@
# Copyright 2020 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of the metadata abstraction for gRPC Asyncio Python."""
from typing import List, Tuple, Iterator, Any, Union
from collections import abc, OrderedDict
MetadataKey = str
MetadataValue = Union[str, bytes]
class Metadata(abc.Mapping):
"""Metadata abstraction for the asynchronous calls and interceptors.
The metadata is a mapping from str -> List[str]
Traits
* Multiple entries are allowed for the same key
* The order of the values by key is preserved
* Getting by an element by key, retrieves the first mapped value
* Supports an immutable view of the data
* Allows partial mutation on the data without recreating the new object from scratch.
"""
def __init__(self, *args: Tuple[MetadataKey, MetadataValue]) -> None:
self._metadata = OrderedDict()
for md_key, md_value in args:
self.add(md_key, md_value)
@classmethod
def from_tuple(cls, raw_metadata: tuple):
if raw_metadata:
return cls(*raw_metadata)
return cls()
def add(self, key: MetadataKey, value: MetadataValue) -> None:
self._metadata.setdefault(key, [])
self._metadata[key].append(value)
def __len__(self) -> int:
"""Return the total number of elements that there are in the metadata,
including multiple values for the same key.
"""
return sum(map(len, self._metadata.values()))
def __getitem__(self, key: MetadataKey) -> MetadataValue:
"""When calling <metadata>[<key>], the first element of all those
mapped for <key> is returned.
"""
try:
return self._metadata[key][0]
except (ValueError, IndexError) as e:
raise KeyError("{0!r}".format(key)) from e
def __setitem__(self, key: MetadataKey, value: MetadataValue) -> None:
"""Calling metadata[<key>] = <value>
Maps <value> to the first instance of <key>.
"""
if key not in self:
self._metadata[key] = [value]
else:
current_values = self.get_all(key)
self._metadata[key] = [value, *current_values[1:]]
def __delitem__(self, key: MetadataKey) -> None:
"""``del metadata[<key>]`` deletes the first mapping for <key>."""
current_values = self.get_all(key)
if not current_values:
raise KeyError(repr(key))
self._metadata[key] = current_values[1:]
def delete_all(self, key: MetadataKey) -> None:
"""Delete all mappings for <key>."""
del self._metadata[key]
def __iter__(self) -> Iterator[Tuple[MetadataKey, MetadataValue]]:
for key, values in self._metadata.items():
for value in values:
yield (key, value)
def get_all(self, key: MetadataKey) -> List[MetadataValue]:
"""For compatibility with other Metadata abstraction objects (like in Java),
this would return all items under the desired <key>.
"""
return self._metadata.get(key, [])
def set_all(self, key: MetadataKey, values: List[MetadataValue]) -> None:
self._metadata[key] = values
def __contains__(self, key: MetadataKey) -> bool:
return key in self._metadata
def __eq__(self, other: Any) -> bool:
if isinstance(other, self.__class__):
return self._metadata == other._metadata
if isinstance(other, tuple):
return tuple(self) == other
return NotImplemented # pytype: disable=bad-return-type
def __add__(self, other: Any) -> 'Metadata':
if isinstance(other, self.__class__):
return Metadata(*(tuple(self) + tuple(other)))
if isinstance(other, tuple):
return Metadata(*(tuple(self) + other))
return NotImplemented # pytype: disable=bad-return-type
def __repr__(self) -> str:
view = tuple(self)
return "{0}({1!r})".format(self.__class__.__name__, view)

View file

@ -0,0 +1,209 @@
# Copyright 2019 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Server-side implementation of gRPC Asyncio Python."""
from concurrent.futures import Executor
from typing import Any, Optional, Sequence
import grpc
from grpc import _common, _compression
from grpc._cython import cygrpc
from . import _base_server
from ._typing import ChannelArgumentType
from ._interceptor import ServerInterceptor
def _augment_channel_arguments(base_options: ChannelArgumentType,
compression: Optional[grpc.Compression]):
compression_option = _compression.create_channel_option(compression)
return tuple(base_options) + compression_option
class Server(_base_server.Server):
"""Serves RPCs."""
def __init__(self, thread_pool: Optional[Executor],
generic_handlers: Optional[Sequence[grpc.GenericRpcHandler]],
interceptors: Optional[Sequence[Any]],
options: ChannelArgumentType,
maximum_concurrent_rpcs: Optional[int],
compression: Optional[grpc.Compression]):
self._loop = cygrpc.get_working_loop()
if interceptors:
invalid_interceptors = [
interceptor for interceptor in interceptors
if not isinstance(interceptor, ServerInterceptor)
]
if invalid_interceptors:
raise ValueError(
'Interceptor must be ServerInterceptor, the '
f'following are invalid: {invalid_interceptors}')
self._server = cygrpc.AioServer(
self._loop, thread_pool, generic_handlers, interceptors,
_augment_channel_arguments(options, compression),
maximum_concurrent_rpcs)
def add_generic_rpc_handlers(
self,
generic_rpc_handlers: Sequence[grpc.GenericRpcHandler]) -> None:
"""Registers GenericRpcHandlers with this Server.
This method is only safe to call before the server is started.
Args:
generic_rpc_handlers: A sequence of GenericRpcHandlers that will be
used to service RPCs.
"""
self._server.add_generic_rpc_handlers(generic_rpc_handlers)
def add_insecure_port(self, address: str) -> int:
"""Opens an insecure port for accepting RPCs.
This method may only be called before starting the server.
Args:
address: The address for which to open a port. If the port is 0,
or not specified in the address, then the gRPC runtime will choose a port.
Returns:
An integer port on which the server will accept RPC requests.
"""
return _common.validate_port_binding_result(
address, self._server.add_insecure_port(_common.encode(address)))
def add_secure_port(self, address: str,
server_credentials: grpc.ServerCredentials) -> int:
"""Opens a secure port for accepting RPCs.
This method may only be called before starting the server.
Args:
address: The address for which to open a port.
if the port is 0, or not specified in the address, then the gRPC
runtime will choose a port.
server_credentials: A ServerCredentials object.
Returns:
An integer port on which the server will accept RPC requests.
"""
return _common.validate_port_binding_result(
address,
self._server.add_secure_port(_common.encode(address),
server_credentials))
async def start(self) -> None:
"""Starts this Server.
This method may only be called once. (i.e. it is not idempotent).
"""
await self._server.start()
async def stop(self, grace: Optional[float]) -> None:
"""Stops this Server.
This method immediately stops the server from servicing new RPCs in
all cases.
If a grace period is specified, this method returns immediately and all
RPCs active at the end of the grace period are aborted. If a grace
period is not specified (by passing None for grace), all existing RPCs
are aborted immediately and this method blocks until the last RPC
handler terminates.
This method is idempotent and may be called at any time. Passing a
smaller grace value in a subsequent call will have the effect of
stopping the Server sooner (passing None will have the effect of
stopping the server immediately). Passing a larger grace value in a
subsequent call will not have the effect of stopping the server later
(i.e. the most restrictive grace value is used).
Args:
grace: A duration of time in seconds or None.
"""
await self._server.shutdown(grace)
async def wait_for_termination(self,
timeout: Optional[float] = None) -> bool:
"""Block current coroutine until the server stops.
This is an EXPERIMENTAL API.
The wait will not consume computational resources during blocking, and
it will block until one of the two following conditions are met:
1) The server is stopped or terminated;
2) A timeout occurs if timeout is not `None`.
The timeout argument works in the same way as `threading.Event.wait()`.
https://docs.python.org/3/library/threading.html#threading.Event.wait
Args:
timeout: A floating point number specifying a timeout for the
operation in seconds.
Returns:
A bool indicates if the operation times out.
"""
return await self._server.wait_for_termination(timeout)
def __del__(self):
"""Schedules a graceful shutdown in current event loop.
The Cython AioServer doesn't hold a ref-count to this class. It should
be safe to slightly extend the underlying Cython object's life span.
"""
if hasattr(self, '_server'):
if self._server.is_running():
cygrpc.schedule_coro_threadsafe(
self._server.shutdown(None),
self._loop,
)
def server(migration_thread_pool: Optional[Executor] = None,
handlers: Optional[Sequence[grpc.GenericRpcHandler]] = None,
interceptors: Optional[Sequence[Any]] = None,
options: Optional[ChannelArgumentType] = None,
maximum_concurrent_rpcs: Optional[int] = None,
compression: Optional[grpc.Compression] = None):
"""Creates a Server with which RPCs can be serviced.
Args:
migration_thread_pool: A futures.ThreadPoolExecutor to be used by the
Server to execute non-AsyncIO RPC handlers for migration purpose.
handlers: An optional list of GenericRpcHandlers used for executing RPCs.
More handlers may be added by calling add_generic_rpc_handlers any time
before the server is started.
interceptors: An optional list of ServerInterceptor objects that observe
and optionally manipulate the incoming RPCs before handing them over to
handlers. The interceptors are given control in the order they are
specified. This is an EXPERIMENTAL API.
options: An optional list of key-value pairs (:term:`channel_arguments` in gRPC runtime)
to configure the channel.
maximum_concurrent_rpcs: The maximum number of concurrent RPCs this server
will service before returning RESOURCE_EXHAUSTED status, or None to
indicate no limit.
compression: An element of grpc.compression, e.g.
grpc.compression.Gzip. This compression algorithm will be used for the
lifetime of the server unless overridden by set_compression. This is an
EXPERIMENTAL option.
Returns:
A Server object.
"""
return Server(migration_thread_pool, () if handlers is None else handlers,
() if interceptors is None else interceptors,
() if options is None else options, maximum_concurrent_rpcs,
compression)

View file

@ -0,0 +1,32 @@
# Copyright 2019 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common types for gRPC Async API"""
from typing import (Any, AsyncIterable, Callable, Iterable, Sequence, Tuple,
TypeVar, Union)
from grpc._cython.cygrpc import EOF
from ._metadata import Metadata, MetadataKey, MetadataValue
RequestType = TypeVar('RequestType')
ResponseType = TypeVar('ResponseType')
SerializingFunction = Callable[[Any], bytes]
DeserializingFunction = Callable[[bytes], Any]
MetadatumType = Tuple[MetadataKey, MetadataValue]
MetadataType = Metadata
ChannelArgumentType = Sequence[Tuple[str, Any]]
EOFType = type(EOF)
DoneCallbackType = Callable[[Any], None]
RequestIterableType = Union[Iterable[Any], AsyncIterable[Any]]
ResponseIterableType = AsyncIterable[Any]

View file

@ -0,0 +1,22 @@
# Copyright 2019 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Internal utilities used by the gRPC Aio module."""
import time
from typing import Optional
def _timeout_to_deadline(timeout: Optional[float]) -> Optional[float]:
if timeout is None:
return None
return time.time() + timeout

View file

@ -0,0 +1,13 @@
# Copyright 2015 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,706 @@
# Copyright 2016 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Translates gRPC's client-side API into gRPC's client-side Beta API."""
import grpc
from grpc import _common
from grpc.beta import _metadata
from grpc.beta import interfaces
from grpc.framework.common import cardinality
from grpc.framework.foundation import future
from grpc.framework.interfaces.face import face
# pylint: disable=too-many-arguments,too-many-locals,unused-argument
_STATUS_CODE_TO_ABORTION_KIND_AND_ABORTION_ERROR_CLASS = {
grpc.StatusCode.CANCELLED:
(face.Abortion.Kind.CANCELLED, face.CancellationError),
grpc.StatusCode.UNKNOWN:
(face.Abortion.Kind.REMOTE_FAILURE, face.RemoteError),
grpc.StatusCode.DEADLINE_EXCEEDED:
(face.Abortion.Kind.EXPIRED, face.ExpirationError),
grpc.StatusCode.UNIMPLEMENTED:
(face.Abortion.Kind.LOCAL_FAILURE, face.LocalError),
}
def _effective_metadata(metadata, metadata_transformer):
non_none_metadata = () if metadata is None else metadata
if metadata_transformer is None:
return non_none_metadata
else:
return metadata_transformer(non_none_metadata)
def _credentials(grpc_call_options):
return None if grpc_call_options is None else grpc_call_options.credentials
def _abortion(rpc_error_call):
code = rpc_error_call.code()
pair = _STATUS_CODE_TO_ABORTION_KIND_AND_ABORTION_ERROR_CLASS.get(code)
error_kind = face.Abortion.Kind.LOCAL_FAILURE if pair is None else pair[0]
return face.Abortion(error_kind, rpc_error_call.initial_metadata(),
rpc_error_call.trailing_metadata(), code,
rpc_error_call.details())
def _abortion_error(rpc_error_call):
code = rpc_error_call.code()
pair = _STATUS_CODE_TO_ABORTION_KIND_AND_ABORTION_ERROR_CLASS.get(code)
exception_class = face.AbortionError if pair is None else pair[1]
return exception_class(rpc_error_call.initial_metadata(),
rpc_error_call.trailing_metadata(), code,
rpc_error_call.details())
class _InvocationProtocolContext(interfaces.GRPCInvocationContext):
def disable_next_request_compression(self):
pass # TODO(https://github.com/grpc/grpc/issues/4078): design, implement.
class _Rendezvous(future.Future, face.Call):
def __init__(self, response_future, response_iterator, call):
self._future = response_future
self._iterator = response_iterator
self._call = call
def cancel(self):
return self._call.cancel()
def cancelled(self):
return self._future.cancelled()
def running(self):
return self._future.running()
def done(self):
return self._future.done()
def result(self, timeout=None):
try:
return self._future.result(timeout=timeout)
except grpc.RpcError as rpc_error_call:
raise _abortion_error(rpc_error_call)
except grpc.FutureTimeoutError:
raise future.TimeoutError()
except grpc.FutureCancelledError:
raise future.CancelledError()
def exception(self, timeout=None):
try:
rpc_error_call = self._future.exception(timeout=timeout)
if rpc_error_call is None:
return None
else:
return _abortion_error(rpc_error_call)
except grpc.FutureTimeoutError:
raise future.TimeoutError()
except grpc.FutureCancelledError:
raise future.CancelledError()
def traceback(self, timeout=None):
try:
return self._future.traceback(timeout=timeout)
except grpc.FutureTimeoutError:
raise future.TimeoutError()
except grpc.FutureCancelledError:
raise future.CancelledError()
def add_done_callback(self, fn):
self._future.add_done_callback(lambda ignored_callback: fn(self))
def __iter__(self):
return self
def _next(self):
try:
return next(self._iterator)
except grpc.RpcError as rpc_error_call:
raise _abortion_error(rpc_error_call)
def __next__(self):
return self._next()
def next(self):
return self._next()
def is_active(self):
return self._call.is_active()
def time_remaining(self):
return self._call.time_remaining()
def add_abortion_callback(self, abortion_callback):
def done_callback():
if self.code() is not grpc.StatusCode.OK:
abortion_callback(_abortion(self._call))
registered = self._call.add_callback(done_callback)
return None if registered else done_callback()
def protocol_context(self):
return _InvocationProtocolContext()
def initial_metadata(self):
return _metadata.beta(self._call.initial_metadata())
def terminal_metadata(self):
return _metadata.beta(self._call.terminal_metadata())
def code(self):
return self._call.code()
def details(self):
return self._call.details()
def _blocking_unary_unary(channel, group, method, timeout, with_call,
protocol_options, metadata, metadata_transformer,
request, request_serializer, response_deserializer):
try:
multi_callable = channel.unary_unary(
_common.fully_qualified_method(group, method),
request_serializer=request_serializer,
response_deserializer=response_deserializer)
effective_metadata = _effective_metadata(metadata, metadata_transformer)
if with_call:
response, call = multi_callable.with_call(
request,
timeout=timeout,
metadata=_metadata.unbeta(effective_metadata),
credentials=_credentials(protocol_options))
return response, _Rendezvous(None, None, call)
else:
return multi_callable(request,
timeout=timeout,
metadata=_metadata.unbeta(effective_metadata),
credentials=_credentials(protocol_options))
except grpc.RpcError as rpc_error_call:
raise _abortion_error(rpc_error_call)
def _future_unary_unary(channel, group, method, timeout, protocol_options,
metadata, metadata_transformer, request,
request_serializer, response_deserializer):
multi_callable = channel.unary_unary(
_common.fully_qualified_method(group, method),
request_serializer=request_serializer,
response_deserializer=response_deserializer)
effective_metadata = _effective_metadata(metadata, metadata_transformer)
response_future = multi_callable.future(
request,
timeout=timeout,
metadata=_metadata.unbeta(effective_metadata),
credentials=_credentials(protocol_options))
return _Rendezvous(response_future, None, response_future)
def _unary_stream(channel, group, method, timeout, protocol_options, metadata,
metadata_transformer, request, request_serializer,
response_deserializer):
multi_callable = channel.unary_stream(
_common.fully_qualified_method(group, method),
request_serializer=request_serializer,
response_deserializer=response_deserializer)
effective_metadata = _effective_metadata(metadata, metadata_transformer)
response_iterator = multi_callable(
request,
timeout=timeout,
metadata=_metadata.unbeta(effective_metadata),
credentials=_credentials(protocol_options))
return _Rendezvous(None, response_iterator, response_iterator)
def _blocking_stream_unary(channel, group, method, timeout, with_call,
protocol_options, metadata, metadata_transformer,
request_iterator, request_serializer,
response_deserializer):
try:
multi_callable = channel.stream_unary(
_common.fully_qualified_method(group, method),
request_serializer=request_serializer,
response_deserializer=response_deserializer)
effective_metadata = _effective_metadata(metadata, metadata_transformer)
if with_call:
response, call = multi_callable.with_call(
request_iterator,
timeout=timeout,
metadata=_metadata.unbeta(effective_metadata),
credentials=_credentials(protocol_options))
return response, _Rendezvous(None, None, call)
else:
return multi_callable(request_iterator,
timeout=timeout,
metadata=_metadata.unbeta(effective_metadata),
credentials=_credentials(protocol_options))
except grpc.RpcError as rpc_error_call:
raise _abortion_error(rpc_error_call)
def _future_stream_unary(channel, group, method, timeout, protocol_options,
metadata, metadata_transformer, request_iterator,
request_serializer, response_deserializer):
multi_callable = channel.stream_unary(
_common.fully_qualified_method(group, method),
request_serializer=request_serializer,
response_deserializer=response_deserializer)
effective_metadata = _effective_metadata(metadata, metadata_transformer)
response_future = multi_callable.future(
request_iterator,
timeout=timeout,
metadata=_metadata.unbeta(effective_metadata),
credentials=_credentials(protocol_options))
return _Rendezvous(response_future, None, response_future)
def _stream_stream(channel, group, method, timeout, protocol_options, metadata,
metadata_transformer, request_iterator, request_serializer,
response_deserializer):
multi_callable = channel.stream_stream(
_common.fully_qualified_method(group, method),
request_serializer=request_serializer,
response_deserializer=response_deserializer)
effective_metadata = _effective_metadata(metadata, metadata_transformer)
response_iterator = multi_callable(
request_iterator,
timeout=timeout,
metadata=_metadata.unbeta(effective_metadata),
credentials=_credentials(protocol_options))
return _Rendezvous(None, response_iterator, response_iterator)
class _UnaryUnaryMultiCallable(face.UnaryUnaryMultiCallable):
def __init__(self, channel, group, method, metadata_transformer,
request_serializer, response_deserializer):
self._channel = channel
self._group = group
self._method = method
self._metadata_transformer = metadata_transformer
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
def __call__(self,
request,
timeout,
metadata=None,
with_call=False,
protocol_options=None):
return _blocking_unary_unary(self._channel, self._group, self._method,
timeout, with_call, protocol_options,
metadata, self._metadata_transformer,
request, self._request_serializer,
self._response_deserializer)
def future(self, request, timeout, metadata=None, protocol_options=None):
return _future_unary_unary(self._channel, self._group, self._method,
timeout, protocol_options, metadata,
self._metadata_transformer, request,
self._request_serializer,
self._response_deserializer)
def event(self,
request,
receiver,
abortion_callback,
timeout,
metadata=None,
protocol_options=None):
raise NotImplementedError()
class _UnaryStreamMultiCallable(face.UnaryStreamMultiCallable):
def __init__(self, channel, group, method, metadata_transformer,
request_serializer, response_deserializer):
self._channel = channel
self._group = group
self._method = method
self._metadata_transformer = metadata_transformer
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
def __call__(self, request, timeout, metadata=None, protocol_options=None):
return _unary_stream(self._channel, self._group, self._method, timeout,
protocol_options, metadata,
self._metadata_transformer, request,
self._request_serializer,
self._response_deserializer)
def event(self,
request,
receiver,
abortion_callback,
timeout,
metadata=None,
protocol_options=None):
raise NotImplementedError()
class _StreamUnaryMultiCallable(face.StreamUnaryMultiCallable):
def __init__(self, channel, group, method, metadata_transformer,
request_serializer, response_deserializer):
self._channel = channel
self._group = group
self._method = method
self._metadata_transformer = metadata_transformer
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
def __call__(self,
request_iterator,
timeout,
metadata=None,
with_call=False,
protocol_options=None):
return _blocking_stream_unary(self._channel, self._group, self._method,
timeout, with_call, protocol_options,
metadata, self._metadata_transformer,
request_iterator,
self._request_serializer,
self._response_deserializer)
def future(self,
request_iterator,
timeout,
metadata=None,
protocol_options=None):
return _future_stream_unary(self._channel, self._group, self._method,
timeout, protocol_options, metadata,
self._metadata_transformer,
request_iterator, self._request_serializer,
self._response_deserializer)
def event(self,
receiver,
abortion_callback,
timeout,
metadata=None,
protocol_options=None):
raise NotImplementedError()
class _StreamStreamMultiCallable(face.StreamStreamMultiCallable):
def __init__(self, channel, group, method, metadata_transformer,
request_serializer, response_deserializer):
self._channel = channel
self._group = group
self._method = method
self._metadata_transformer = metadata_transformer
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
def __call__(self,
request_iterator,
timeout,
metadata=None,
protocol_options=None):
return _stream_stream(self._channel, self._group, self._method, timeout,
protocol_options, metadata,
self._metadata_transformer, request_iterator,
self._request_serializer,
self._response_deserializer)
def event(self,
receiver,
abortion_callback,
timeout,
metadata=None,
protocol_options=None):
raise NotImplementedError()
class _GenericStub(face.GenericStub):
def __init__(self, channel, metadata_transformer, request_serializers,
response_deserializers):
self._channel = channel
self._metadata_transformer = metadata_transformer
self._request_serializers = request_serializers or {}
self._response_deserializers = response_deserializers or {}
def blocking_unary_unary(self,
group,
method,
request,
timeout,
metadata=None,
with_call=None,
protocol_options=None):
request_serializer = self._request_serializers.get((
group,
method,
))
response_deserializer = self._response_deserializers.get((
group,
method,
))
return _blocking_unary_unary(self._channel, group, method, timeout,
with_call, protocol_options, metadata,
self._metadata_transformer, request,
request_serializer, response_deserializer)
def future_unary_unary(self,
group,
method,
request,
timeout,
metadata=None,
protocol_options=None):
request_serializer = self._request_serializers.get((
group,
method,
))
response_deserializer = self._response_deserializers.get((
group,
method,
))
return _future_unary_unary(self._channel, group, method, timeout,
protocol_options, metadata,
self._metadata_transformer, request,
request_serializer, response_deserializer)
def inline_unary_stream(self,
group,
method,
request,
timeout,
metadata=None,
protocol_options=None):
request_serializer = self._request_serializers.get((
group,
method,
))
response_deserializer = self._response_deserializers.get((
group,
method,
))
return _unary_stream(self._channel, group, method, timeout,
protocol_options, metadata,
self._metadata_transformer, request,
request_serializer, response_deserializer)
def blocking_stream_unary(self,
group,
method,
request_iterator,
timeout,
metadata=None,
with_call=None,
protocol_options=None):
request_serializer = self._request_serializers.get((
group,
method,
))
response_deserializer = self._response_deserializers.get((
group,
method,
))
return _blocking_stream_unary(self._channel, group, method, timeout,
with_call, protocol_options, metadata,
self._metadata_transformer,
request_iterator, request_serializer,
response_deserializer)
def future_stream_unary(self,
group,
method,
request_iterator,
timeout,
metadata=None,
protocol_options=None):
request_serializer = self._request_serializers.get((
group,
method,
))
response_deserializer = self._response_deserializers.get((
group,
method,
))
return _future_stream_unary(self._channel, group, method, timeout,
protocol_options, metadata,
self._metadata_transformer,
request_iterator, request_serializer,
response_deserializer)
def inline_stream_stream(self,
group,
method,
request_iterator,
timeout,
metadata=None,
protocol_options=None):
request_serializer = self._request_serializers.get((
group,
method,
))
response_deserializer = self._response_deserializers.get((
group,
method,
))
return _stream_stream(self._channel, group, method, timeout,
protocol_options, metadata,
self._metadata_transformer, request_iterator,
request_serializer, response_deserializer)
def event_unary_unary(self,
group,
method,
request,
receiver,
abortion_callback,
timeout,
metadata=None,
protocol_options=None):
raise NotImplementedError()
def event_unary_stream(self,
group,
method,
request,
receiver,
abortion_callback,
timeout,
metadata=None,
protocol_options=None):
raise NotImplementedError()
def event_stream_unary(self,
group,
method,
receiver,
abortion_callback,
timeout,
metadata=None,
protocol_options=None):
raise NotImplementedError()
def event_stream_stream(self,
group,
method,
receiver,
abortion_callback,
timeout,
metadata=None,
protocol_options=None):
raise NotImplementedError()
def unary_unary(self, group, method):
request_serializer = self._request_serializers.get((
group,
method,
))
response_deserializer = self._response_deserializers.get((
group,
method,
))
return _UnaryUnaryMultiCallable(self._channel, group, method,
self._metadata_transformer,
request_serializer,
response_deserializer)
def unary_stream(self, group, method):
request_serializer = self._request_serializers.get((
group,
method,
))
response_deserializer = self._response_deserializers.get((
group,
method,
))
return _UnaryStreamMultiCallable(self._channel, group, method,
self._metadata_transformer,
request_serializer,
response_deserializer)
def stream_unary(self, group, method):
request_serializer = self._request_serializers.get((
group,
method,
))
response_deserializer = self._response_deserializers.get((
group,
method,
))
return _StreamUnaryMultiCallable(self._channel, group, method,
self._metadata_transformer,
request_serializer,
response_deserializer)
def stream_stream(self, group, method):
request_serializer = self._request_serializers.get((
group,
method,
))
response_deserializer = self._response_deserializers.get((
group,
method,
))
return _StreamStreamMultiCallable(self._channel, group, method,
self._metadata_transformer,
request_serializer,
response_deserializer)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
return False
class _DynamicStub(face.DynamicStub):
def __init__(self, backing_generic_stub, group, cardinalities):
self._generic_stub = backing_generic_stub
self._group = group
self._cardinalities = cardinalities
def __getattr__(self, attr):
method_cardinality = self._cardinalities.get(attr)
if method_cardinality is cardinality.Cardinality.UNARY_UNARY:
return self._generic_stub.unary_unary(self._group, attr)
elif method_cardinality is cardinality.Cardinality.UNARY_STREAM:
return self._generic_stub.unary_stream(self._group, attr)
elif method_cardinality is cardinality.Cardinality.STREAM_UNARY:
return self._generic_stub.stream_unary(self._group, attr)
elif method_cardinality is cardinality.Cardinality.STREAM_STREAM:
return self._generic_stub.stream_stream(self._group, attr)
else:
raise AttributeError('_DynamicStub object has no attribute "%s"!' %
attr)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
return False
def generic_stub(channel, host, metadata_transformer, request_serializers,
response_deserializers):
return _GenericStub(channel, metadata_transformer, request_serializers,
response_deserializers)
def dynamic_stub(channel, service, cardinalities, host, metadata_transformer,
request_serializers, response_deserializers):
return _DynamicStub(
_GenericStub(channel, metadata_transformer, request_serializers,
response_deserializers), service, cardinalities)

View file

@ -0,0 +1,52 @@
# Copyright 2017 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""API metadata conversion utilities."""
import collections
_Metadatum = collections.namedtuple('_Metadatum', (
'key',
'value',
))
def _beta_metadatum(key, value):
beta_key = key if isinstance(key, (bytes,)) else key.encode('ascii')
beta_value = value if isinstance(value, (bytes,)) else value.encode('ascii')
return _Metadatum(beta_key, beta_value)
def _metadatum(beta_key, beta_value):
key = beta_key if isinstance(beta_key, (str,)) else beta_key.decode('utf8')
if isinstance(beta_value, (str,)) or key[-4:] == '-bin':
value = beta_value
else:
value = beta_value.decode('utf8')
return _Metadatum(key, value)
def beta(metadata):
if metadata is None:
return ()
else:
return tuple(_beta_metadatum(key, value) for key, value in metadata)
def unbeta(beta_metadata):
if beta_metadata is None:
return ()
else:
return tuple(
_metadatum(beta_key, beta_value)
for beta_key, beta_value in beta_metadata)

View file

@ -0,0 +1,385 @@
# Copyright 2016 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Translates gRPC's server-side API into gRPC's server-side Beta API."""
import collections
import threading
import grpc
from grpc import _common
from grpc.beta import _metadata
from grpc.beta import interfaces
from grpc.framework.common import cardinality
from grpc.framework.common import style
from grpc.framework.foundation import abandonment
from grpc.framework.foundation import logging_pool
from grpc.framework.foundation import stream
from grpc.framework.interfaces.face import face
# pylint: disable=too-many-return-statements
_DEFAULT_POOL_SIZE = 8
class _ServerProtocolContext(interfaces.GRPCServicerContext):
def __init__(self, servicer_context):
self._servicer_context = servicer_context
def peer(self):
return self._servicer_context.peer()
def disable_next_response_compression(self):
pass # TODO(https://github.com/grpc/grpc/issues/4078): design, implement.
class _FaceServicerContext(face.ServicerContext):
def __init__(self, servicer_context):
self._servicer_context = servicer_context
def is_active(self):
return self._servicer_context.is_active()
def time_remaining(self):
return self._servicer_context.time_remaining()
def add_abortion_callback(self, abortion_callback):
raise NotImplementedError(
'add_abortion_callback no longer supported server-side!')
def cancel(self):
self._servicer_context.cancel()
def protocol_context(self):
return _ServerProtocolContext(self._servicer_context)
def invocation_metadata(self):
return _metadata.beta(self._servicer_context.invocation_metadata())
def initial_metadata(self, initial_metadata):
self._servicer_context.send_initial_metadata(
_metadata.unbeta(initial_metadata))
def terminal_metadata(self, terminal_metadata):
self._servicer_context.set_terminal_metadata(
_metadata.unbeta(terminal_metadata))
def code(self, code):
self._servicer_context.set_code(code)
def details(self, details):
self._servicer_context.set_details(details)
def _adapt_unary_request_inline(unary_request_inline):
def adaptation(request, servicer_context):
return unary_request_inline(request,
_FaceServicerContext(servicer_context))
return adaptation
def _adapt_stream_request_inline(stream_request_inline):
def adaptation(request_iterator, servicer_context):
return stream_request_inline(request_iterator,
_FaceServicerContext(servicer_context))
return adaptation
class _Callback(stream.Consumer):
def __init__(self):
self._condition = threading.Condition()
self._values = []
self._terminated = False
self._cancelled = False
def consume(self, value):
with self._condition:
self._values.append(value)
self._condition.notify_all()
def terminate(self):
with self._condition:
self._terminated = True
self._condition.notify_all()
def consume_and_terminate(self, value):
with self._condition:
self._values.append(value)
self._terminated = True
self._condition.notify_all()
def cancel(self):
with self._condition:
self._cancelled = True
self._condition.notify_all()
def draw_one_value(self):
with self._condition:
while True:
if self._cancelled:
raise abandonment.Abandoned()
elif self._values:
return self._values.pop(0)
elif self._terminated:
return None
else:
self._condition.wait()
def draw_all_values(self):
with self._condition:
while True:
if self._cancelled:
raise abandonment.Abandoned()
elif self._terminated:
all_values = tuple(self._values)
self._values = None
return all_values
else:
self._condition.wait()
def _run_request_pipe_thread(request_iterator, request_consumer,
servicer_context):
thread_joined = threading.Event()
def pipe_requests():
for request in request_iterator:
if not servicer_context.is_active() or thread_joined.is_set():
return
request_consumer.consume(request)
if not servicer_context.is_active() or thread_joined.is_set():
return
request_consumer.terminate()
request_pipe_thread = threading.Thread(target=pipe_requests)
request_pipe_thread.daemon = True
request_pipe_thread.start()
def _adapt_unary_unary_event(unary_unary_event):
def adaptation(request, servicer_context):
callback = _Callback()
if not servicer_context.add_callback(callback.cancel):
raise abandonment.Abandoned()
unary_unary_event(request, callback.consume_and_terminate,
_FaceServicerContext(servicer_context))
return callback.draw_all_values()[0]
return adaptation
def _adapt_unary_stream_event(unary_stream_event):
def adaptation(request, servicer_context):
callback = _Callback()
if not servicer_context.add_callback(callback.cancel):
raise abandonment.Abandoned()
unary_stream_event(request, callback,
_FaceServicerContext(servicer_context))
while True:
response = callback.draw_one_value()
if response is None:
return
else:
yield response
return adaptation
def _adapt_stream_unary_event(stream_unary_event):
def adaptation(request_iterator, servicer_context):
callback = _Callback()
if not servicer_context.add_callback(callback.cancel):
raise abandonment.Abandoned()
request_consumer = stream_unary_event(
callback.consume_and_terminate,
_FaceServicerContext(servicer_context))
_run_request_pipe_thread(request_iterator, request_consumer,
servicer_context)
return callback.draw_all_values()[0]
return adaptation
def _adapt_stream_stream_event(stream_stream_event):
def adaptation(request_iterator, servicer_context):
callback = _Callback()
if not servicer_context.add_callback(callback.cancel):
raise abandonment.Abandoned()
request_consumer = stream_stream_event(
callback, _FaceServicerContext(servicer_context))
_run_request_pipe_thread(request_iterator, request_consumer,
servicer_context)
while True:
response = callback.draw_one_value()
if response is None:
return
else:
yield response
return adaptation
class _SimpleMethodHandler(
collections.namedtuple('_MethodHandler', (
'request_streaming',
'response_streaming',
'request_deserializer',
'response_serializer',
'unary_unary',
'unary_stream',
'stream_unary',
'stream_stream',
)), grpc.RpcMethodHandler):
pass
def _simple_method_handler(implementation, request_deserializer,
response_serializer):
if implementation.style is style.Service.INLINE:
if implementation.cardinality is cardinality.Cardinality.UNARY_UNARY:
return _SimpleMethodHandler(
False, False, request_deserializer, response_serializer,
_adapt_unary_request_inline(implementation.unary_unary_inline),
None, None, None)
elif implementation.cardinality is cardinality.Cardinality.UNARY_STREAM:
return _SimpleMethodHandler(
False, True, request_deserializer, response_serializer, None,
_adapt_unary_request_inline(implementation.unary_stream_inline),
None, None)
elif implementation.cardinality is cardinality.Cardinality.STREAM_UNARY:
return _SimpleMethodHandler(
True, False, request_deserializer, response_serializer, None,
None,
_adapt_stream_request_inline(
implementation.stream_unary_inline), None)
elif implementation.cardinality is cardinality.Cardinality.STREAM_STREAM:
return _SimpleMethodHandler(
True, True, request_deserializer, response_serializer, None,
None, None,
_adapt_stream_request_inline(
implementation.stream_stream_inline))
elif implementation.style is style.Service.EVENT:
if implementation.cardinality is cardinality.Cardinality.UNARY_UNARY:
return _SimpleMethodHandler(
False, False, request_deserializer, response_serializer,
_adapt_unary_unary_event(implementation.unary_unary_event),
None, None, None)
elif implementation.cardinality is cardinality.Cardinality.UNARY_STREAM:
return _SimpleMethodHandler(
False, True, request_deserializer, response_serializer, None,
_adapt_unary_stream_event(implementation.unary_stream_event),
None, None)
elif implementation.cardinality is cardinality.Cardinality.STREAM_UNARY:
return _SimpleMethodHandler(
True, False, request_deserializer, response_serializer, None,
None,
_adapt_stream_unary_event(implementation.stream_unary_event),
None)
elif implementation.cardinality is cardinality.Cardinality.STREAM_STREAM:
return _SimpleMethodHandler(
True, True, request_deserializer, response_serializer, None,
None, None,
_adapt_stream_stream_event(implementation.stream_stream_event))
raise ValueError()
def _flatten_method_pair_map(method_pair_map):
method_pair_map = method_pair_map or {}
flat_map = {}
for method_pair in method_pair_map:
method = _common.fully_qualified_method(method_pair[0], method_pair[1])
flat_map[method] = method_pair_map[method_pair]
return flat_map
class _GenericRpcHandler(grpc.GenericRpcHandler):
def __init__(self, method_implementations, multi_method_implementation,
request_deserializers, response_serializers):
self._method_implementations = _flatten_method_pair_map(
method_implementations)
self._request_deserializers = _flatten_method_pair_map(
request_deserializers)
self._response_serializers = _flatten_method_pair_map(
response_serializers)
self._multi_method_implementation = multi_method_implementation
def service(self, handler_call_details):
method_implementation = self._method_implementations.get(
handler_call_details.method)
if method_implementation is not None:
return _simple_method_handler(
method_implementation,
self._request_deserializers.get(handler_call_details.method),
self._response_serializers.get(handler_call_details.method))
elif self._multi_method_implementation is None:
return None
else:
try:
return None #TODO(nathaniel): call the multimethod.
except face.NoSuchMethodError:
return None
class _Server(interfaces.Server):
def __init__(self, grpc_server):
self._grpc_server = grpc_server
def add_insecure_port(self, address):
return self._grpc_server.add_insecure_port(address)
def add_secure_port(self, address, server_credentials):
return self._grpc_server.add_secure_port(address, server_credentials)
def start(self):
self._grpc_server.start()
def stop(self, grace):
return self._grpc_server.stop(grace)
def __enter__(self):
self._grpc_server.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self._grpc_server.stop(None)
return False
def server(service_implementations, multi_method_implementation,
request_deserializers, response_serializers, thread_pool,
thread_pool_size):
generic_rpc_handler = _GenericRpcHandler(service_implementations,
multi_method_implementation,
request_deserializers,
response_serializers)
if thread_pool is None:
effective_thread_pool = logging_pool.pool(
_DEFAULT_POOL_SIZE if thread_pool_size is None else thread_pool_size
)
else:
effective_thread_pool = thread_pool
return _Server(
grpc.server(effective_thread_pool, handlers=(generic_rpc_handler,)))

View file

@ -0,0 +1,310 @@
# Copyright 2015-2016 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Entry points into the Beta API of gRPC Python."""
# threading is referenced from specification in this module.
import threading # pylint: disable=unused-import
# interfaces, cardinality, and face are referenced from specification in this
# module.
import grpc
from grpc import _auth
from grpc.beta import _client_adaptations
from grpc.beta import _metadata
from grpc.beta import _server_adaptations
from grpc.beta import interfaces # pylint: disable=unused-import
from grpc.framework.common import cardinality # pylint: disable=unused-import
from grpc.framework.interfaces.face import face # pylint: disable=unused-import
# pylint: disable=too-many-arguments
ChannelCredentials = grpc.ChannelCredentials
ssl_channel_credentials = grpc.ssl_channel_credentials
CallCredentials = grpc.CallCredentials
def metadata_call_credentials(metadata_plugin, name=None):
def plugin(context, callback):
def wrapped_callback(beta_metadata, error):
callback(_metadata.unbeta(beta_metadata), error)
metadata_plugin(context, wrapped_callback)
return grpc.metadata_call_credentials(plugin, name=name)
def google_call_credentials(credentials):
"""Construct CallCredentials from GoogleCredentials.
Args:
credentials: A GoogleCredentials object from the oauth2client library.
Returns:
A CallCredentials object for use in a GRPCCallOptions object.
"""
return metadata_call_credentials(_auth.GoogleCallCredentials(credentials))
access_token_call_credentials = grpc.access_token_call_credentials
composite_call_credentials = grpc.composite_call_credentials
composite_channel_credentials = grpc.composite_channel_credentials
class Channel(object):
"""A channel to a remote host through which RPCs may be conducted.
Only the "subscribe" and "unsubscribe" methods are supported for application
use. This class' instance constructor and all other attributes are
unsupported.
"""
def __init__(self, channel):
self._channel = channel
def subscribe(self, callback, try_to_connect=None):
"""Subscribes to this Channel's connectivity.
Args:
callback: A callable to be invoked and passed an
interfaces.ChannelConnectivity identifying this Channel's connectivity.
The callable will be invoked immediately upon subscription and again for
every change to this Channel's connectivity thereafter until it is
unsubscribed.
try_to_connect: A boolean indicating whether or not this Channel should
attempt to connect if it is not already connected and ready to conduct
RPCs.
"""
self._channel.subscribe(callback, try_to_connect=try_to_connect)
def unsubscribe(self, callback):
"""Unsubscribes a callback from this Channel's connectivity.
Args:
callback: A callable previously registered with this Channel from having
been passed to its "subscribe" method.
"""
self._channel.unsubscribe(callback)
def insecure_channel(host, port):
"""Creates an insecure Channel to a remote host.
Args:
host: The name of the remote host to which to connect.
port: The port of the remote host to which to connect.
If None only the 'host' part will be used.
Returns:
A Channel to the remote host through which RPCs may be conducted.
"""
channel = grpc.insecure_channel(host if port is None else '%s:%d' %
(host, port))
return Channel(channel)
def secure_channel(host, port, channel_credentials):
"""Creates a secure Channel to a remote host.
Args:
host: The name of the remote host to which to connect.
port: The port of the remote host to which to connect.
If None only the 'host' part will be used.
channel_credentials: A ChannelCredentials.
Returns:
A secure Channel to the remote host through which RPCs may be conducted.
"""
channel = grpc.secure_channel(
host if port is None else '%s:%d' % (host, port), channel_credentials)
return Channel(channel)
class StubOptions(object):
"""A value encapsulating the various options for creation of a Stub.
This class and its instances have no supported interface - it exists to define
the type of its instances and its instances exist to be passed to other
functions.
"""
def __init__(self, host, request_serializers, response_deserializers,
metadata_transformer, thread_pool, thread_pool_size):
self.host = host
self.request_serializers = request_serializers
self.response_deserializers = response_deserializers
self.metadata_transformer = metadata_transformer
self.thread_pool = thread_pool
self.thread_pool_size = thread_pool_size
_EMPTY_STUB_OPTIONS = StubOptions(None, None, None, None, None, None)
def stub_options(host=None,
request_serializers=None,
response_deserializers=None,
metadata_transformer=None,
thread_pool=None,
thread_pool_size=None):
"""Creates a StubOptions value to be passed at stub creation.
All parameters are optional and should always be passed by keyword.
Args:
host: A host string to set on RPC calls.
request_serializers: A dictionary from service name-method name pair to
request serialization behavior.
response_deserializers: A dictionary from service name-method name pair to
response deserialization behavior.
metadata_transformer: A callable that given a metadata object produces
another metadata object to be used in the underlying communication on the
wire.
thread_pool: A thread pool to use in stubs.
thread_pool_size: The size of thread pool to create for use in stubs;
ignored if thread_pool has been passed.
Returns:
A StubOptions value created from the passed parameters.
"""
return StubOptions(host, request_serializers, response_deserializers,
metadata_transformer, thread_pool, thread_pool_size)
def generic_stub(channel, options=None):
"""Creates a face.GenericStub on which RPCs can be made.
Args:
channel: A Channel for use by the created stub.
options: A StubOptions customizing the created stub.
Returns:
A face.GenericStub on which RPCs can be made.
"""
effective_options = _EMPTY_STUB_OPTIONS if options is None else options
return _client_adaptations.generic_stub(
channel._channel, # pylint: disable=protected-access
effective_options.host,
effective_options.metadata_transformer,
effective_options.request_serializers,
effective_options.response_deserializers)
def dynamic_stub(channel, service, cardinalities, options=None):
"""Creates a face.DynamicStub with which RPCs can be invoked.
Args:
channel: A Channel for the returned face.DynamicStub to use.
service: The package-qualified full name of the service.
cardinalities: A dictionary from RPC method name to cardinality.Cardinality
value identifying the cardinality of the RPC method.
options: An optional StubOptions value further customizing the functionality
of the returned face.DynamicStub.
Returns:
A face.DynamicStub with which RPCs can be invoked.
"""
effective_options = _EMPTY_STUB_OPTIONS if options is None else options
return _client_adaptations.dynamic_stub(
channel._channel, # pylint: disable=protected-access
service,
cardinalities,
effective_options.host,
effective_options.metadata_transformer,
effective_options.request_serializers,
effective_options.response_deserializers)
ServerCredentials = grpc.ServerCredentials
ssl_server_credentials = grpc.ssl_server_credentials
class ServerOptions(object):
"""A value encapsulating the various options for creation of a Server.
This class and its instances have no supported interface - it exists to define
the type of its instances and its instances exist to be passed to other
functions.
"""
def __init__(self, multi_method_implementation, request_deserializers,
response_serializers, thread_pool, thread_pool_size,
default_timeout, maximum_timeout):
self.multi_method_implementation = multi_method_implementation
self.request_deserializers = request_deserializers
self.response_serializers = response_serializers
self.thread_pool = thread_pool
self.thread_pool_size = thread_pool_size
self.default_timeout = default_timeout
self.maximum_timeout = maximum_timeout
_EMPTY_SERVER_OPTIONS = ServerOptions(None, None, None, None, None, None, None)
def server_options(multi_method_implementation=None,
request_deserializers=None,
response_serializers=None,
thread_pool=None,
thread_pool_size=None,
default_timeout=None,
maximum_timeout=None):
"""Creates a ServerOptions value to be passed at server creation.
All parameters are optional and should always be passed by keyword.
Args:
multi_method_implementation: A face.MultiMethodImplementation to be called
to service an RPC if the server has no specific method implementation for
the name of the RPC for which service was requested.
request_deserializers: A dictionary from service name-method name pair to
request deserialization behavior.
response_serializers: A dictionary from service name-method name pair to
response serialization behavior.
thread_pool: A thread pool to use in stubs.
thread_pool_size: The size of thread pool to create for use in stubs;
ignored if thread_pool has been passed.
default_timeout: A duration in seconds to allow for RPC service when
servicing RPCs that did not include a timeout value when invoked.
maximum_timeout: A duration in seconds to allow for RPC service when
servicing RPCs no matter what timeout value was passed when the RPC was
invoked.
Returns:
A StubOptions value created from the passed parameters.
"""
return ServerOptions(multi_method_implementation, request_deserializers,
response_serializers, thread_pool, thread_pool_size,
default_timeout, maximum_timeout)
def server(service_implementations, options=None):
"""Creates an interfaces.Server with which RPCs can be serviced.
Args:
service_implementations: A dictionary from service name-method name pair to
face.MethodImplementation.
options: An optional ServerOptions value further customizing the
functionality of the returned Server.
Returns:
An interfaces.Server with which RPCs can be serviced.
"""
effective_options = _EMPTY_SERVER_OPTIONS if options is None else options
return _server_adaptations.server(
service_implementations, effective_options.multi_method_implementation,
effective_options.request_deserializers,
effective_options.response_serializers, effective_options.thread_pool,
effective_options.thread_pool_size)

View file

@ -0,0 +1,165 @@
# Copyright 2015 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Constants and interfaces of the Beta API of gRPC Python."""
import abc
import six
import grpc
ChannelConnectivity = grpc.ChannelConnectivity
# FATAL_FAILURE was a Beta-API name for SHUTDOWN
ChannelConnectivity.FATAL_FAILURE = ChannelConnectivity.SHUTDOWN
StatusCode = grpc.StatusCode
class GRPCCallOptions(object):
"""A value encapsulating gRPC-specific options passed on RPC invocation.
This class and its instances have no supported interface - it exists to
define the type of its instances and its instances exist to be passed to
other functions.
"""
def __init__(self, disable_compression, subcall_of, credentials):
self.disable_compression = disable_compression
self.subcall_of = subcall_of
self.credentials = credentials
def grpc_call_options(disable_compression=False, credentials=None):
"""Creates a GRPCCallOptions value to be passed at RPC invocation.
All parameters are optional and should always be passed by keyword.
Args:
disable_compression: A boolean indicating whether or not compression should
be disabled for the request object of the RPC. Only valid for
request-unary RPCs.
credentials: A CallCredentials object to use for the invoked RPC.
"""
return GRPCCallOptions(disable_compression, None, credentials)
GRPCAuthMetadataContext = grpc.AuthMetadataContext
GRPCAuthMetadataPluginCallback = grpc.AuthMetadataPluginCallback
GRPCAuthMetadataPlugin = grpc.AuthMetadataPlugin
class GRPCServicerContext(six.with_metaclass(abc.ABCMeta)):
"""Exposes gRPC-specific options and behaviors to code servicing RPCs."""
@abc.abstractmethod
def peer(self):
"""Identifies the peer that invoked the RPC being serviced.
Returns:
A string identifying the peer that invoked the RPC being serviced.
"""
raise NotImplementedError()
@abc.abstractmethod
def disable_next_response_compression(self):
"""Disables compression of the next response passed by the application."""
raise NotImplementedError()
class GRPCInvocationContext(six.with_metaclass(abc.ABCMeta)):
"""Exposes gRPC-specific options and behaviors to code invoking RPCs."""
@abc.abstractmethod
def disable_next_request_compression(self):
"""Disables compression of the next request passed by the application."""
raise NotImplementedError()
class Server(six.with_metaclass(abc.ABCMeta)):
"""Services RPCs."""
@abc.abstractmethod
def add_insecure_port(self, address):
"""Reserves a port for insecure RPC service once this Server becomes active.
This method may only be called before calling this Server's start method is
called.
Args:
address: The address for which to open a port.
Returns:
An integer port on which RPCs will be serviced after this link has been
started. This is typically the same number as the port number contained
in the passed address, but will likely be different if the port number
contained in the passed address was zero.
"""
raise NotImplementedError()
@abc.abstractmethod
def add_secure_port(self, address, server_credentials):
"""Reserves a port for secure RPC service after this Server becomes active.
This method may only be called before calling this Server's start method is
called.
Args:
address: The address for which to open a port.
server_credentials: A ServerCredentials.
Returns:
An integer port on which RPCs will be serviced after this link has been
started. This is typically the same number as the port number contained
in the passed address, but will likely be different if the port number
contained in the passed address was zero.
"""
raise NotImplementedError()
@abc.abstractmethod
def start(self):
"""Starts this Server's service of RPCs.
This method may only be called while the server is not serving RPCs (i.e. it
is not idempotent).
"""
raise NotImplementedError()
@abc.abstractmethod
def stop(self, grace):
"""Stops this Server's service of RPCs.
All calls to this method immediately stop service of new RPCs. When existing
RPCs are aborted is controlled by the grace period parameter passed to this
method.
This method may be called at any time and is idempotent. Passing a smaller
grace value than has been passed in a previous call will have the effect of
stopping the Server sooner. Passing a larger grace value than has been
passed in a previous call will not have the effect of stopping the server
later.
Args:
grace: A duration of time in seconds to allow existing RPCs to complete
before being aborted by this Server's stopping. May be zero for
immediate abortion of all in-progress RPCs.
Returns:
A threading.Event that will be set when this Server has completely
stopped. The returned event may not be set until after the full grace
period (if some ongoing RPC continues for the full length of the period)
of it may be set much sooner (such as if this Server had no RPCs underway
at the time it was stopped or if all RPCs that it had underway completed
very early in the grace period).
"""
raise NotImplementedError()

View file

@ -0,0 +1,149 @@
# Copyright 2015 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for the gRPC Python Beta API."""
import threading
import time
# implementations is referenced from specification in this module.
from grpc.beta import implementations # pylint: disable=unused-import
from grpc.beta import interfaces
from grpc.framework.foundation import callable_util
from grpc.framework.foundation import future
_DONE_CALLBACK_EXCEPTION_LOG_MESSAGE = (
'Exception calling connectivity future "done" callback!')
class _ChannelReadyFuture(future.Future):
def __init__(self, channel):
self._condition = threading.Condition()
self._channel = channel
self._matured = False
self._cancelled = False
self._done_callbacks = []
def _block(self, timeout):
until = None if timeout is None else time.time() + timeout
with self._condition:
while True:
if self._cancelled:
raise future.CancelledError()
elif self._matured:
return
else:
if until is None:
self._condition.wait()
else:
remaining = until - time.time()
if remaining < 0:
raise future.TimeoutError()
else:
self._condition.wait(timeout=remaining)
def _update(self, connectivity):
with self._condition:
if (not self._cancelled and
connectivity is interfaces.ChannelConnectivity.READY):
self._matured = True
self._channel.unsubscribe(self._update)
self._condition.notify_all()
done_callbacks = tuple(self._done_callbacks)
self._done_callbacks = None
else:
return
for done_callback in done_callbacks:
callable_util.call_logging_exceptions(
done_callback, _DONE_CALLBACK_EXCEPTION_LOG_MESSAGE, self)
def cancel(self):
with self._condition:
if not self._matured:
self._cancelled = True
self._channel.unsubscribe(self._update)
self._condition.notify_all()
done_callbacks = tuple(self._done_callbacks)
self._done_callbacks = None
else:
return False
for done_callback in done_callbacks:
callable_util.call_logging_exceptions(
done_callback, _DONE_CALLBACK_EXCEPTION_LOG_MESSAGE, self)
return True
def cancelled(self):
with self._condition:
return self._cancelled
def running(self):
with self._condition:
return not self._cancelled and not self._matured
def done(self):
with self._condition:
return self._cancelled or self._matured
def result(self, timeout=None):
self._block(timeout)
return None
def exception(self, timeout=None):
self._block(timeout)
return None
def traceback(self, timeout=None):
self._block(timeout)
return None
def add_done_callback(self, fn):
with self._condition:
if not self._cancelled and not self._matured:
self._done_callbacks.append(fn)
return
fn(self)
def start(self):
with self._condition:
self._channel.subscribe(self._update, try_to_connect=True)
def __del__(self):
with self._condition:
if not self._cancelled and not self._matured:
self._channel.unsubscribe(self._update)
def channel_ready_future(channel):
"""Creates a future.Future tracking when an implementations.Channel is ready.
Cancelling the returned future.Future does not tell the given
implementations.Channel to abandon attempts it may have been making to
connect; cancelling merely deactivates the return future.Future's
subscription to the given implementations.Channel's connectivity.
Args:
channel: An implementations.Channel.
Returns:
A future.Future that matures when the given Channel has connectivity
interfaces.ChannelConnectivity.READY.
"""
ready_future = _ChannelReadyFuture(channel)
ready_future.start()
return ready_future

View file

@ -0,0 +1,127 @@
# Copyright 2018 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""gRPC's experimental APIs.
These APIs are subject to be removed during any minor version release.
"""
import copy
import functools
import sys
import warnings
import grpc
_EXPERIMENTAL_APIS_USED = set()
class ChannelOptions(object):
"""Indicates a channel option unique to gRPC Python.
This enumeration is part of an EXPERIMENTAL API.
Attributes:
SingleThreadedUnaryStream: Perform unary-stream RPCs on a single thread.
"""
SingleThreadedUnaryStream = "SingleThreadedUnaryStream"
class UsageError(Exception):
"""Raised by the gRPC library to indicate usage not allowed by the API."""
_insecure_channel_credentials_sentinel = object()
_insecure_channel_credentials = grpc.ChannelCredentials(
_insecure_channel_credentials_sentinel)
def insecure_channel_credentials():
"""Creates a ChannelCredentials for use with an insecure channel.
THIS IS AN EXPERIMENTAL API.
This is not for use with secure_channel function. Intead, this should be
used with grpc.unary_unary, grpc.unary_stream, grpc.stream_unary, or
grpc.stream_stream.
"""
return _insecure_channel_credentials
class ExperimentalApiWarning(Warning):
"""A warning that an API is experimental."""
def _warn_experimental(api_name, stack_offset):
if api_name not in _EXPERIMENTAL_APIS_USED:
_EXPERIMENTAL_APIS_USED.add(api_name)
msg = ("'{}' is an experimental API. It is subject to change or ".
format(api_name) +
"removal between minor releases. Proceed with caution.")
warnings.warn(msg, ExperimentalApiWarning, stacklevel=2 + stack_offset)
def experimental_api(f):
@functools.wraps(f)
def _wrapper(*args, **kwargs):
_warn_experimental(f.__name__, 1)
return f(*args, **kwargs)
return _wrapper
def wrap_server_method_handler(wrapper, handler):
"""Wraps the server method handler function.
The server implementation requires all server handlers being wrapped as
RpcMethodHandler objects. This helper function ease the pain of writing
server handler wrappers.
Args:
wrapper: A wrapper function that takes in a method handler behavior
(the actual function) and returns a wrapped function.
handler: A RpcMethodHandler object to be wrapped.
Returns:
A newly created RpcMethodHandler.
"""
if not handler:
return None
if not handler.request_streaming:
if not handler.response_streaming:
# NOTE(lidiz) _replace is a public API:
# https://docs.python.org/dev/library/collections.html
return handler._replace(unary_unary=wrapper(handler.unary_unary))
else:
return handler._replace(unary_stream=wrapper(handler.unary_stream))
else:
if not handler.response_streaming:
return handler._replace(stream_unary=wrapper(handler.stream_unary))
else:
return handler._replace(
stream_stream=wrapper(handler.stream_stream))
__all__ = (
'ChannelOptions',
'ExperimentalApiWarning',
'UsageError',
'insecure_channel_credentials',
'wrap_server_method_handler',
)
if sys.version_info > (3, 6):
from grpc._simple_stubs import unary_unary, unary_stream, stream_unary, stream_stream
__all__ = __all__ + (unary_unary, unary_stream, stream_unary, stream_stream)

View file

@ -0,0 +1,16 @@
# Copyright 2020 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Alias of grpc.aio to keep backward compatibility."""
from grpc.aio import *

View file

@ -0,0 +1,27 @@
# Copyright 2018 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""gRPC's Python gEvent APIs."""
from grpc._cython import cygrpc as _cygrpc
def init_gevent():
"""Patches gRPC's libraries to be compatible with gevent.
This must be called AFTER the python standard lib has been patched,
but BEFORE creating and gRPC objects.
In order for progress to be made, the application must drive the event loop.
"""
_cygrpc.init_grpc_gevent()

View file

@ -0,0 +1,45 @@
# Copyright 2018 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""gRPC's APIs for TLS Session Resumption support"""
from grpc._cython import cygrpc as _cygrpc
def ssl_session_cache_lru(capacity):
"""Creates an SSLSessionCache with LRU replacement policy
Args:
capacity: Size of the cache
Returns:
An SSLSessionCache with LRU replacement policy that can be passed as a value for
the grpc.ssl_session_cache option to a grpc.Channel. SSL session caches are used
to store session tickets, which clients can present to resume previous TLS sessions
with a server.
"""
return SSLSessionCache(_cygrpc.SSLSessionCacheLRU(capacity))
class SSLSessionCache(object):
"""An encapsulation of a session cache used for TLS session resumption.
Instances of this class can be passed to a Channel as values for the
grpc.ssl_session_cache option
"""
def __init__(self, cache):
self._cache = cache
def __int__(self):
return int(self._cache)

View file

@ -0,0 +1,13 @@
# Copyright 2015 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,13 @@
# Copyright 2015 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,26 @@
# Copyright 2015 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines an enum for classifying RPC methods by streaming semantics."""
import enum
@enum.unique
class Cardinality(enum.Enum):
"""Describes the streaming semantics of an RPC method."""
UNARY_UNARY = 'request-unary/response-unary'
UNARY_STREAM = 'request-unary/response-streaming'
STREAM_UNARY = 'request-streaming/response-unary'
STREAM_STREAM = 'request-streaming/response-streaming'

View file

@ -0,0 +1,24 @@
# Copyright 2015 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines an enum for classifying RPC methods by control flow semantics."""
import enum
@enum.unique
class Service(enum.Enum):
"""Describes the control flow style of RPC method implementation."""
INLINE = 'inline'
EVENT = 'event'

View file

@ -0,0 +1,13 @@
# Copyright 2015 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,22 @@
# Copyright 2015 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for indicating abandonment of computation."""
class Abandoned(Exception):
"""Indicates that some computation is being abandoned.
Abandoning a computation is different than returning a value or raising
an exception indicating some operational or programming defect.
"""

View file

@ -0,0 +1,96 @@
# Copyright 2015 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for working with callables."""
import abc
import collections
import enum
import functools
import logging
import six
_LOGGER = logging.getLogger(__name__)
class Outcome(six.with_metaclass(abc.ABCMeta)):
"""A sum type describing the outcome of some call.
Attributes:
kind: One of Kind.RETURNED or Kind.RAISED respectively indicating that the
call returned a value or raised an exception.
return_value: The value returned by the call. Must be present if kind is
Kind.RETURNED.
exception: The exception raised by the call. Must be present if kind is
Kind.RAISED.
"""
@enum.unique
class Kind(enum.Enum):
"""Identifies the general kind of the outcome of some call."""
RETURNED = object()
RAISED = object()
class _EasyOutcome(
collections.namedtuple('_EasyOutcome',
['kind', 'return_value', 'exception']), Outcome):
"""A trivial implementation of Outcome."""
def _call_logging_exceptions(behavior, message, *args, **kwargs):
try:
return _EasyOutcome(Outcome.Kind.RETURNED, behavior(*args, **kwargs),
None)
except Exception as e: # pylint: disable=broad-except
_LOGGER.exception(message)
return _EasyOutcome(Outcome.Kind.RAISED, None, e)
def with_exceptions_logged(behavior, message):
"""Wraps a callable in a try-except that logs any exceptions it raises.
Args:
behavior: Any callable.
message: A string to log if the behavior raises an exception.
Returns:
A callable that when executed invokes the given behavior. The returned
callable takes the same arguments as the given behavior but returns a
future.Outcome describing whether the given behavior returned a value or
raised an exception.
"""
@functools.wraps(behavior)
def wrapped_behavior(*args, **kwargs):
return _call_logging_exceptions(behavior, message, *args, **kwargs)
return wrapped_behavior
def call_logging_exceptions(behavior, message, *args, **kwargs):
"""Calls a behavior in a try-except that logs any exceptions it raises.
Args:
behavior: Any callable.
message: A string to log if the behavior raises an exception.
*args: Positional arguments to pass to the given behavior.
**kwargs: Keyword arguments to pass to the given behavior.
Returns:
An Outcome describing whether the given behavior returned a value or raised
an exception.
"""
return _call_logging_exceptions(behavior, message, *args, **kwargs)

View file

@ -0,0 +1,221 @@
# Copyright 2015 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Future interface.
Python doesn't have a Future interface in its standard library. In the absence
of such a standard, three separate, incompatible implementations
(concurrent.futures.Future, ndb.Future, and asyncio.Future) have appeared. This
interface attempts to be as compatible as possible with
concurrent.futures.Future. From ndb.Future it adopts a traceback-object accessor
method.
Unlike the concrete and implemented Future classes listed above, the Future
class defined in this module is an entirely abstract interface that anyone may
implement and use.
The one known incompatibility between this interface and the interface of
concurrent.futures.Future is that this interface defines its own CancelledError
and TimeoutError exceptions rather than raising the implementation-private
concurrent.futures._base.CancelledError and the
built-in-but-only-in-3.3-and-later TimeoutError.
"""
import abc
import six
class TimeoutError(Exception):
"""Indicates that a particular call timed out."""
class CancelledError(Exception):
"""Indicates that the computation underlying a Future was cancelled."""
class Future(six.with_metaclass(abc.ABCMeta)):
"""A representation of a computation in another control flow.
Computations represented by a Future may be yet to be begun, may be ongoing,
or may have already completed.
"""
# NOTE(nathaniel): This isn't the return type that I would want to have if it
# were up to me. Were this interface being written from scratch, the return
# type of this method would probably be a sum type like:
#
# NOT_COMMENCED
# COMMENCED_AND_NOT_COMPLETED
# PARTIAL_RESULT<Partial_Result_Type>
# COMPLETED<Result_Type>
# UNCANCELLABLE
# NOT_IMMEDIATELY_DETERMINABLE
@abc.abstractmethod
def cancel(self):
"""Attempts to cancel the computation.
This method does not block.
Returns:
True if the computation has not yet begun, will not be allowed to take
place, and determination of both was possible without blocking. False
under all other circumstances including but not limited to the
computation's already having begun, the computation's already having
finished, and the computation's having been scheduled for execution on a
remote system for which a determination of whether or not it commenced
before being cancelled cannot be made without blocking.
"""
raise NotImplementedError()
# NOTE(nathaniel): Here too this isn't the return type that I'd want this
# method to have if it were up to me. I think I'd go with another sum type
# like:
#
# NOT_CANCELLED (this object's cancel method hasn't been called)
# NOT_COMMENCED
# COMMENCED_AND_NOT_COMPLETED
# PARTIAL_RESULT<Partial_Result_Type>
# COMPLETED<Result_Type>
# UNCANCELLABLE
# NOT_IMMEDIATELY_DETERMINABLE
#
# Notice how giving the cancel method the right semantics obviates most
# reasons for this method to exist.
@abc.abstractmethod
def cancelled(self):
"""Describes whether the computation was cancelled.
This method does not block.
Returns:
True if the computation was cancelled any time before its result became
immediately available. False under all other circumstances including but
not limited to this object's cancel method not having been called and
the computation's result having become immediately available.
"""
raise NotImplementedError()
@abc.abstractmethod
def running(self):
"""Describes whether the computation is taking place.
This method does not block.
Returns:
True if the computation is scheduled to take place in the future or is
taking place now, or False if the computation took place in the past or
was cancelled.
"""
raise NotImplementedError()
# NOTE(nathaniel): These aren't quite the semantics I'd like here either. I
# would rather this only returned True in cases in which the underlying
# computation completed successfully. A computation's having been cancelled
# conflicts with considering that computation "done".
@abc.abstractmethod
def done(self):
"""Describes whether the computation has taken place.
This method does not block.
Returns:
True if the computation is known to have either completed or have been
unscheduled or interrupted. False if the computation may possibly be
executing or scheduled to execute later.
"""
raise NotImplementedError()
@abc.abstractmethod
def result(self, timeout=None):
"""Accesses the outcome of the computation or raises its exception.
This method may return immediately or may block.
Args:
timeout: The length of time in seconds to wait for the computation to
finish or be cancelled, or None if this method should block until the
computation has finished or is cancelled no matter how long that takes.
Returns:
The return value of the computation.
Raises:
TimeoutError: If a timeout value is passed and the computation does not
terminate within the allotted time.
CancelledError: If the computation was cancelled.
Exception: If the computation raised an exception, this call will raise
the same exception.
"""
raise NotImplementedError()
@abc.abstractmethod
def exception(self, timeout=None):
"""Return the exception raised by the computation.
This method may return immediately or may block.
Args:
timeout: The length of time in seconds to wait for the computation to
terminate or be cancelled, or None if this method should block until
the computation is terminated or is cancelled no matter how long that
takes.
Returns:
The exception raised by the computation, or None if the computation did
not raise an exception.
Raises:
TimeoutError: If a timeout value is passed and the computation does not
terminate within the allotted time.
CancelledError: If the computation was cancelled.
"""
raise NotImplementedError()
@abc.abstractmethod
def traceback(self, timeout=None):
"""Access the traceback of the exception raised by the computation.
This method may return immediately or may block.
Args:
timeout: The length of time in seconds to wait for the computation to
terminate or be cancelled, or None if this method should block until
the computation is terminated or is cancelled no matter how long that
takes.
Returns:
The traceback of the exception raised by the computation, or None if the
computation did not raise an exception.
Raises:
TimeoutError: If a timeout value is passed and the computation does not
terminate within the allotted time.
CancelledError: If the computation was cancelled.
"""
raise NotImplementedError()
@abc.abstractmethod
def add_done_callback(self, fn):
"""Adds a function to be called at completion of the computation.
The callback will be passed this Future object describing the outcome of
the computation.
If the computation has already completed, the callback will be called
immediately.
Args:
fn: A callable taking this Future object as its single parameter.
"""
raise NotImplementedError()

View file

@ -0,0 +1,72 @@
# Copyright 2015 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A thread pool that logs exceptions raised by tasks executed within it."""
import logging
from concurrent import futures
_LOGGER = logging.getLogger(__name__)
def _wrap(behavior):
"""Wraps an arbitrary callable behavior in exception-logging."""
def _wrapping(*args, **kwargs):
try:
return behavior(*args, **kwargs)
except Exception:
_LOGGER.exception(
'Unexpected exception from %s executed in logging pool!',
behavior)
raise
return _wrapping
class _LoggingPool(object):
"""An exception-logging futures.ThreadPoolExecutor-compatible thread pool."""
def __init__(self, backing_pool):
self._backing_pool = backing_pool
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self._backing_pool.shutdown(wait=True)
def submit(self, fn, *args, **kwargs):
return self._backing_pool.submit(_wrap(fn), *args, **kwargs)
def map(self, func, *iterables, **kwargs):
return self._backing_pool.map(_wrap(func),
*iterables,
timeout=kwargs.get('timeout', None))
def shutdown(self, wait=True):
self._backing_pool.shutdown(wait=wait)
def pool(max_workers):
"""Creates a thread pool that logs exceptions raised by the tasks within it.
Args:
max_workers: The maximum number of worker threads to allow the pool.
Returns:
A futures.ThreadPoolExecutor-compatible thread pool that logs exceptions
raised by the tasks executed within it.
"""
return _LoggingPool(futures.ThreadPoolExecutor(max_workers))

View file

@ -0,0 +1,45 @@
# Copyright 2015 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Interfaces related to streams of values or objects."""
import abc
import six
class Consumer(six.with_metaclass(abc.ABCMeta)):
"""Interface for consumers of finite streams of values or objects."""
@abc.abstractmethod
def consume(self, value):
"""Accepts a value.
Args:
value: Any value accepted by this Consumer.
"""
raise NotImplementedError()
@abc.abstractmethod
def terminate(self):
"""Indicates to this Consumer that no more values will be supplied."""
raise NotImplementedError()
@abc.abstractmethod
def consume_and_terminate(self, value):
"""Supplies a value and signals that no more values will be supplied.
Args:
value: Any value accepted by this Consumer.
"""
raise NotImplementedError()

View file

@ -0,0 +1,148 @@
# Copyright 2015 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helpful utilities related to the stream module."""
import logging
import threading
from grpc.framework.foundation import stream
_NO_VALUE = object()
_LOGGER = logging.getLogger(__name__)
class TransformingConsumer(stream.Consumer):
"""A stream.Consumer that passes a transformation of its input to another."""
def __init__(self, transformation, downstream):
self._transformation = transformation
self._downstream = downstream
def consume(self, value):
self._downstream.consume(self._transformation(value))
def terminate(self):
self._downstream.terminate()
def consume_and_terminate(self, value):
self._downstream.consume_and_terminate(self._transformation(value))
class IterableConsumer(stream.Consumer):
"""A Consumer that when iterated over emits the values it has consumed."""
def __init__(self):
self._condition = threading.Condition()
self._values = []
self._active = True
def consume(self, value):
with self._condition:
if self._active:
self._values.append(value)
self._condition.notify()
def terminate(self):
with self._condition:
self._active = False
self._condition.notify()
def consume_and_terminate(self, value):
with self._condition:
if self._active:
self._values.append(value)
self._active = False
self._condition.notify()
def __iter__(self):
return self
def __next__(self):
return self.next()
def next(self):
with self._condition:
while self._active and not self._values:
self._condition.wait()
if self._values:
return self._values.pop(0)
else:
raise StopIteration()
class ThreadSwitchingConsumer(stream.Consumer):
"""A Consumer decorator that affords serialization and asynchrony."""
def __init__(self, sink, pool):
self._lock = threading.Lock()
self._sink = sink
self._pool = pool
# True if self._spin has been submitted to the pool to be called once and
# that call has not yet returned, False otherwise.
self._spinning = False
self._values = []
self._active = True
def _spin(self, sink, value, terminate):
while True:
try:
if value is _NO_VALUE:
sink.terminate()
elif terminate:
sink.consume_and_terminate(value)
else:
sink.consume(value)
except Exception as e: # pylint:disable=broad-except
_LOGGER.exception(e)
with self._lock:
if terminate:
self._spinning = False
return
elif self._values:
value = self._values.pop(0)
terminate = not self._values and not self._active
elif not self._active:
value = _NO_VALUE
terminate = True
else:
self._spinning = False
return
def consume(self, value):
with self._lock:
if self._active:
if self._spinning:
self._values.append(value)
else:
self._pool.submit(self._spin, self._sink, value, False)
self._spinning = True
def terminate(self):
with self._lock:
if self._active:
self._active = False
if not self._spinning:
self._pool.submit(self._spin, self._sink, _NO_VALUE, True)
self._spinning = True
def consume_and_terminate(self, value):
with self._lock:
if self._active:
self._active = False
if self._spinning:
self._values.append(value)
else:
self._pool.submit(self._spin, self._sink, value, True)
self._spinning = True

View file

@ -0,0 +1,13 @@
# Copyright 2015 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,13 @@
# Copyright 2015 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Some files were not shown because too many files have changed in this diff Show more