245 lines
8.4 KiB
Python
245 lines
8.4 KiB
Python
# Copyright 2016 Google Inc. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import logging
|
|
|
|
import httplib2
|
|
import six
|
|
from six.moves import http_client
|
|
|
|
from oauth2client._helpers import _to_bytes
|
|
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
# Properties present in file-like streams / buffers.
|
|
_STREAM_PROPERTIES = ('read', 'seek', 'tell')
|
|
|
|
# Google Data client libraries may need to set this to [401, 403].
|
|
REFRESH_STATUS_CODES = (http_client.UNAUTHORIZED,)
|
|
|
|
|
|
class MemoryCache(object):
|
|
"""httplib2 Cache implementation which only caches locally."""
|
|
|
|
def __init__(self):
|
|
self.cache = {}
|
|
|
|
def get(self, key):
|
|
return self.cache.get(key)
|
|
|
|
def set(self, key, value):
|
|
self.cache[key] = value
|
|
|
|
def delete(self, key):
|
|
self.cache.pop(key, None)
|
|
|
|
|
|
def get_cached_http():
|
|
"""Return an HTTP object which caches results returned.
|
|
|
|
This is intended to be used in methods like
|
|
oauth2client.client.verify_id_token(), which calls to the same URI
|
|
to retrieve certs.
|
|
|
|
Returns:
|
|
httplib2.Http, an HTTP object with a MemoryCache
|
|
"""
|
|
return _CACHED_HTTP
|
|
|
|
|
|
def get_http_object():
|
|
"""Return a new HTTP object.
|
|
|
|
Returns:
|
|
httplib2.Http, an HTTP object.
|
|
"""
|
|
return httplib2.Http()
|
|
|
|
|
|
def _initialize_headers(headers):
|
|
"""Creates a copy of the headers.
|
|
|
|
Args:
|
|
headers: dict, request headers to copy.
|
|
|
|
Returns:
|
|
dict, the copied headers or a new dictionary if the headers
|
|
were None.
|
|
"""
|
|
return {} if headers is None else dict(headers)
|
|
|
|
|
|
def _apply_user_agent(headers, user_agent):
|
|
"""Adds a user-agent to the headers.
|
|
|
|
Args:
|
|
headers: dict, request headers to add / modify user
|
|
agent within.
|
|
user_agent: str, the user agent to add.
|
|
|
|
Returns:
|
|
dict, the original headers passed in, but modified if the
|
|
user agent is not None.
|
|
"""
|
|
if user_agent is not None:
|
|
if 'user-agent' in headers:
|
|
headers['user-agent'] = (user_agent + ' ' + headers['user-agent'])
|
|
else:
|
|
headers['user-agent'] = user_agent
|
|
|
|
return headers
|
|
|
|
|
|
def clean_headers(headers):
|
|
"""Forces header keys and values to be strings, i.e not unicode.
|
|
|
|
The httplib module just concats the header keys and values in a way that
|
|
may make the message header a unicode string, which, if it then tries to
|
|
contatenate to a binary request body may result in a unicode decode error.
|
|
|
|
Args:
|
|
headers: dict, A dictionary of headers.
|
|
|
|
Returns:
|
|
The same dictionary but with all the keys converted to strings.
|
|
"""
|
|
clean = {}
|
|
try:
|
|
for k, v in six.iteritems(headers):
|
|
if not isinstance(k, six.binary_type):
|
|
k = str(k)
|
|
if not isinstance(v, six.binary_type):
|
|
v = str(v)
|
|
clean[_to_bytes(k)] = _to_bytes(v)
|
|
except UnicodeEncodeError:
|
|
from oauth2client.client import NonAsciiHeaderError
|
|
raise NonAsciiHeaderError(k, ': ', v)
|
|
return clean
|
|
|
|
|
|
def wrap_http_for_auth(credentials, http):
|
|
"""Prepares an HTTP object's request method for auth.
|
|
|
|
Wraps HTTP requests with logic to catch auth failures (typically
|
|
identified via a 401 status code). In the event of failure, tries
|
|
to refresh the token used and then retry the original request.
|
|
|
|
Args:
|
|
credentials: Credentials, the credentials used to identify
|
|
the authenticated user.
|
|
http: httplib2.Http, an http object to be used to make
|
|
auth requests.
|
|
"""
|
|
orig_request_method = http.request
|
|
|
|
# The closure that will replace 'httplib2.Http.request'.
|
|
def new_request(uri, method='GET', body=None, headers=None,
|
|
redirections=httplib2.DEFAULT_MAX_REDIRECTS,
|
|
connection_type=None):
|
|
if not credentials.access_token:
|
|
_LOGGER.info('Attempting refresh to obtain '
|
|
'initial access_token')
|
|
credentials._refresh(orig_request_method)
|
|
|
|
# Clone and modify the request headers to add the appropriate
|
|
# Authorization header.
|
|
headers = _initialize_headers(headers)
|
|
credentials.apply(headers)
|
|
_apply_user_agent(headers, credentials.user_agent)
|
|
|
|
body_stream_position = None
|
|
# Check if the body is a file-like stream.
|
|
if all(getattr(body, stream_prop, None) for stream_prop in
|
|
_STREAM_PROPERTIES):
|
|
body_stream_position = body.tell()
|
|
|
|
resp, content = orig_request_method(uri, method, body,
|
|
clean_headers(headers),
|
|
redirections, connection_type)
|
|
|
|
# A stored token may expire between the time it is retrieved and
|
|
# the time the request is made, so we may need to try twice.
|
|
max_refresh_attempts = 2
|
|
for refresh_attempt in range(max_refresh_attempts):
|
|
if resp.status not in REFRESH_STATUS_CODES:
|
|
break
|
|
_LOGGER.info('Refreshing due to a %s (attempt %s/%s)',
|
|
resp.status, refresh_attempt + 1,
|
|
max_refresh_attempts)
|
|
credentials._refresh(orig_request_method)
|
|
credentials.apply(headers)
|
|
if body_stream_position is not None:
|
|
body.seek(body_stream_position)
|
|
|
|
resp, content = orig_request_method(uri, method, body,
|
|
clean_headers(headers),
|
|
redirections, connection_type)
|
|
|
|
return resp, content
|
|
|
|
# Replace the request method with our own closure.
|
|
http.request = new_request
|
|
|
|
# Set credentials as a property of the request method.
|
|
setattr(http.request, 'credentials', credentials)
|
|
|
|
|
|
def wrap_http_for_jwt_access(credentials, http):
|
|
"""Prepares an HTTP object's request method for JWT access.
|
|
|
|
Wraps HTTP requests with logic to catch auth failures (typically
|
|
identified via a 401 status code). In the event of failure, tries
|
|
to refresh the token used and then retry the original request.
|
|
|
|
Args:
|
|
credentials: _JWTAccessCredentials, the credentials used to identify
|
|
a service account that uses JWT access tokens.
|
|
http: httplib2.Http, an http object to be used to make
|
|
auth requests.
|
|
"""
|
|
orig_request_method = http.request
|
|
wrap_http_for_auth(credentials, http)
|
|
# The new value of ``http.request`` set by ``wrap_http_for_auth``.
|
|
authenticated_request_method = http.request
|
|
|
|
# The closure that will replace 'httplib2.Http.request'.
|
|
def new_request(uri, method='GET', body=None, headers=None,
|
|
redirections=httplib2.DEFAULT_MAX_REDIRECTS,
|
|
connection_type=None):
|
|
if 'aud' in credentials._kwargs:
|
|
# Preemptively refresh token, this is not done for OAuth2
|
|
if (credentials.access_token is None or
|
|
credentials.access_token_expired):
|
|
credentials.refresh(None)
|
|
return authenticated_request_method(uri, method, body,
|
|
headers, redirections,
|
|
connection_type)
|
|
else:
|
|
# If we don't have an 'aud' (audience) claim,
|
|
# create a 1-time token with the uri root as the audience
|
|
headers = _initialize_headers(headers)
|
|
_apply_user_agent(headers, credentials.user_agent)
|
|
uri_root = uri.split('?', 1)[0]
|
|
token, unused_expiry = credentials._create_token({'aud': uri_root})
|
|
|
|
headers['Authorization'] = 'Bearer ' + token
|
|
return orig_request_method(uri, method, body,
|
|
clean_headers(headers),
|
|
redirections, connection_type)
|
|
|
|
# Replace the request method with our own closure.
|
|
http.request = new_request
|
|
|
|
|
|
_CACHED_HTTP = httplib2.Http(MemoryCache())
|