
391 lines
15 KiB
Raw Permalink Normal View History

# Copyright 2020 Google Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
"""Firebase auth providers management sub module."""
from urllib import parse
import requests
from firebase_admin import _auth_utils
from firebase_admin import _user_mgt
class ProviderConfig:
"""Parent type for all authentication provider config types."""
def __init__(self, data):
self._data = data
def provider_id(self):
name = self._data['name']
return name.split('/')[-1]
def display_name(self):
return self._data.get('displayName')
def enabled(self):
return self._data.get('enabled', False)
class OIDCProviderConfig(ProviderConfig):
"""Represents the OIDC auth provider configuration.
def issuer(self):
return self._data['issuer']
def client_id(self):
return self._data['clientId']
class SAMLProviderConfig(ProviderConfig):
"""Represents he SAML auth provider configuration.
def idp_entity_id(self):
return self._data.get('idpConfig', {})['idpEntityId']
def sso_url(self):
return self._data.get('idpConfig', {})['ssoUrl']
def x509_certificates(self):
certs = self._data.get('idpConfig', {})['idpCertificates']
return [c['x509Certificate'] for c in certs]
def callback_url(self):
return self._data.get('spConfig', {})['callbackUri']
def rp_entity_id(self):
return self._data.get('spConfig', {})['spEntityId']
class ListProviderConfigsPage:
"""Represents a page of AuthProviderConfig instances retrieved from a Firebase project.
Provides methods for traversing the provider configs included in this page, as well as
retrieving subsequent pages. The iterator returned by ``iterate_all()`` can be used to iterate
through all provider configs in the Firebase project starting from this page.
def __init__(self, download, page_token, max_results):
self._download = download
self._max_results = max_results
self._current = download(page_token, max_results)
def provider_configs(self):
"""A list of ``AuthProviderConfig`` instances available in this page."""
raise NotImplementedError
def next_page_token(self):
"""Page token string for the next page (empty string indicates no more pages)."""
return self._current.get('nextPageToken', '')
def has_next_page(self):
"""A boolean indicating whether more pages are available."""
return bool(self.next_page_token)
def get_next_page(self):
"""Retrieves the next page of provider configs, if available.
ListProviderConfigsPage: Next page of provider configs, or None if this is the last
if self.has_next_page:
return self.__class__(self._download, self.next_page_token, self._max_results)
return None
def iterate_all(self):
"""Retrieves an iterator for provider configs.
Returned iterator will iterate through all the provider configs in the Firebase project
starting from this page. The iterator will never buffer more than one page of configs
in memory at a time.
iterator: An iterator of AuthProviderConfig instances.
return _ProviderConfigIterator(self)
class _ListOIDCProviderConfigsPage(ListProviderConfigsPage):
def provider_configs(self):
return [OIDCProviderConfig(data) for data in self._current.get('oauthIdpConfigs', [])]
class _ListSAMLProviderConfigsPage(ListProviderConfigsPage):
def provider_configs(self):
return [SAMLProviderConfig(data) for data in self._current.get('inboundSamlConfigs', [])]
class _ProviderConfigIterator(_auth_utils.PageIterator):
def items(self):
return self._current_page.provider_configs
class ProviderConfigClient:
"""Client for managing Auth provider configurations."""
def __init__(self, http_client, project_id, tenant_id=None):
self.http_client = http_client
self.base_url = '{0}/projects/{1}'.format(self.PROVIDER_CONFIG_URL, project_id)
if tenant_id:
self.base_url += '/tenants/{0}'.format(tenant_id)
def get_oidc_provider_config(self, provider_id):
body = self._make_request('get', '/oauthIdpConfigs/{0}'.format(provider_id))
return OIDCProviderConfig(body)
def create_oidc_provider_config(
self, provider_id, client_id, issuer, display_name=None, enabled=None):
"""Creates a new OIDC provider config from the given parameters."""
req = {
'clientId': _validate_non_empty_string(client_id, 'client_id'),
'issuer': _validate_url(issuer, 'issuer'),
if display_name is not None:
req['displayName'] = _auth_utils.validate_string(display_name, 'display_name')
if enabled is not None:
req['enabled'] = _auth_utils.validate_boolean(enabled, 'enabled')
params = 'oauthIdpConfigId={0}'.format(provider_id)
body = self._make_request('post', '/oauthIdpConfigs', json=req, params=params)
return OIDCProviderConfig(body)
def update_oidc_provider_config(
self, provider_id, client_id=None, issuer=None, display_name=None, enabled=None):
"""Updates an existing OIDC provider config with the given parameters."""
req = {}
if display_name is not None:
if display_name == _user_mgt.DELETE_ATTRIBUTE:
req['displayName'] = None
req['displayName'] = _auth_utils.validate_string(display_name, 'display_name')
if enabled is not None:
req['enabled'] = _auth_utils.validate_boolean(enabled, 'enabled')
if client_id:
req['clientId'] = _validate_non_empty_string(client_id, 'client_id')
if issuer:
req['issuer'] = _validate_url(issuer, 'issuer')
if not req:
raise ValueError('At least one parameter must be specified for update.')
update_mask = _auth_utils.build_update_mask(req)
params = 'updateMask={0}'.format(','.join(update_mask))
url = '/oauthIdpConfigs/{0}'.format(provider_id)
body = self._make_request('patch', url, json=req, params=params)
return OIDCProviderConfig(body)
def delete_oidc_provider_config(self, provider_id):
self._make_request('delete', '/oauthIdpConfigs/{0}'.format(provider_id))
def list_oidc_provider_configs(self, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS):
return _ListOIDCProviderConfigsPage(
self._fetch_oidc_provider_configs, page_token, max_results)
def _fetch_oidc_provider_configs(self, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS):
return self._fetch_provider_configs('/oauthIdpConfigs', page_token, max_results)
def get_saml_provider_config(self, provider_id):
body = self._make_request('get', '/inboundSamlConfigs/{0}'.format(provider_id))
return SAMLProviderConfig(body)
def create_saml_provider_config(
self, provider_id, idp_entity_id, sso_url, x509_certificates,
rp_entity_id, callback_url, display_name=None, enabled=None):
"""Creates a new SAML provider config from the given parameters."""
req = {
'idpConfig': {
'idpEntityId': _validate_non_empty_string(idp_entity_id, 'idp_entity_id'),
'ssoUrl': _validate_url(sso_url, 'sso_url'),
'idpCertificates': _validate_x509_certificates(x509_certificates),
'spConfig': {
'spEntityId': _validate_non_empty_string(rp_entity_id, 'rp_entity_id'),
'callbackUri': _validate_url(callback_url, 'callback_url'),
if display_name is not None:
req['displayName'] = _auth_utils.validate_string(display_name, 'display_name')
if enabled is not None:
req['enabled'] = _auth_utils.validate_boolean(enabled, 'enabled')
params = 'inboundSamlConfigId={0}'.format(provider_id)
body = self._make_request('post', '/inboundSamlConfigs', json=req, params=params)
return SAMLProviderConfig(body)
def update_saml_provider_config(
self, provider_id, idp_entity_id=None, sso_url=None, x509_certificates=None,
rp_entity_id=None, callback_url=None, display_name=None, enabled=None):
"""Updates an existing SAML provider config with the given parameters."""
idp_config = {}
if idp_entity_id is not None:
idp_config['idpEntityId'] = _validate_non_empty_string(idp_entity_id, 'idp_entity_id')
if sso_url is not None:
idp_config['ssoUrl'] = _validate_url(sso_url, 'sso_url')
if x509_certificates is not None:
idp_config['idpCertificates'] = _validate_x509_certificates(x509_certificates)
sp_config = {}
if rp_entity_id is not None:
sp_config['spEntityId'] = _validate_non_empty_string(rp_entity_id, 'rp_entity_id')
if callback_url is not None:
sp_config['callbackUri'] = _validate_url(callback_url, 'callback_url')
req = {}
if display_name is not None:
if display_name == _user_mgt.DELETE_ATTRIBUTE:
req['displayName'] = None
req['displayName'] = _auth_utils.validate_string(display_name, 'display_name')
if enabled is not None:
req['enabled'] = _auth_utils.validate_boolean(enabled, 'enabled')
if idp_config:
req['idpConfig'] = idp_config
if sp_config:
req['spConfig'] = sp_config
if not req:
raise ValueError('At least one parameter must be specified for update.')
update_mask = _auth_utils.build_update_mask(req)
params = 'updateMask={0}'.format(','.join(update_mask))
url = '/inboundSamlConfigs/{0}'.format(provider_id)
body = self._make_request('patch', url, json=req, params=params)
return SAMLProviderConfig(body)
def delete_saml_provider_config(self, provider_id):
self._make_request('delete', '/inboundSamlConfigs/{0}'.format(provider_id))
def list_saml_provider_configs(self, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS):
return _ListSAMLProviderConfigsPage(
self._fetch_saml_provider_configs, page_token, max_results)
def _fetch_saml_provider_configs(self, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS):
return self._fetch_provider_configs('/inboundSamlConfigs', page_token, max_results)
def _fetch_provider_configs(self, path, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS):
"""Fetches a page of auth provider configs"""
if page_token is not None:
if not isinstance(page_token, str) or not page_token:
raise ValueError('Page token must be a non-empty string.')
if not isinstance(max_results, int):
raise ValueError('Max results must be an integer.')
if max_results < 1 or max_results > MAX_LIST_CONFIGS_RESULTS:
raise ValueError(
'Max results must be a positive integer less than or equal to '
params = 'pageSize={0}'.format(max_results)
if page_token:
params += '&pageToken={0}'.format(page_token)
return self._make_request('get', path, params=params)
def _make_request(self, method, path, **kwargs):
url = '{0}{1}'.format(self.base_url, path)
return self.http_client.body(method, url, **kwargs)
except requests.exceptions.RequestException as error:
raise _auth_utils.handle_auth_backend_error(error)
def _validate_oidc_provider_id(provider_id):
if not isinstance(provider_id, str):
raise ValueError(
'Invalid OIDC provider ID: {0}. Provider ID must be a non-empty string.'.format(
if not provider_id.startswith('oidc.'):
raise ValueError('Invalid OIDC provider ID: {0}.'.format(provider_id))
return provider_id
def _validate_saml_provider_id(provider_id):
if not isinstance(provider_id, str):
raise ValueError(
'Invalid SAML provider ID: {0}. Provider ID must be a non-empty string.'.format(
if not provider_id.startswith('saml.'):
raise ValueError('Invalid SAML provider ID: {0}.'.format(provider_id))
return provider_id
def _validate_non_empty_string(value, label):
"""Validates that the given value is a non-empty string."""
if not isinstance(value, str):
raise ValueError('Invalid type for {0}: {1}.'.format(label, value))
if not value:
raise ValueError('{0} must not be empty.'.format(label))
return value
def _validate_url(url, label):
"""Validates that the given value is a well-formed URL string."""
if not isinstance(url, str) or not url:
raise ValueError(
'Invalid photo URL: "{0}". {1} must be a non-empty '
'string.'.format(url, label))
parsed = parse.urlparse(url)
if not parsed.netloc:
raise ValueError('Malformed {0}: "{1}".'.format(label, url))
return url
except Exception:
raise ValueError('Malformed {0}: "{1}".'.format(label, url))
def _validate_x509_certificates(x509_certificates):
if not isinstance(x509_certificates, list) or not x509_certificates:
raise ValueError('x509_certificates must be a non-empty list.')
if not all([isinstance(cert, str) and cert for cert in x509_certificates]):
raise ValueError('x509_certificates must only contain non-empty strings.')
return [{'x509Certificate': cert} for cert in x509_certificates]