# Copyright 2019 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Invocation-side implementation of gRPC Asyncio Python."""

import asyncio
import sys
from typing import Any, Iterable, Optional, Sequence, List

import grpc
from grpc import _common, _compression, _grpcio_metadata
from grpc._cython import cygrpc

from . import _base_call, _base_channel
from ._call import (StreamStreamCall, StreamUnaryCall, UnaryStreamCall,
                    UnaryUnaryCall)
from ._interceptor import (
    InterceptedUnaryUnaryCall, InterceptedUnaryStreamCall,
    InterceptedStreamUnaryCall, InterceptedStreamStreamCall, ClientInterceptor,
    UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor,
    StreamUnaryClientInterceptor, StreamStreamClientInterceptor)
from ._metadata import Metadata
from ._typing import (ChannelArgumentType, DeserializingFunction,
                      SerializingFunction, RequestIterableType)
from ._utils import _timeout_to_deadline

_USER_AGENT = 'grpc-python-asyncio/{}'.format(_grpcio_metadata.__version__)

if sys.version_info[1] < 7:

    def _all_tasks() -> Iterable[asyncio.Task]:
        return asyncio.Task.all_tasks()
else:

    def _all_tasks() -> Iterable[asyncio.Task]:
        return asyncio.all_tasks()


def _augment_channel_arguments(base_options: ChannelArgumentType,
                               compression: Optional[grpc.Compression]):
    compression_channel_argument = _compression.create_channel_option(
        compression)
    user_agent_channel_argument = ((
        cygrpc.ChannelArgKey.primary_user_agent_string,
        _USER_AGENT,
    ),)
    return tuple(base_options
                ) + compression_channel_argument + user_agent_channel_argument


class _BaseMultiCallable:
    """Base class of all multi callable objects.

    Handles the initialization logic and stores common attributes.
    """
    _loop: asyncio.AbstractEventLoop
    _channel: cygrpc.AioChannel
    _method: bytes
    _request_serializer: SerializingFunction
    _response_deserializer: DeserializingFunction
    _interceptors: Optional[Sequence[ClientInterceptor]]
    _loop: asyncio.AbstractEventLoop

    # pylint: disable=too-many-arguments
    def __init__(
            self,
            channel: cygrpc.AioChannel,
            method: bytes,
            request_serializer: SerializingFunction,
            response_deserializer: DeserializingFunction,
            interceptors: Optional[Sequence[ClientInterceptor]],
            loop: asyncio.AbstractEventLoop,
    ) -> None:
        self._loop = loop
        self._channel = channel
        self._method = method
        self._request_serializer = request_serializer
        self._response_deserializer = response_deserializer
        self._interceptors = interceptors

    @staticmethod
    def _init_metadata(metadata: Optional[Metadata] = None,
                       compression: Optional[grpc.Compression] = None
                      ) -> Metadata:
        """Based on the provided values for <metadata> or <compression> initialise the final
        metadata, as it should be used for the current call.
        """
        metadata = metadata or Metadata()
        if compression:
            metadata = Metadata(
                *_compression.augment_metadata(metadata, compression))
        return metadata


class UnaryUnaryMultiCallable(_BaseMultiCallable,
                              _base_channel.UnaryUnaryMultiCallable):

    def __call__(self,
                 request: Any,
                 *,
                 timeout: Optional[float] = None,
                 metadata: Optional[Metadata] = None,
                 credentials: Optional[grpc.CallCredentials] = None,
                 wait_for_ready: Optional[bool] = None,
                 compression: Optional[grpc.Compression] = None
                ) -> _base_call.UnaryUnaryCall:

        metadata = self._init_metadata(metadata, compression)
        if not self._interceptors:
            call = UnaryUnaryCall(request, _timeout_to_deadline(timeout),
                                  metadata, credentials, wait_for_ready,
                                  self._channel, self._method,
                                  self._request_serializer,
                                  self._response_deserializer, self._loop)
        else:
            call = InterceptedUnaryUnaryCall(
                self._interceptors, request, timeout, metadata, credentials,
                wait_for_ready, self._channel, self._method,
                self._request_serializer, self._response_deserializer,
                self._loop)

        return call


