Source code for adanet.core.tpu_estimator

"""An AdaNet estimator implementation which can run on TPU.

Copyright 2018 The AdaNet Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import contextlib
import functools

from absl import logging
from adanet import tf_compat
from adanet.core.estimator import Estimator
import tensorflow.compat.v2 as tf


# pylint: disable=g-classes-have-attributes
[docs]class TPUEstimator(Estimator, tf.compat.v1.estimator.tpu.TPUEstimator): """An :class:`adanet.Estimator` capable of training and evaluating on TPU. Unless :code:`use_tpu=False`, training will run on TPU. However, certain parts of the AdaNet training loop, such as report materialization and best candidate selection, will still occurr on CPU. Furthermore, if using TPUEmbedding (i.e. :code:`embedding_config_spec` is supplied), inference will also occurr on CPU. TODO: Provide the missing functionality detailed below. N.B: Embeddings using the TPUEmbedding (i.e. :code:`embedding_config_spec` is provided) only support :code:`shared_embedding_columns` when running for multiple AdaNet iterations. Using regular :code:`embedding_columns` will cause iterations 2..n to fail because of mismatched embedding scopes. Args: head: See :class:`adanet.Estimator`. subnetwork_generator: See :class:`adanet.Estimator`. max_iteration_steps: See :class:`adanet.Estimator`. ensemblers: See :class:`adanet.Estimator`. ensemble_strategies: See :class:`adanet.Estimator`. evaluator: See :class:`adanet.Estimator`. report_materializer: See :class:`adanet.Estimator`. metric_fn: See :class:`adanet.Estimator`. force_grow: See :class:`adanet.Estimator`. replicate_ensemble_in_training: See :class:`adanet.Estimator`. adanet_loss_decay: See :class:`adanet.Estimator`. report_dir: See :class:`adanet.Estimator`. config: See :class:`adanet.Estimator`. use_tpu: Boolean to enable training on TPU. Defaults to :code:`True` and is only provided to allow debugging models on CPU/GPU. Use :class:`adanet.Estimator` instead if you do not plan to run on TPU. eval_on_tpu: Boolean to enable evaluating on TPU. Defaults to :code:`True`. Ignored if :code:`use_tpu=False`. export_to_tpu: See :class:`tf.compat.v1.estimator.tpu.TPUEstimator`. train_batch_size: See :class:`tf.compat.v1.estimator.tpu.TPUEstimator`. Defaults to 0 if `None`. eval_batch_size: See :class:`tf.compat.v1.estimator.tpu.TPUEstimator`. Defaults to train_batch_size if `None`. predict_batch_size: See :class:`tf.compat.v1.estimator.tpu.TPUEstimator`. Defaults to eval_batch_size if `None`. embedding_config_spec: See :class:`tf.compat.v1.estimator.tpu.TPUEstimator`. If supplied, :code:`predict` will be called on CPU and no TPU compatible :code:`SavedModel` will be exported. debug: See :class:`adanet.Estimator`. enable_ensemble_summaries: See :class:`adanet.Estimator`. enable_subnetwork_summaries: See :class:`adanet.Estimator`. export_subnetwork_logits: Whether to include subnetwork logits in exports. export_subnetwork_last_layer: Whether to include subnetwork last layer in exports. global_step_combiner_fn: See :class:`adanet.Estimator`. max_iterations: See :class:`adanet.Estimator`. replay_config: See :class:`adanet.Estimator`. add_predict_batch_config: If True, supplies a default `tpu_estimator.BatchConfig` when calling `tpu_estimator.model_fn_inference_on_tpu`, otherwise supplies None. **kwargs: Extra keyword args passed to the parent. """ def __init__(self, head, subnetwork_generator, max_iteration_steps, ensemblers=None, ensemble_strategies=None, evaluator=None, report_materializer=None, metric_fn=None, force_grow=False, replicate_ensemble_in_training=False, adanet_loss_decay=.9, model_dir=None, report_dir=None, config=None, use_tpu=True, eval_on_tpu=True, export_to_tpu=True, train_batch_size=None, eval_batch_size=None, predict_batch_size=None, embedding_config_spec=None, debug=False, enable_ensemble_summaries=True, enable_subnetwork_summaries=True, export_subnetwork_logits=False, export_subnetwork_last_layer=True, global_step_combiner_fn=tf.math.reduce_mean, max_iterations=None, replay_config=None, add_predict_batch_config=True, **kwargs): self._use_tpu = use_tpu if not self._use_tpu: logging.warning( "This adanet.TPUEstimator is meant to be used for running on TPU. " "If you want to run on CPU/GPU, use adanet.Estimator instead.") # TPUEstimator modifies config under the hood. We keep track of it here so # we can use it from _create_temp_run_config. self._original_config = config or tf_compat.v1.estimator.tpu.RunConfig() self._eval_on_tpu = eval_on_tpu if self._use_tpu else False self._export_to_tpu = export_to_tpu self._train_batch_size = train_batch_size or 0 self._eval_batch_size = eval_batch_size or train_batch_size or 0 self._predict_batch_size = ( predict_batch_size or eval_batch_size or train_batch_size or 0) self._embedding_config_spec = embedding_config_spec self._add_predict_batch_config = add_predict_batch_config if self._embedding_config_spec: logging.warning( "TPU does not support inference with TPUEmbedding. Force setting " "`export_to_tpu=False` so no TPU SavedModel will be exported.") self._export_to_tpu = False from tensorflow_estimator.python.estimator.tpu import tpu_estimator # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top super(TPUEstimator, self).__init__( head=head, subnetwork_generator=subnetwork_generator, max_iteration_steps=max_iteration_steps, ensemblers=ensemblers, ensemble_strategies=ensemble_strategies, evaluator=evaluator, report_materializer=report_materializer, metric_fn=metric_fn, force_grow=force_grow, replicate_ensemble_in_training=replicate_ensemble_in_training, adanet_loss_decay=adanet_loss_decay, model_dir=model_dir, report_dir=report_dir, config=self._original_config, use_tpu=self._use_tpu, eval_on_tpu=self._eval_on_tpu, export_to_tpu=self._export_to_tpu, export_saved_model_api_version=( tpu_estimator.ExportSavedModelApiVersion.V2), train_batch_size=self._train_batch_size, eval_batch_size=self._eval_batch_size, predict_batch_size=self._predict_batch_size, embedding_config_spec=self._embedding_config_spec, debug=debug, enable_ensemble_summaries=enable_ensemble_summaries, enable_subnetwork_summaries=enable_subnetwork_summaries, export_subnetwork_logits=export_subnetwork_logits, export_subnetwork_last_layer=export_subnetwork_last_layer, global_step_combiner_fn=global_step_combiner_fn, max_iterations=max_iterations, replay_config=replay_config, **kwargs)
[docs] def predict(self, input_fn, predict_keys=None, hooks=None, checkpoint_path=None, yield_single_examples=True): use_tpu = self._use_tpu eval_on_tpu = self._eval_on_tpu if self._embedding_config_spec: logging.warning("TPU does not support inference with TPUEmbedding. " "Falling back to CPU.") use_tpu = False eval_on_tpu = False if not checkpoint_path: checkpoint_path = tf.train.latest_checkpoint(self.model_dir) logging.info("Computing predictions for AdaNet model at checkpoint: %s", checkpoint_path) params = self.params params.update({ "best_ensemble_index": self._compute_best_ensemble_index(checkpoint_path, hooks), "checkpoint_path": checkpoint_path, }) from tensorflow_estimator.python.estimator.tpu import tpu_estimator # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top # TODO: Consider extracting a common function to use here and in # _create_temp_estimator(). estimator = tf_compat.v1.estimator.tpu.TPUEstimator( model_fn=self._create_model_fn(hooks=hooks, is_export=False), params=params, config=self._original_config, model_dir=self.model_dir, use_tpu=use_tpu, eval_on_tpu=eval_on_tpu, export_to_tpu=self._export_to_tpu, export_saved_model_api_version=( tpu_estimator.ExportSavedModelApiVersion.V2), train_batch_size=self._train_batch_size, eval_batch_size=self._eval_batch_size, predict_batch_size=self._predict_batch_size, embedding_config_spec=self._embedding_config_spec) return estimator.predict( input_fn, predict_keys=predict_keys, hooks=hooks, checkpoint_path=checkpoint_path, yield_single_examples=yield_single_examples)
def _create_temp_run_config(self, temp_model_dir): """See the `Estimator` base class for details.""" return tf_compat.v1.estimator.tpu.RunConfig( model_dir=temp_model_dir, tpu_config=self._original_config.tpu_config, evaluation_master=self._original_config.evaluation_master, master=self._original_config.master, cluster=self._original_config.cluster, tf_random_seed=self._original_config.tf_random_seed, session_config=self._original_config.session_config, protocol=self._original_config.protocol) def _create_temp_estimator(self, config, **create_model_fn_args): """See the `Estimator` base class for details.""" from tensorflow_estimator.python.estimator.tpu import tpu_estimator # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top temp_model_dir = config.model_dir return tf_compat.v1.estimator.tpu.TPUEstimator( model_fn=self._create_model_fn(**create_model_fn_args), config=config, model_dir=temp_model_dir, use_tpu=self._use_tpu, eval_on_tpu=self._eval_on_tpu, export_to_tpu=self._export_to_tpu, export_saved_model_api_version=( tpu_estimator.ExportSavedModelApiVersion.V2), train_batch_size=self._train_batch_size, eval_batch_size=self._eval_batch_size, predict_batch_size=self._predict_batch_size, embedding_config_spec=self._embedding_config_spec) @contextlib.contextmanager def _call_input_fn_in_new_graph(self, input_fn, mode, config): """See the `Estimator` base class for details.""" # Bind parameters to input_fn since the parent's input_fn is not expected to # have any arguments. from tensorflow.python.util import function_utils # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top input_fn_args = function_utils.fn_args(input_fn) kwargs = {} if "mode" in input_fn_args: kwargs["mode"] = mode if "params" in input_fn_args: kwargs["params"] = self.params if "config" in input_fn_args: kwargs["config"] = config input_fn = functools.partial(input_fn, **kwargs) with super(TPUEstimator, self)._call_input_fn_in_new_graph(input_fn, mode, config) as res: yield res def _create_estimator_spec(self, current_iteration, mode, iteration_number_tensor, previous_iteration_vars, is_growing_phase, evaluation_name): """See the `Estimator` base class for details.""" if not self._use_tpu: return super(TPUEstimator, self)._create_estimator_spec( current_iteration, mode, iteration_number_tensor, previous_iteration_vars, is_growing_phase, evaluation_name) training = mode == tf.estimator.ModeKeys.TRAIN iteration_estimator_spec = current_iteration.estimator_spec training_hooks = self._training_hooks(current_iteration, training, iteration_number_tensor, previous_iteration_vars, is_growing_phase) if is_growing_phase: training_hooks = self._process_hooks_for_growing_phase(training_hooks) evaluation_hooks = self._evaluation_hooks(current_iteration, training, evaluation_name) return tf_compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=iteration_estimator_spec.predictions, loss=iteration_estimator_spec.loss, train_op=self._train_op(iteration_estimator_spec, is_growing_phase), host_call=self._create_host_call(current_iteration, training), eval_metrics=iteration_estimator_spec.eval_metrics, export_outputs=iteration_estimator_spec.export_outputs, # Return a constant summary_op, otherwise `Estimator` creates summary # ops that do not work on TPU. scaffold_fn=lambda: tf.compat.v1.train.Scaffold( # pylint: disable=g-long-lambda summary_op=tf.constant("")), training_hooks=training_hooks, evaluation_hooks=evaluation_hooks) def _training_hooks(self, current_iteration, training, iteration_number_tensor, previous_iteration_vars, is_growing_phase): """See the `Estimator` base class for details.""" training_hooks = super(TPUEstimator, self)._training_hooks(current_iteration, training, iteration_number_tensor, previous_iteration_vars, is_growing_phase) if self._use_tpu: # Remove summary hooks on TPU since summaries are saved via host_call. training_hooks = [ hook for hook in training_hooks if not isinstance(hook, tf.compat.v1.train.SummarySaverHook) ] return training_hooks def _create_host_call(self, current_iteration, training): """Construct a host_call writing scalar summaries. Args: current_iteration: The current `_Iteration`. training: Boolean indicating whether in training mode. Returns: (fn, args) Pair to be called by TPUEstimator as the host_call. """ if not training: return lambda **kwargs: [tf.no_op()], {} # Collect and flatten summary functions and arguments. summary_kwargs = collections.OrderedDict() gs_t = tf.reshape(tf.cast(tf.train.get_global_step(), dtype=tf.int32), [1]) summary_kwargs["global_step"] = gs_t summary_fns = collections.defaultdict(list) for i, summary in enumerate(current_iteration.summaries): for j, (summary_fn, tensor) in enumerate(summary.summary_tuples()): summary_fns[i].append(summary_fn) summary_kwargs["summary_{}_{}".format(i, j)] = tensor def _host_call_fn(**kwargs): """Training host call. Creates summaries for training metrics. Args: **kwargs: Dict of {str: Tensor} , with `Tensor` of shape `[batch]`. Must contain key "global_step" with value of current global_step Tensor. Returns: List of summary ops to run on the CPU host. """ from tensorflow.python.ops import summary_ops_v2 # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top gs = tf.cast(kwargs.pop("global_step")[0], dtype=tf.int64) for i, summary in enumerate(current_iteration.summaries): with summary_ops_v2.create_file_writer(summary.logdir).as_default(): with summary_ops_v2.record_summaries_every_n_global_steps( n=self.config.save_summary_steps, global_step=gs): for j, summary_fn in enumerate(summary_fns[i]): tensor = kwargs["summary_{}_{}".format(i, j)] summary_fn(tensor, step=gs) summary.clear_summary_tuples() return tf.compat.v1.summary.all_v2_summary_ops() return _host_call_fn, summary_kwargs def _create_model_fn(self, is_growing_phase=False, is_inside_training_loop=False, is_export=False, evaluation_name=None, best_ensemble_index=None, checkpoint_path=None, hooks=None): """See the `Estimator` base class for details.""" from tensorflow_estimator.python.estimator.tpu import tpu_estimator # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top adanet_model_fn = super(TPUEstimator, self)._create_model_fn( is_growing_phase, is_inside_training_loop, is_export, evaluation_name, best_ensemble_index, checkpoint_path, hooks) def _model_fn(features, labels, mode, params, config): """The model_fn to return which supports exporting on TPU.""" if (is_export and params["use_tpu"] and mode == tf.estimator.ModeKeys.PREDICT): batch_config = None if self._add_predict_batch_config: batch_config = tpu_estimator.BatchConfig( # Set num_batch_threads to the number of TPU cores on Servomatic. num_batch_threads=2, max_batch_size=self._predict_batch_size, # TODO: Magic number. Investigate whether there is a # better way to set this, or have the user pass it in. batch_timeout_micros=60 * 1000, allowed_batch_sizes=[self._predict_batch_size]) return tpu_estimator.model_fn_inference_on_tpu( adanet_model_fn, features=features, labels=labels, config=config, params=params, batch_config=batch_config) return adanet_model_fn(features, labels, mode, params, config) return _model_fn