390 lines
15 KiB
Python
390 lines
15 KiB
Python
# Copyright 2020 Google Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Firebase auth providers management sub module."""
|
|
|
|
from urllib import parse
|
|
|
|
import requests
|
|
|
|
from firebase_admin import _auth_utils
|
|
from firebase_admin import _user_mgt
|
|
|
|
|
|
MAX_LIST_CONFIGS_RESULTS = 100
|
|
|
|
|
|
class ProviderConfig:
|
|
"""Parent type for all authentication provider config types."""
|
|
|
|
def __init__(self, data):
|
|
self._data = data
|
|
|
|
@property
|
|
def provider_id(self):
|
|
name = self._data['name']
|
|
return name.split('/')[-1]
|
|
|
|
@property
|
|
def display_name(self):
|
|
return self._data.get('displayName')
|
|
|
|
@property
|
|
def enabled(self):
|
|
return self._data.get('enabled', False)
|
|
|
|
|
|
class OIDCProviderConfig(ProviderConfig):
|
|
"""Represents the OIDC auth provider configuration.
|
|
|
|
See https://openid.net/specs/openid-connect-core-1_0-final.html.
|
|
"""
|
|
|
|
@property
|
|
def issuer(self):
|
|
return self._data['issuer']
|
|
|
|
@property
|
|
def client_id(self):
|
|
return self._data['clientId']
|
|
|
|
|
|
class SAMLProviderConfig(ProviderConfig):
|
|
"""Represents he SAML auth provider configuration.
|
|
|
|
See http://docs.oasis-open.org/security/saml/Post2.0/sstc-saml-tech-overview-2.0.html.
|
|
"""
|
|
|
|
@property
|
|
def idp_entity_id(self):
|
|
return self._data.get('idpConfig', {})['idpEntityId']
|
|
|
|
@property
|
|
def sso_url(self):
|
|
return self._data.get('idpConfig', {})['ssoUrl']
|
|
|
|
@property
|
|
def x509_certificates(self):
|
|
certs = self._data.get('idpConfig', {})['idpCertificates']
|
|
return [c['x509Certificate'] for c in certs]
|
|
|
|
@property
|
|
def callback_url(self):
|
|
return self._data.get('spConfig', {})['callbackUri']
|
|
|
|
@property
|
|
def rp_entity_id(self):
|
|
return self._data.get('spConfig', {})['spEntityId']
|
|
|
|
|
|
class ListProviderConfigsPage:
|
|
"""Represents a page of AuthProviderConfig instances retrieved from a Firebase project.
|
|
|
|
Provides methods for traversing the provider configs included in this page, as well as
|
|
retrieving subsequent pages. The iterator returned by ``iterate_all()`` can be used to iterate
|
|
through all provider configs in the Firebase project starting from this page.
|
|
"""
|
|
|
|
def __init__(self, download, page_token, max_results):
|
|
self._download = download
|
|
self._max_results = max_results
|
|
self._current = download(page_token, max_results)
|
|
|
|
@property
|
|
def provider_configs(self):
|
|
"""A list of ``AuthProviderConfig`` instances available in this page."""
|
|
raise NotImplementedError
|
|
|
|
@property
|
|
def next_page_token(self):
|
|
"""Page token string for the next page (empty string indicates no more pages)."""
|
|
return self._current.get('nextPageToken', '')
|
|
|
|
@property
|
|
def has_next_page(self):
|
|
"""A boolean indicating whether more pages are available."""
|
|
return bool(self.next_page_token)
|
|
|
|
def get_next_page(self):
|
|
"""Retrieves the next page of provider configs, if available.
|
|
|
|
Returns:
|
|
ListProviderConfigsPage: Next page of provider configs, or None if this is the last
|
|
page.
|
|
"""
|
|
if self.has_next_page:
|
|
return self.__class__(self._download, self.next_page_token, self._max_results)
|
|
return None
|
|
|
|
def iterate_all(self):
|
|
"""Retrieves an iterator for provider configs.
|
|
|
|
Returned iterator will iterate through all the provider configs in the Firebase project
|
|
starting from this page. The iterator will never buffer more than one page of configs
|
|
in memory at a time.
|
|
|
|
Returns:
|
|
iterator: An iterator of AuthProviderConfig instances.
|
|
"""
|
|
return _ProviderConfigIterator(self)
|
|
|
|
|
|
class _ListOIDCProviderConfigsPage(ListProviderConfigsPage):
|
|
|
|
@property
|
|
def provider_configs(self):
|
|
return [OIDCProviderConfig(data) for data in self._current.get('oauthIdpConfigs', [])]
|
|
|
|
|
|
class _ListSAMLProviderConfigsPage(ListProviderConfigsPage):
|
|
|
|
@property
|
|
def provider_configs(self):
|
|
return [SAMLProviderConfig(data) for data in self._current.get('inboundSamlConfigs', [])]
|
|
|
|
|
|
class _ProviderConfigIterator(_auth_utils.PageIterator):
|
|
|
|
@property
|
|
def items(self):
|
|
return self._current_page.provider_configs
|
|
|
|
|
|
class ProviderConfigClient:
|
|
"""Client for managing Auth provider configurations."""
|
|
|
|
PROVIDER_CONFIG_URL = 'https://identitytoolkit.googleapis.com/v2beta1'
|
|
|
|
def __init__(self, http_client, project_id, tenant_id=None):
|
|
self.http_client = http_client
|
|
self.base_url = '{0}/projects/{1}'.format(self.PROVIDER_CONFIG_URL, project_id)
|
|
if tenant_id:
|
|
self.base_url += '/tenants/{0}'.format(tenant_id)
|
|
|
|
def get_oidc_provider_config(self, provider_id):
|
|
_validate_oidc_provider_id(provider_id)
|
|
body = self._make_request('get', '/oauthIdpConfigs/{0}'.format(provider_id))
|
|
return OIDCProviderConfig(body)
|
|
|
|
def create_oidc_provider_config(
|
|
self, provider_id, client_id, issuer, display_name=None, enabled=None):
|
|
"""Creates a new OIDC provider config from the given parameters."""
|
|
_validate_oidc_provider_id(provider_id)
|
|
req = {
|
|
'clientId': _validate_non_empty_string(client_id, 'client_id'),
|
|
'issuer': _validate_url(issuer, 'issuer'),
|
|
}
|
|
if display_name is not None:
|
|
req['displayName'] = _auth_utils.validate_string(display_name, 'display_name')
|
|
if enabled is not None:
|
|
req['enabled'] = _auth_utils.validate_boolean(enabled, 'enabled')
|
|
|
|
params = 'oauthIdpConfigId={0}'.format(provider_id)
|
|
body = self._make_request('post', '/oauthIdpConfigs', json=req, params=params)
|
|
return OIDCProviderConfig(body)
|
|
|
|
def update_oidc_provider_config(
|
|
self, provider_id, client_id=None, issuer=None, display_name=None, enabled=None):
|
|
"""Updates an existing OIDC provider config with the given parameters."""
|
|
_validate_oidc_provider_id(provider_id)
|
|
req = {}
|
|
if display_name is not None:
|
|
if display_name == _user_mgt.DELETE_ATTRIBUTE:
|
|
req['displayName'] = None
|
|
else:
|
|
req['displayName'] = _auth_utils.validate_string(display_name, 'display_name')
|
|
if enabled is not None:
|
|
req['enabled'] = _auth_utils.validate_boolean(enabled, 'enabled')
|
|
if client_id:
|
|
req['clientId'] = _validate_non_empty_string(client_id, 'client_id')
|
|
if issuer:
|
|
req['issuer'] = _validate_url(issuer, 'issuer')
|
|
|
|
if not req:
|
|
raise ValueError('At least one parameter must be specified for update.')
|
|
|
|
update_mask = _auth_utils.build_update_mask(req)
|
|
params = 'updateMask={0}'.format(','.join(update_mask))
|
|
url = '/oauthIdpConfigs/{0}'.format(provider_id)
|
|
body = self._make_request('patch', url, json=req, params=params)
|
|
return OIDCProviderConfig(body)
|
|
|
|
def delete_oidc_provider_config(self, provider_id):
|
|
_validate_oidc_provider_id(provider_id)
|
|
self._make_request('delete', '/oauthIdpConfigs/{0}'.format(provider_id))
|
|
|
|
def list_oidc_provider_configs(self, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS):
|
|
return _ListOIDCProviderConfigsPage(
|
|
self._fetch_oidc_provider_configs, page_token, max_results)
|
|
|
|
def _fetch_oidc_provider_configs(self, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS):
|
|
return self._fetch_provider_configs('/oauthIdpConfigs', page_token, max_results)
|
|
|
|
def get_saml_provider_config(self, provider_id):
|
|
_validate_saml_provider_id(provider_id)
|
|
body = self._make_request('get', '/inboundSamlConfigs/{0}'.format(provider_id))
|
|
return SAMLProviderConfig(body)
|
|
|
|
def create_saml_provider_config(
|
|
self, provider_id, idp_entity_id, sso_url, x509_certificates,
|
|
rp_entity_id, callback_url, display_name=None, enabled=None):
|
|
"""Creates a new SAML provider config from the given parameters."""
|
|
_validate_saml_provider_id(provider_id)
|
|
req = {
|
|
'idpConfig': {
|
|
'idpEntityId': _validate_non_empty_string(idp_entity_id, 'idp_entity_id'),
|
|
'ssoUrl': _validate_url(sso_url, 'sso_url'),
|
|
'idpCertificates': _validate_x509_certificates(x509_certificates),
|
|
},
|
|
'spConfig': {
|
|
'spEntityId': _validate_non_empty_string(rp_entity_id, 'rp_entity_id'),
|
|
'callbackUri': _validate_url(callback_url, 'callback_url'),
|
|
},
|
|
}
|
|
if display_name is not None:
|
|
req['displayName'] = _auth_utils.validate_string(display_name, 'display_name')
|
|
if enabled is not None:
|
|
req['enabled'] = _auth_utils.validate_boolean(enabled, 'enabled')
|
|
|
|
params = 'inboundSamlConfigId={0}'.format(provider_id)
|
|
body = self._make_request('post', '/inboundSamlConfigs', json=req, params=params)
|
|
return SAMLProviderConfig(body)
|
|
|
|
def update_saml_provider_config(
|
|
self, provider_id, idp_entity_id=None, sso_url=None, x509_certificates=None,
|
|
rp_entity_id=None, callback_url=None, display_name=None, enabled=None):
|
|
"""Updates an existing SAML provider config with the given parameters."""
|
|
_validate_saml_provider_id(provider_id)
|
|
idp_config = {}
|
|
if idp_entity_id is not None:
|
|
idp_config['idpEntityId'] = _validate_non_empty_string(idp_entity_id, 'idp_entity_id')
|
|
if sso_url is not None:
|
|
idp_config['ssoUrl'] = _validate_url(sso_url, 'sso_url')
|
|
if x509_certificates is not None:
|
|
idp_config['idpCertificates'] = _validate_x509_certificates(x509_certificates)
|
|
|
|
sp_config = {}
|
|
if rp_entity_id is not None:
|
|
sp_config['spEntityId'] = _validate_non_empty_string(rp_entity_id, 'rp_entity_id')
|
|
if callback_url is not None:
|
|
sp_config['callbackUri'] = _validate_url(callback_url, 'callback_url')
|
|
|
|
req = {}
|
|
if display_name is not None:
|
|
if display_name == _user_mgt.DELETE_ATTRIBUTE:
|
|
req['displayName'] = None
|
|
else:
|
|
req['displayName'] = _auth_utils.validate_string(display_name, 'display_name')
|
|
if enabled is not None:
|
|
req['enabled'] = _auth_utils.validate_boolean(enabled, 'enabled')
|
|
if idp_config:
|
|
req['idpConfig'] = idp_config
|
|
if sp_config:
|
|
req['spConfig'] = sp_config
|
|
|
|
if not req:
|
|
raise ValueError('At least one parameter must be specified for update.')
|
|
|
|
update_mask = _auth_utils.build_update_mask(req)
|
|
params = 'updateMask={0}'.format(','.join(update_mask))
|
|
url = '/inboundSamlConfigs/{0}'.format(provider_id)
|
|
body = self._make_request('patch', url, json=req, params=params)
|
|
return SAMLProviderConfig(body)
|
|
|
|
def delete_saml_provider_config(self, provider_id):
|
|
_validate_saml_provider_id(provider_id)
|
|
self._make_request('delete', '/inboundSamlConfigs/{0}'.format(provider_id))
|
|
|
|
def list_saml_provider_configs(self, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS):
|
|
return _ListSAMLProviderConfigsPage(
|
|
self._fetch_saml_provider_configs, page_token, max_results)
|
|
|
|
def _fetch_saml_provider_configs(self, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS):
|
|
return self._fetch_provider_configs('/inboundSamlConfigs', page_token, max_results)
|
|
|
|
def _fetch_provider_configs(self, path, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS):
|
|
"""Fetches a page of auth provider configs"""
|
|
if page_token is not None:
|
|
if not isinstance(page_token, str) or not page_token:
|
|
raise ValueError('Page token must be a non-empty string.')
|
|
if not isinstance(max_results, int):
|
|
raise ValueError('Max results must be an integer.')
|
|
if max_results < 1 or max_results > MAX_LIST_CONFIGS_RESULTS:
|
|
raise ValueError(
|
|
'Max results must be a positive integer less than or equal to '
|
|
'{0}.'.format(MAX_LIST_CONFIGS_RESULTS))
|
|
|
|
params = 'pageSize={0}'.format(max_results)
|
|
if page_token:
|
|
params += '&pageToken={0}'.format(page_token)
|
|
return self._make_request('get', path, params=params)
|
|
|
|
def _make_request(self, method, path, **kwargs):
|
|
url = '{0}{1}'.format(self.base_url, path)
|
|
try:
|
|
return self.http_client.body(method, url, **kwargs)
|
|
except requests.exceptions.RequestException as error:
|
|
raise _auth_utils.handle_auth_backend_error(error)
|
|
|
|
|
|
def _validate_oidc_provider_id(provider_id):
|
|
if not isinstance(provider_id, str):
|
|
raise ValueError(
|
|
'Invalid OIDC provider ID: {0}. Provider ID must be a non-empty string.'.format(
|
|
provider_id))
|
|
if not provider_id.startswith('oidc.'):
|
|
raise ValueError('Invalid OIDC provider ID: {0}.'.format(provider_id))
|
|
return provider_id
|
|
|
|
|
|
def _validate_saml_provider_id(provider_id):
|
|
if not isinstance(provider_id, str):
|
|
raise ValueError(
|
|
'Invalid SAML provider ID: {0}. Provider ID must be a non-empty string.'.format(
|
|
provider_id))
|
|
if not provider_id.startswith('saml.'):
|
|
raise ValueError('Invalid SAML provider ID: {0}.'.format(provider_id))
|
|
return provider_id
|
|
|
|
|
|
def _validate_non_empty_string(value, label):
|
|
"""Validates that the given value is a non-empty string."""
|
|
if not isinstance(value, str):
|
|
raise ValueError('Invalid type for {0}: {1}.'.format(label, value))
|
|
if not value:
|
|
raise ValueError('{0} must not be empty.'.format(label))
|
|
return value
|
|
|
|
|
|
def _validate_url(url, label):
|
|
"""Validates that the given value is a well-formed URL string."""
|
|
if not isinstance(url, str) or not url:
|
|
raise ValueError(
|
|
'Invalid photo URL: "{0}". {1} must be a non-empty '
|
|
'string.'.format(url, label))
|
|
try:
|
|
parsed = parse.urlparse(url)
|
|
if not parsed.netloc:
|
|
raise ValueError('Malformed {0}: "{1}".'.format(label, url))
|
|
return url
|
|
except Exception:
|
|
raise ValueError('Malformed {0}: "{1}".'.format(label, url))
|
|
|
|
|
|
def _validate_x509_certificates(x509_certificates):
|
|
if not isinstance(x509_certificates, list) or not x509_certificates:
|
|
raise ValueError('x509_certificates must be a non-empty list.')
|
|
if not all([isinstance(cert, str) and cert for cert in x509_certificates]):
|
|
raise ValueError('x509_certificates must only contain non-empty strings.')
|
|
return [{'x509Certificate': cert} for cert in x509_certificates]
|