class UnaryStreamMultiCallable(_BaseMultiCallable,
                               _base_channel.UnaryStreamMultiCallable):

    def __call__(self,
                 request: Any,
                 *,
                 timeout: Optional[float] = None,
                 metadata: Optional[Metadata] = None,
                 credentials: Optional[grpc.CallCredentials] = None,
                 wait_for_ready: Optional[bool] = None,
                 compression: Optional[grpc.Compression] = None
                ) -> _base_call.UnaryStreamCall:

        metadata = self._init_metadata(metadata, compression)
        deadline = _timeout_to_deadline(timeout)

        if not self._interceptors:
            call = UnaryStreamCall(request, deadline, metadata, credentials,
                                   wait_for_ready, self._channel, self._method,
                                   self._request_serializer,
                                   self._response_deserializer, self._loop)
        else:
            call = InterceptedUnaryStreamCall(
                self._interceptors, request, deadline, metadata, credentials,
                wait_for_ready, self._channel, self._method,
                self._request_serializer, self._response_deserializer,
                self._loop)

        return call


class StreamUnaryMultiCallable(_BaseMultiCallable,
                               _base_channel.StreamUnaryMultiCallable):

    def __call__(self,
                 request_iterator: Optional[RequestIterableType] = None,
                 timeout: Optional[float] = None,
                 metadata: Optional[Metadata] = None,
                 credentials: Optional[grpc.CallCredentials] = None,
                 wait_for_ready: Optional[bool] = None,
                 compression: Optional[grpc.Compression] = None
                ) -> _base_call.StreamUnaryCall:

        metadata = self._init_metadata(metadata, compression)
        deadline = _timeout_to_deadline(timeout)

        if not self._interceptors:
            call = StreamUnaryCall(request_iterator, deadline, metadata,
                                   credentials, wait_for_ready, self._channel,
                                   self._method, self._request_serializer,
                                   self._response_deserializer, self._loop)
        else:
            call = InterceptedStreamUnaryCall(
                self._interceptors, request_iterator, deadline, metadata,
                credentials, wait_for_ready, self._channel, self._method,
                self._request_serializer, self._response_deserializer,
                self._loop)

        return call


class StreamStreamMultiCallable(_BaseMultiCallable,
                                _base_channel.StreamStreamMultiCallable):

    def __call__(self,
                 request_iterator: Optional[RequestIterableType] = None,
                 timeout: Optional[float] = None,
                 metadata: Optional[Metadata] = None,
                 credentials: Optional[grpc.CallCredentials] = None,
                 wait_for_ready: Optional[bool] = None,
                 compression: Optional[grpc.Compression] = None
                ) -> _base_call.StreamStreamCall:

        metadata = self._init_metadata(metadata, compression)
        deadline = _timeout_to_deadline(timeout)

        if not self._interceptors:
            call = StreamStreamCall(request_iterator, deadline, metadata,
                                    credentials, wait_for_ready, self._channel,
                                    self._method, self._request_serializer,
                                    self._response_deserializer, self._loop)
        else:
            call = InterceptedStreamStreamCall(
                self._interceptors, request_iterator, deadline, metadata,
                credentials, wait_for_ready, self._channel, self._method,
                self._request_serializer, self._response_deserializer,
                self._loop)

        return call


