Vehicle-Anti-Theft-Face-Rec.../venv/Lib/site-packages/firebase_admin/ml.py

984 lines
36 KiB
Python
Raw Normal View History

# Copyright 2019 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Firebase ML module.
This module contains functions for creating, updating, getting, listing,
deleting, publishing and unpublishing Firebase ML models.
"""
import datetime
import re
import time
import os
from urllib import parse
import requests
import firebase_admin
from firebase_admin import _http_client
from firebase_admin import _utils
from firebase_admin import exceptions
# pylint: disable=import-error,no-name-in-module
try:
from firebase_admin import storage
_GCS_ENABLED = True
except ImportError:
_GCS_ENABLED = False
# pylint: disable=import-error,no-name-in-module
try:
import tensorflow as tf
_TF_ENABLED = True
except ImportError:
_TF_ENABLED = False
_ML_ATTRIBUTE = '_ml'
_MAX_PAGE_SIZE = 100
_MODEL_ID_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$')
_DISPLAY_NAME_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,32}$')
_TAG_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,32}$')
_GCS_TFLITE_URI_PATTERN = re.compile(
r'^gs://(?P<bucket_name>[a-z0-9_.-]{3,63})/(?P<blob_name>.+)$')
_AUTO_ML_MODEL_PATTERN = re.compile(
r'^projects/(?P<project_id>[a-z0-9-]{6,30})/locations/(?P<location_id>[^/]+)/' +
r'models/(?P<model_id>[A-Za-z0-9]+)$')
_RESOURCE_NAME_PATTERN = re.compile(
r'^projects/(?P<project_id>[a-z0-9-]{6,30})/models/(?P<model_id>[A-Za-z0-9_-]{1,60})$')
_OPERATION_NAME_PATTERN = re.compile(
r'^projects/(?P<project_id>[a-z0-9-]{6,30})/operations/[^/]+$')
def _get_ml_service(app):
""" Returns an _MLService instance for an App.
Args:
app: A Firebase App instance (or None to use the default App).
Returns:
_MLService: An _MLService for the specified App instance.
Raises:
ValueError: If the app argument is invalid.
"""
return _utils.get_app_service(app, _ML_ATTRIBUTE, _MLService)
def create_model(model, app=None):
"""Creates a model in the current Firebase project.
Args:
model: An ml.Model to create.
app: A Firebase app instance (or None to use the default app).
Returns:
Model: The model that was created in Firebase ML.
"""
ml_service = _get_ml_service(app)
return Model.from_dict(ml_service.create_model(model), app=app)
def update_model(model, app=None):
"""Updates a model's metadata or model file.
Args:
model: The ml.Model to update.
app: A Firebase app instance (or None to use the default app).
Returns:
Model: The updated model.
"""
ml_service = _get_ml_service(app)
return Model.from_dict(ml_service.update_model(model), app=app)
def publish_model(model_id, app=None):
"""Publishes a Firebase ML model.
A published model can be downloaded to client apps.
Args:
model_id: The id of the model to publish.
app: A Firebase app instance (or None to use the default app).
Returns:
Model: The published model.
"""
ml_service = _get_ml_service(app)
return Model.from_dict(ml_service.set_published(model_id, publish=True), app=app)
def unpublish_model(model_id, app=None):
"""Unpublishes a Firebase ML model.
Args:
model_id: The id of the model to unpublish.
app: A Firebase app instance (or None to use the default app).
Returns:
Model: The unpublished model.
"""
ml_service = _get_ml_service(app)
return Model.from_dict(ml_service.set_published(model_id, publish=False), app=app)
def get_model(model_id, app=None):
"""Gets the model specified by the given ID.
Args:
model_id: The id of the model to get.
app: A Firebase app instance (or None to use the default app).
Returns:
Model: The requested model.
"""
ml_service = _get_ml_service(app)
return Model.from_dict(ml_service.get_model(model_id), app=app)
def list_models(list_filter=None, page_size=None, page_token=None, app=None):
"""Lists the current project's models.
Args:
list_filter: a list filter string such as ``tags:'tag_1'``. None will return all models.
page_size: A number between 1 and 100 inclusive that specifies the maximum
number of models to return per page. None for default.
page_token: A next page token returned from a previous page of results. None
for first page of results.
app: A Firebase app instance (or None to use the default app).
Returns:
ListModelsPage: A (filtered) list of models.
"""
ml_service = _get_ml_service(app)
return ListModelsPage(
ml_service.list_models, list_filter, page_size, page_token, app=app)
def delete_model(model_id, app=None):
"""Deletes a model from the current project.
Args:
model_id: The id of the model you wish to delete.
app: A Firebase app instance (or None to use the default app).
"""
ml_service = _get_ml_service(app)
ml_service.delete_model(model_id)
class Model:
"""A Firebase ML Model object.
Args:
display_name: The display name of your model - used to identify your model in code.
tags: Optional list of strings associated with your model. Can be used in list queries.
model_format: A subclass of ModelFormat. (e.g. TFLiteFormat) Specifies the model details.
"""
def __init__(self, display_name=None, tags=None, model_format=None):
self._app = None # Only needed for wait_for_unlo
self._data = {}
self._model_format = None
if display_name is not None:
self.display_name = display_name
if tags is not None:
self.tags = tags
if model_format is not None:
self.model_format = model_format
@classmethod
def from_dict(cls, data, app=None):
"""Create an instance of the object from a dict."""
data_copy = dict(data)
tflite_format = None
tflite_format_data = data_copy.pop('tfliteModel', None)
data_copy.pop('@type', None) # Returned by Operations. (Not needed)
if tflite_format_data:
tflite_format = TFLiteFormat.from_dict(tflite_format_data)
model = Model(model_format=tflite_format)
model._data = data_copy # pylint: disable=protected-access
model._app = app # pylint: disable=protected-access
return model
def _update_from_dict(self, data):
copy = Model.from_dict(data)
self.model_format = copy.model_format
self._data = copy._data # pylint: disable=protected-access
def __eq__(self, other):
if isinstance(other, self.__class__):
# pylint: disable=protected-access
return self._data == other._data and self._model_format == other._model_format
return False
def __ne__(self, other):
return not self.__eq__(other)
@property
def model_id(self):
"""The model's ID, unique to the project."""
if not self._data.get('name'):
return None
_, model_id = _validate_and_parse_name(self._data.get('name'))
return model_id
@property
def display_name(self):
"""The model's display name, used to refer to the model in code and in
the Firebase console."""
return self._data.get('displayName')
@display_name.setter
def display_name(self, display_name):
self._data['displayName'] = _validate_display_name(display_name)
return self
@staticmethod
def _convert_to_millis(date_string):
if not date_string:
return None
format_str = '%Y-%m-%dT%H:%M:%S.%fZ'
epoch = datetime.datetime.utcfromtimestamp(0)
datetime_object = datetime.datetime.strptime(date_string, format_str)
millis = int((datetime_object - epoch).total_seconds() * 1000)
return millis
@property
def create_time(self):
"""The time the model was created."""
return Model._convert_to_millis(self._data.get('createTime', None))
@property
def update_time(self):
"""The time the model was last updated."""
return Model._convert_to_millis(self._data.get('updateTime', None))
@property
def validation_error(self):
"""Validation error message."""
return self._data.get('state', {}).get('validationError', {}).get('message')
@property
def published(self):
"""True if the model is published and available for clients to
download."""
return bool(self._data.get('state', {}).get('published'))
@property
def etag(self):
"""The entity tag (ETag) of the model resource."""
return self._data.get('etag')
@property
def model_hash(self):
"""SHA256 hash of the model binary."""
return self._data.get('modelHash')
@property
def tags(self):
"""Tag strings, used for filtering query results."""
return self._data.get('tags')
@tags.setter
def tags(self, tags):
self._data['tags'] = _validate_tags(tags)
return self
@property
def locked(self):
"""True if the Model object is locked by an active operation."""
return bool(self._data.get('activeOperations') and
len(self._data.get('activeOperations')) > 0)
def wait_for_unlocked(self, max_time_seconds=None):
"""Waits for the model to be unlocked. (All active operations complete)
Args:
max_time_seconds: The maximum number of seconds to wait for the model to unlock.
(None for no limit)
Raises:
exceptions.DeadlineExceeded: If max_time_seconds passed and the model is still locked.
"""
if not self.locked:
return
ml_service = _get_ml_service(self._app)
op_name = self._data.get('activeOperations')[0].get('name')
model_dict = ml_service.handle_operation(
ml_service.get_operation(op_name),
wait_for_operation=True,
max_time_seconds=max_time_seconds)
self._update_from_dict(model_dict)
@property
def model_format(self):
"""The model's ``ModelFormat`` object, which represents the model's
format and storage location."""
return self._model_format
@model_format.setter
def model_format(self, model_format):
if model_format is not None:
_validate_model_format(model_format)
self._model_format = model_format #Can be None
return self
def as_dict(self, for_upload=False):
"""Returns a serializable representation of the object."""
copy = dict(self._data)
if self._model_format:
copy.update(self._model_format.as_dict(for_upload=for_upload))
return copy
class ModelFormat:
"""Abstract base class representing a Model Format such as TFLite."""
def as_dict(self, for_upload=False):
"""Returns a serializable representation of the object."""
raise NotImplementedError
class TFLiteFormat(ModelFormat):
"""Model format representing a TFLite model.
Args:
model_source: A TFLiteModelSource sub class. Specifies the details of the model source.
"""
def __init__(self, model_source=None):
self._data = {}
self._model_source = None
if model_source is not None:
self.model_source = model_source
@classmethod
def from_dict(cls, data):
"""Create an instance of the object from a dict."""
data_copy = dict(data)
tflite_format = TFLiteFormat(model_source=cls._init_model_source(data_copy))
tflite_format._data = data_copy # pylint: disable=protected-access
return tflite_format
def __eq__(self, other):
if isinstance(other, self.__class__):
# pylint: disable=protected-access
return self._data == other._data and self._model_source == other._model_source
return False
def __ne__(self, other):
return not self.__eq__(other)
@staticmethod
def _init_model_source(data):
gcs_tflite_uri = data.pop('gcsTfliteUri', None)
if gcs_tflite_uri:
return TFLiteGCSModelSource(gcs_tflite_uri=gcs_tflite_uri)
auto_ml_model = data.pop('automlModel', None)
if auto_ml_model:
return TFLiteAutoMlSource(auto_ml_model=auto_ml_model)
return None
@property
def model_source(self):
"""The TF Lite model's location."""
return self._model_source
@model_source.setter
def model_source(self, model_source):
if model_source is not None:
if not isinstance(model_source, TFLiteModelSource):
raise TypeError('Model source must be a TFLiteModelSource object.')
self._model_source = model_source # Can be None
@property
def size_bytes(self):
"""The size in bytes of the TF Lite model."""
return self._data.get('sizeBytes')
def as_dict(self, for_upload=False):
"""Returns a serializable representation of the object."""
copy = dict(self._data)
if self._model_source:
copy.update(self._model_source.as_dict(for_upload=for_upload))
return {'tfliteModel': copy}
class TFLiteModelSource:
"""Abstract base class representing a model source for TFLite format models."""
def as_dict(self, for_upload=False):
"""Returns a serializable representation of the object."""
raise NotImplementedError
class _CloudStorageClient:
"""Cloud Storage helper class"""
GCS_URI = 'gs://{0}/{1}'
BLOB_NAME = 'Firebase/ML/Models/{0}'
@staticmethod
def _assert_gcs_enabled():
if not _GCS_ENABLED:
raise ImportError('Failed to import the Cloud Storage library for Python. Make sure '
'to install the "google-cloud-storage" module.')
@staticmethod
def _parse_gcs_tflite_uri(uri):
# GCS Bucket naming rules are complex. The regex is not comprehensive.
# See https://cloud.google.com/storage/docs/naming for full details.
matcher = _GCS_TFLITE_URI_PATTERN.match(uri)
if not matcher:
raise ValueError('GCS TFLite URI format is invalid.')
return matcher.group('bucket_name'), matcher.group('blob_name')
@staticmethod
def upload(bucket_name, model_file_name, app):
"""Upload a model file to the specified Storage bucket."""
_CloudStorageClient._assert_gcs_enabled()
file_name = os.path.basename(model_file_name)
bucket = storage.bucket(bucket_name, app=app)
blob_name = _CloudStorageClient.BLOB_NAME.format(file_name)
blob = bucket.blob(blob_name)
blob.upload_from_filename(model_file_name)
return _CloudStorageClient.GCS_URI.format(bucket.name, blob_name)
@staticmethod
def sign_uri(gcs_tflite_uri, app):
"""Makes the gcs_tflite_uri readable for GET for 10 minutes via signed_uri."""
_CloudStorageClient._assert_gcs_enabled()
bucket_name, blob_name = _CloudStorageClient._parse_gcs_tflite_uri(gcs_tflite_uri)
bucket = storage.bucket(bucket_name, app=app)
blob = bucket.blob(blob_name)
return blob.generate_signed_url(
version='v4',
expiration=datetime.timedelta(minutes=10),
method='GET'
)
class TFLiteGCSModelSource(TFLiteModelSource):
"""TFLite model source representing a tflite model file stored in GCS."""
_STORAGE_CLIENT = _CloudStorageClient()
def __init__(self, gcs_tflite_uri, app=None):
self._app = app
self._gcs_tflite_uri = _validate_gcs_tflite_uri(gcs_tflite_uri)
def __eq__(self, other):
if isinstance(other, self.__class__):
return self._gcs_tflite_uri == other._gcs_tflite_uri # pylint: disable=protected-access
return False
def __ne__(self, other):
return not self.__eq__(other)
@classmethod
def from_tflite_model_file(cls, model_file_name, bucket_name=None, app=None):
"""Uploads the model file to an existing Google Cloud Storage bucket.
Args:
model_file_name: The name of the model file.
bucket_name: The name of an existing bucket. None to use the default bucket configured
in the app.
app: A Firebase app instance (or None to use the default app).
Returns:
TFLiteGCSModelSource: The source created from the model_file
Raises:
ImportError: If the Cloud Storage Library has not been installed.
"""
gcs_uri = TFLiteGCSModelSource._STORAGE_CLIENT.upload(bucket_name, model_file_name, app)
return TFLiteGCSModelSource(gcs_tflite_uri=gcs_uri, app=app)
@staticmethod
def _assert_tf_enabled():
if not _TF_ENABLED:
raise ImportError('Failed to import the tensorflow library for Python. Make sure '
'to install the tensorflow module.')
if not tf.version.VERSION.startswith('1.') and not tf.version.VERSION.startswith('2.'):
raise ImportError('Expected tensorflow version 1.x or 2.x, but found {0}'
.format(tf.version.VERSION))
@staticmethod
def _tf_convert_from_saved_model(saved_model_dir):
# Same for both v1.x and v2.x
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
return converter.convert()
@staticmethod
def _tf_convert_from_keras_model(keras_model):
"""Converts the given Keras model into a TF Lite model."""
# Version 1.x conversion function takes a model file. Version 2.x takes the model itself.
if tf.version.VERSION.startswith('1.'):
keras_file = 'firebase_keras_model.h5'
tf.keras.models.save_model(keras_model, keras_file)
converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file)
else:
converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
return converter.convert()
@classmethod
def from_saved_model(cls, saved_model_dir, model_file_name='firebase_ml_model.tflite',
bucket_name=None, app=None):
"""Creates a Tensor Flow Lite model from the saved model, and uploads the model to GCS.
Args:
saved_model_dir: The saved model directory.
model_file_name: The name that the tflite model will be saved as in Cloud Storage.
bucket_name: The name of an existing bucket. None to use the default bucket configured
in the app.
app: Optional. A Firebase app instance (or None to use the default app)
Returns:
TFLiteGCSModelSource: The source created from the saved_model_dir
Raises:
ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed.
"""
TFLiteGCSModelSource._assert_tf_enabled()
tflite_model = TFLiteGCSModelSource._tf_convert_from_saved_model(saved_model_dir)
with open(model_file_name, 'wb') as model_file:
model_file.write(tflite_model)
return TFLiteGCSModelSource.from_tflite_model_file(model_file_name, bucket_name, app)
@classmethod
def from_keras_model(cls, keras_model, model_file_name='firebase_ml_model.tflite',
bucket_name=None, app=None):
"""Creates a Tensor Flow Lite model from the keras model, and uploads the model to GCS.
Args:
keras_model: A tf.keras model.
model_file_name: The name that the tflite model will be saved as in Cloud Storage.
bucket_name: The name of an existing bucket. None to use the default bucket configured
in the app.
app: Optional. A Firebase app instance (or None to use the default app)
Returns:
TFLiteGCSModelSource: The source created from the keras_model
Raises:
ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed.
"""
TFLiteGCSModelSource._assert_tf_enabled()
tflite_model = TFLiteGCSModelSource._tf_convert_from_keras_model(keras_model)
with open(model_file_name, 'wb') as model_file:
model_file.write(tflite_model)
return TFLiteGCSModelSource.from_tflite_model_file(model_file_name, bucket_name, app)
@property
def gcs_tflite_uri(self):
"""URI of the model file in Cloud Storage."""
return self._gcs_tflite_uri
@gcs_tflite_uri.setter
def gcs_tflite_uri(self, gcs_tflite_uri):
self._gcs_tflite_uri = _validate_gcs_tflite_uri(gcs_tflite_uri)
def _get_signed_gcs_tflite_uri(self):
"""Signs the GCS uri, so the model file can be uploaded to Firebase ML and verified."""
return TFLiteGCSModelSource._STORAGE_CLIENT.sign_uri(self._gcs_tflite_uri, self._app)
def as_dict(self, for_upload=False):
"""Returns a serializable representation of the object."""
if for_upload:
return {'gcsTfliteUri': self._get_signed_gcs_tflite_uri()}
return {'gcsTfliteUri': self._gcs_tflite_uri}
class TFLiteAutoMlSource(TFLiteModelSource):
"""TFLite model source representing a tflite model created with AutoML."""
def __init__(self, auto_ml_model, app=None):
self._app = app
self.auto_ml_model = auto_ml_model
def __eq__(self, other):
if isinstance(other, self.__class__):
return self.auto_ml_model == other.auto_ml_model
return False
def __ne__(self, other):
return not self.__eq__(other)
@property
def auto_ml_model(self):
"""Resource name of the model, created by the AutoML API or Cloud console."""
return self._auto_ml_model
@auto_ml_model.setter
def auto_ml_model(self, auto_ml_model):
self._auto_ml_model = _validate_auto_ml_model(auto_ml_model)
def as_dict(self, for_upload=False):
"""Returns a serializable representation of the object."""
# Upload is irrelevant for auto_ml models
return {'automlModel': self._auto_ml_model}
class ListModelsPage:
"""Represents a page of models in a Firebase project.
Provides methods for traversing the models included in this page, as well as
retrieving subsequent pages of models. The iterator returned by
``iterate_all()`` can be used to iterate through all the models in the
Firebase project starting from this page.
"""
def __init__(self, list_models_func, list_filter, page_size, page_token, app):
self._list_models_func = list_models_func
self._list_filter = list_filter
self._page_size = page_size
self._page_token = page_token
self._app = app
self._list_response = list_models_func(list_filter, page_size, page_token)
@property
def models(self):
"""A list of Models from this page."""
return [
Model.from_dict(model, app=self._app) for model in self._list_response.get('models', [])
]
@property
def list_filter(self):
"""The filter string used to filter the models."""
return self._list_filter
@property
def next_page_token(self):
"""Token identifying the next page of results."""
return self._list_response.get('nextPageToken', '')
@property
def has_next_page(self):
"""True if more pages are available."""
return bool(self.next_page_token)
def get_next_page(self):
"""Retrieves the next page of models if available.
Returns:
ListModelsPage: Next page of models, or None if this is the last page.
"""
if self.has_next_page:
return ListModelsPage(
self._list_models_func,
self._list_filter,
self._page_size,
self.next_page_token,
self._app)
return None
def iterate_all(self):
"""Retrieves an iterator for Models.
Returned iterator will iterate through all the models in the Firebase
project starting from this page. The iterator will never buffer more than
one page of models in memory at a time.
Returns:
iterator: An iterator of Model instances.
"""
return _ModelIterator(self)
class _ModelIterator:
"""An iterator that allows iterating over models, one at a time.
This implementation loads a page of models into memory, and iterates on them.
When the whole page has been traversed, it loads another page. This class
never keeps more than one page of entries in memory.
"""
def __init__(self, current_page):
if not isinstance(current_page, ListModelsPage):
raise TypeError('Current page must be a ListModelsPage')
self._current_page = current_page
self._index = 0
def next(self):
if self._index == len(self._current_page.models):
if self._current_page.has_next_page:
self._current_page = self._current_page.get_next_page()
self._index = 0
if self._index < len(self._current_page.models):
result = self._current_page.models[self._index]
self._index += 1
return result
raise StopIteration
def __next__(self):
return self.next()
def __iter__(self):
return self
def _validate_and_parse_name(name):
# The resource name is added automatically from API call responses.
# The only way it could be invalid is if someone tries to
# create a model from a dictionary manually and does it incorrectly.
matcher = _RESOURCE_NAME_PATTERN.match(name)
if not matcher:
raise ValueError('Model resource name format is invalid.')
return matcher.group('project_id'), matcher.group('model_id')
def _validate_model(model, update_mask=None):
if not isinstance(model, Model):
raise TypeError('Model must be an ml.Model.')
if update_mask is None and not model.display_name:
raise ValueError('Model must have a display name.')
def _validate_model_id(model_id):
if not _MODEL_ID_PATTERN.match(model_id):
raise ValueError('Model ID format is invalid.')
def _validate_operation_name(op_name):
if not _OPERATION_NAME_PATTERN.match(op_name):
raise ValueError('Operation name format is invalid.')
return op_name
def _validate_display_name(display_name):
if not _DISPLAY_NAME_PATTERN.match(display_name):
raise ValueError('Display name format is invalid.')
return display_name
def _validate_tags(tags):
if not isinstance(tags, list) or not \
all(isinstance(tag, str) for tag in tags):
raise TypeError('Tags must be a list of strings.')
if not all(_TAG_PATTERN.match(tag) for tag in tags):
raise ValueError('Tag format is invalid.')
return tags
def _validate_gcs_tflite_uri(uri):
# GCS Bucket naming rules are complex. The regex is not comprehensive.
# See https://cloud.google.com/storage/docs/naming for full details.
if not _GCS_TFLITE_URI_PATTERN.match(uri):
raise ValueError('GCS TFLite URI format is invalid.')
return uri
def _validate_auto_ml_model(model):
if not _AUTO_ML_MODEL_PATTERN.match(model):
raise ValueError('Model resource name format is invalid.')
return model
def _validate_model_format(model_format):
if not isinstance(model_format, ModelFormat):
raise TypeError('Model format must be a ModelFormat object.')
return model_format
def _validate_list_filter(list_filter):
if list_filter is not None:
if not isinstance(list_filter, str):
raise TypeError('List filter must be a string or None.')
def _validate_page_size(page_size):
if page_size is not None:
if type(page_size) is not int: # pylint: disable=unidiomatic-typecheck
# Specifically type() to disallow boolean which is a subtype of int
raise TypeError('Page size must be a number or None.')
if page_size < 1 or page_size > _MAX_PAGE_SIZE:
raise ValueError('Page size must be a positive integer between '
'1 and {0}'.format(_MAX_PAGE_SIZE))
def _validate_page_token(page_token):
if page_token is not None:
if not isinstance(page_token, str):
raise TypeError('Page token must be a string or None.')
class _MLService:
"""Firebase ML service."""
PROJECT_URL = 'https://firebaseml.googleapis.com/v1beta2/projects/{0}/'
OPERATION_URL = 'https://firebaseml.googleapis.com/v1beta2/'
POLL_EXPONENTIAL_BACKOFF_FACTOR = 1.5
POLL_BASE_WAIT_TIME_SECONDS = 3
def __init__(self, app):
self._project_id = app.project_id
if not self._project_id:
raise ValueError(
'Project ID is required to access ML service. Either set the '
'projectId option, or use service account credentials.')
self._project_url = _MLService.PROJECT_URL.format(self._project_id)
ml_headers = {
'X-FIREBASE-CLIENT': 'fire-admin-python/{0}'.format(firebase_admin.__version__),
}
self._client = _http_client.JsonHttpClient(
credential=app.credential.get_credential(),
headers=ml_headers,
base_url=self._project_url)
self._operation_client = _http_client.JsonHttpClient(
credential=app.credential.get_credential(),
headers=ml_headers,
base_url=_MLService.OPERATION_URL)
def get_operation(self, op_name):
_validate_operation_name(op_name)
try:
return self._operation_client.body('get', url=op_name)
except requests.exceptions.RequestException as error:
raise _utils.handle_platform_error_from_requests(error)
def _exponential_backoff(self, current_attempt, stop_time):
"""Sleeps for the appropriate amount of time. Or throws deadline exceeded."""
delay_factor = pow(_MLService.POLL_EXPONENTIAL_BACKOFF_FACTOR, current_attempt)
wait_time_seconds = delay_factor * _MLService.POLL_BASE_WAIT_TIME_SECONDS
if stop_time is not None:
max_seconds_left = (stop_time - datetime.datetime.now()).total_seconds()
if max_seconds_left < 1: # allow a bit of time for rpc
raise exceptions.DeadlineExceededError('Polling max time exceeded.')
wait_time_seconds = min(wait_time_seconds, max_seconds_left - 1)
time.sleep(wait_time_seconds)
def handle_operation(self, operation, wait_for_operation=False, max_time_seconds=None):
"""Handles long running operations.
Args:
operation: The operation to handle.
wait_for_operation: Should we allow polling for the operation to complete.
If no polling is requested, a locked model will be returned instead.
max_time_seconds: The maximum seconds to try polling for operation complete.
(None for no limit)
Returns:
dict: A dictionary of the returned model properties.
Raises:
TypeError: if the operation is not a dictionary.
ValueError: If the operation is malformed.
UnknownError: If the server responds with an unexpected response.
err: If the operation exceeds polling attempts or stop_time
"""
if not isinstance(operation, dict):
raise TypeError('Operation must be a dictionary.')
if operation.get('done'):
# Operations which are immediately done don't have an operation name
if operation.get('response'):
return operation.get('response')
if operation.get('error'):
raise _utils.handle_operation_error(operation.get('error'))
raise exceptions.UnknownError(message='Internal Error: Malformed Operation.')
op_name = _validate_operation_name(operation.get('name'))
metadata = operation.get('metadata', {})
metadata_type = metadata.get('@type', '')
if not metadata_type.endswith('ModelOperationMetadata'):
raise TypeError('Unknown type of operation metadata.')
_, model_id = _validate_and_parse_name(metadata.get('name'))
current_attempt = 0
start_time = datetime.datetime.now()
stop_time = (None if max_time_seconds is None else
start_time + datetime.timedelta(seconds=max_time_seconds))
while wait_for_operation and not operation.get('done'):
# We just got this operation. Wait before getting another
# so we don't exceed the GetOperation maximum request rate.
self._exponential_backoff(current_attempt, stop_time)
operation = self.get_operation(op_name)
current_attempt += 1
if operation.get('done'):
if operation.get('response'):
return operation.get('response')
if operation.get('error'):
raise _utils.handle_operation_error(operation.get('error'))
# If the operation is not complete or timed out, return a (locked) model instead
return get_model(model_id).as_dict()
def create_model(self, model):
_validate_model(model)
try:
return self.handle_operation(
self._client.body('post', url='models', json=model.as_dict(for_upload=True)))
except requests.exceptions.RequestException as error:
raise _utils.handle_platform_error_from_requests(error)
def update_model(self, model, update_mask=None):
_validate_model(model, update_mask)
path = 'models/{0}'.format(model.model_id)
if update_mask is not None:
path = path + '?updateMask={0}'.format(update_mask)
try:
return self.handle_operation(
self._client.body('patch', url=path, json=model.as_dict(for_upload=True)))
except requests.exceptions.RequestException as error:
raise _utils.handle_platform_error_from_requests(error)
def set_published(self, model_id, publish):
_validate_model_id(model_id)
model_name = 'projects/{0}/models/{1}'.format(self._project_id, model_id)
model = Model.from_dict({
'name': model_name,
'state': {
'published': publish
}
})
return self.update_model(model, update_mask='state.published')
def get_model(self, model_id):
_validate_model_id(model_id)
try:
return self._client.body('get', url='models/{0}'.format(model_id))
except requests.exceptions.RequestException as error:
raise _utils.handle_platform_error_from_requests(error)
def list_models(self, list_filter, page_size, page_token):
""" lists Firebase ML models."""
_validate_list_filter(list_filter)
_validate_page_size(page_size)
_validate_page_token(page_token)
params = {}
if list_filter:
params['filter'] = list_filter
if page_size:
params['page_size'] = page_size
if page_token:
params['page_token'] = page_token
path = 'models'
if params:
param_str = parse.urlencode(sorted(params.items()), True)
path = path + '?' + param_str
try:
return self._client.body('get', url=path)
except requests.exceptions.RequestException as error:
raise _utils.handle_platform_error_from_requests(error)
def delete_model(self, model_id):
_validate_model_id(model_id)
try:
self._client.body('delete', url='models/{0}'.format(model_id))
except requests.exceptions.RequestException as error:
raise _utils.handle_platform_error_from_requests(error)