class Channel(_base_channel.Channel):
    _loop: asyncio.AbstractEventLoop
    _channel: cygrpc.AioChannel
    _unary_unary_interceptors: List[UnaryUnaryClientInterceptor]
    _unary_stream_interceptors: List[UnaryStreamClientInterceptor]
    _stream_unary_interceptors: List[StreamUnaryClientInterceptor]
    _stream_stream_interceptors: List[StreamStreamClientInterceptor]

    def __init__(self, target: str, options: ChannelArgumentType,
                 credentials: Optional[grpc.ChannelCredentials],
                 compression: Optional[grpc.Compression],
                 interceptors: Optional[Sequence[ClientInterceptor]]):
        """Constructor.

        Args:
          target: The target to which to connect.
          options: Configuration options for the channel.
          credentials: A cygrpc.ChannelCredentials or None.
          compression: An optional value indicating the compression method to be
            used over the lifetime of the channel.
          interceptors: An optional list of interceptors that would be used for
            intercepting any RPC executed with that channel.
        """
        self._unary_unary_interceptors = []
        self._unary_stream_interceptors = []
        self._stream_unary_interceptors = []
        self._stream_stream_interceptors = []

        if interceptors is not None:
            for interceptor in interceptors:
                if isinstance(interceptor, UnaryUnaryClientInterceptor):
                    self._unary_unary_interceptors.append(interceptor)
                elif isinstance(interceptor, UnaryStreamClientInterceptor):
                    self._unary_stream_interceptors.append(interceptor)
                elif isinstance(interceptor, StreamUnaryClientInterceptor):
                    self._stream_unary_interceptors.append(interceptor)
                elif isinstance(interceptor, StreamStreamClientInterceptor):
                    self._stream_stream_interceptors.append(interceptor)
                else:
                    raise ValueError(
                        "Interceptor {} must be ".format(interceptor) +
                        "{} or ".format(UnaryUnaryClientInterceptor.__name__) +
                        "{} or ".format(UnaryStreamClientInterceptor.__name__) +
                        "{} or ".format(StreamUnaryClientInterceptor.__name__) +
                        "{}. ".format(StreamStreamClientInterceptor.__name__))

        self._loop = cygrpc.get_working_loop()
        self._channel = cygrpc.AioChannel(
            _common.encode(target),
            _augment_channel_arguments(options, compression), credentials,
            self._loop)

    async def __aenter__(self):
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        await self._close(None)

    async def _close(self, grace):  # pylint: disable=too-many-branches
        if self._channel.closed():
            return

        # No new calls will be accepted by the Cython channel.
        self._channel.closing()

        # Iterate through running tasks
        tasks = _all_tasks()
        calls = []
        call_tasks = []
        for task in tasks:
            try:
                stack = task.get_stack(limit=1)
            except AttributeError as attribute_error:
                # NOTE(lidiz) tl;dr: If the Task is created with a CPython
                # object, it will trigger AttributeError.
                #
                # In the global finalizer, the event loop schedules
                # a CPython PyAsyncGenAThrow object.
                # https://github.com/python/cpython/blob/00e45877e33d32bb61aa13a2033e3bba370bda4d/Lib/asyncio/base_events.py#L484
                #
                # However, the PyAsyncGenAThrow object is written in C and
                # failed to include the normal Python frame objects. Hence,
                # this exception is a false negative, and it is safe to ignore
                # the failure. It is fixed by https://github.com/python/cpython/pull/18669,
                # but not available until 3.9 or 3.8.3. So, we have to keep it
                # for a while.
                # TODO(lidiz) drop this hack after 3.8 deprecation
                if 'frame' in str(attribute_error):
                    continue
                else:
                    raise

            # If the Task is created by a C-extension, the stack will be empty.
            if not stack:
                continue

            # Locate ones created by `aio.Call`.
            frame = stack[0]
            candidate = frame.f_locals.get('self')
            if candidate:
                if isinstance(candidate, _base_call.Call):
                    if hasattr(candidate, '_channel'):
                        # For intercepted Call object
                        if candidate._channel is not self._channel:
                            continue
                    elif hasattr(candidate, '_cython_call'):
                        # For normal Call object
                        if candidate._cython_call._channel is not self._channel:
                            continue
                    else:
                        # Unidentified Call object
                        raise cygrpc.InternalError(
                            f'Unrecognized call object: {candidate}')

                    calls.append(candidate)
                    call_tasks.append(task)

        # If needed, try to wait for them to finish.
        # Call objects are not always awaitables.
        if grace and call_tasks:
            await asyncio.wait(call_tasks, timeout=grace, loop=self._loop)

        # Time to cancel existing calls.
        for call in calls:
            call.cancel()

        # Destroy the channel
        self._channel.close()

    async def close(self, grace: Optional[float] = None):
        await self._close(grace)

    def get_state(self,
                  try_to_connect: bool = False) -> grpc.ChannelConnectivity:
        result = self._channel.check_connectivity_state(try_to_connect)
        return _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[result]

    async def wait_for_state_change(
            self,
            last_observed_state: grpc.ChannelConnectivity,
    ) -> None:
        assert await self._channel.watch_connectivity_state(
            last_observed_state.value[0], None)

    async def channel_ready(self) -> None:
        state = self.get_state(try_to_connect=True)
        while state != grpc.ChannelConnectivity.READY:
            await self.wait_for_state_change(state)
            state = self.get_state(try_to_connect=True)

    def unary_unary(
            self,
            method: str,
            request_serializer: Optional[SerializingFunction] = None,
            response_deserializer: Optional[DeserializingFunction] = None
    ) -> UnaryUnaryMultiCallable:
        return UnaryUnaryMultiCallable(self._channel, _common.encode(method),
                                       request_serializer,
                                       response_deserializer,
                                       self._unary_unary_interceptors,
                                       self._loop)

    def unary_stream(
            self,
            method: str,
            request_serializer: Optional[SerializingFunction] = None,
            response_deserializer: Optional[DeserializingFunction] = None
    ) -> UnaryStreamMultiCallable:
        return UnaryStreamMultiCallable(self._channel, _common.encode(method),
                                        request_serializer,
                                        response_deserializer,
                                        self._unary_stream_interceptors,
                                        self._loop)

    def stream_unary(
            self,
            method: str,
            request_serializer: Optional[SerializingFunction] = None,
            response_deserializer: Optional[DeserializingFunction] = None
    ) -> StreamUnaryMultiCallable:
        return StreamUnaryMultiCallable(self._channel, _common.encode(method),
                                        request_serializer,
                                        response_deserializer,
                                        self._stream_unary_interceptors,
                                        self._loop)

    def stream_stream(
            self,
            method: str,
            request_serializer: Optional[SerializingFunction] = None,
            response_deserializer: Optional[DeserializingFunction] = None
    ) -> StreamStreamMultiCallable:
        return StreamStreamMultiCallable(self._channel, _common.encode(method),
                                         request_serializer,
                                         response_deserializer,
                                         self._stream_stream_interceptors,
                                         self._loop)


def insecure_channel(
        target: str,
        options: Optional[ChannelArgumentType] = None,
        compression: Optional[grpc.Compression] = None,
        interceptors: Optional[Sequence[ClientInterceptor]] = None):
    """Creates an insecure asynchronous Channel to a server.

    Args:
      target: The server address
      options: An optional list of key-value pairs (:term:`channel_arguments`
        in gRPC Core runtime) to configure the channel.
      compression: An optional value indicating the compression method to be
        used over the lifetime of the channel. This is an EXPERIMENTAL option.
      interceptors: An optional sequence of interceptors that will be executed for
        any call executed with this channel.

    Returns:
      A Channel.
    """
    return Channel(target, () if options is None else options, None,
                   compression, interceptors)


def secure_channel(target: str,
                   credentials: grpc.ChannelCredentials,
                   options: Optional[ChannelArgumentType] = None,
                   compression: Optional[grpc.Compression] = None,
                   interceptors: Optional[Sequence[ClientInterceptor]] = None):
    """Creates a secure asynchronous Channel to a server.

    Args:
      target: The server address.
      credentials: A ChannelCredentials instance.
      options: An optional list of key-value pairs (:term:`channel_arguments`
        in gRPC Core runtime) to configure the channel.
      compression: An optional value indicating the compression method to be
        used over the lifetime of the channel. This is an EXPERIMENTAL option.
      interceptors: An optional sequence of interceptors that will be executed for
        any call executed with this channel.

    Returns:
      An aio.Channel.
    """
    return Channel(target, () if options is None else options,
                   credentials._credentials, compression, interceptors)