Source code for adanet.core.estimator

"""An AdaNet estimator implementation in Tensorflow using a single graph.

Copyright 2018 The AdaNet Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import contextlib
import errno
import os
import time

from adanet.core.candidate import _CandidateBuilder
from adanet.core.ensemble import _EnsembleBuilder
from adanet.core.ensemble import MixtureWeightType
from adanet.core.iteration import _IterationBuilder
from adanet.core.report_accessor import _ReportAccessor
from adanet.core.summary import _ScopedSummary
from adanet.core.timer import _CountDownTimer
import numpy as np
import six
import tensorflow as tf

from tensorflow.python.ops import resources


class _StopAfterTrainingHook(tf.train.SessionRunHook):
  """Hook that requests stop once iteration is over."""

  def __init__(self, iteration, after_fn):
    """Initializes a `_StopAfterTrainingHook`.

    Args:
      iteration: An `_Iteration` instance.
      after_fn: A function to call after training stopped.

    Returns:
      A `_StopAfterTrainingHook` instance.
    """

    self._iteration = iteration
    self._after_fn = after_fn

  def before_run(self, run_context):
    """See `SessionRunHook`."""

    del run_context  # Unused
    return tf.train.SessionRunArgs(self._iteration.is_over_fn())

  def after_run(self, run_context, run_values):
    """See `SessionRunHook`."""

    is_over = run_values.results
    if not is_over:
      return
    run_context.request_stop()
    self._after_fn()


class _EvalMetricSaverHook(tf.train.SessionRunHook):
  """A hook for writing evaluation metrics as summaries to disk."""

  def __init__(self, name, eval_metric_ops, output_dir):
    """Initializes a `_EvalMetricSaverHook` instance.

    Args:
      name: String name of candidate owner of these metrics.
      eval_metric_ops: Dict of metric results keyed by name. The values of the
        dict are the results of calling a metric function, namely a
        `(metric_tensor, update_op)` tuple. `metric_tensor` should be evaluated
        without any impact on state (typically is a pure computation based on
        variables.). For example, it should not trigger the `update_op` or
        require any input fetching.
      output_dir: Directory for writing evaluation summaries.

    Returns:
      An `_EvalMetricSaverHook` instance.
    """

    self._name = name
    self._eval_metric_ops = eval_metric_ops
    self._output_dir = output_dir

  def before_run(self, run_context):
    """See `SessionRunHook`."""

    del run_context  # Unused
    return tf.train.SessionRunArgs(self._eval_metric_ops)

  def _dict_to_str(self, dictionary):
    """Get a `str` representation of a `dict`.

    Args:
      dictionary: The `dict` to be represented as `str`.

    Returns:
      A `str` representing the `dictionary`.
    """
    return ", ".join("%s = %s" % (k, v) for k, v in sorted(dictionary.items()))

  def end(self, session):
    """See `SessionRunHook`."""

    # Forked from tensorflow/python/estimator/estimator.py function called
    # _write_dict_to_summary.
    eval_dict = {}
    for key, metric in self._eval_metric_ops.items():
      eval_dict[key] = metric[0]
    current_global_step = tf.train.get_global_step()

    eval_dict, current_global_step = session.run((eval_dict,
                                                  current_global_step))

    tf.logging.info("Saving candidate '%s' dict for global step %d: %s",
                    self._name, current_global_step,
                    self._dict_to_str(eval_dict))
    summary_writer = tf.summary.FileWriterCache.get(self._output_dir)
    summary_proto = tf.summary.Summary()
    for key in eval_dict:
      value = eval_dict[key]
      if isinstance(value, (np.float32, float)):
        summary_proto.value.add(tag=key, simple_value=float(value))
      elif isinstance(value, six.binary_type):
        summ = tf.summary.Summary.FromString(value)
        for i, _ in enumerate(summ.value):
          summ.value[i].tag = "%s/%d" % (key, i)
        summary_proto.value.extend(summ.value)
      else:
        tf.logging.warn(
            "Skipping summary for %s, must be a float, np.float32, "
            "or a serialized string of Summary.", key)
    summary_writer.add_summary(summary_proto, current_global_step)
    summary_writer.flush()


[docs]class Estimator(tf.estimator.Estimator): # pyformat: disable r"""The AdaNet algorithm implemented as a :class:`tf.estimator.Estimator`. AdaNet is as defined in the paper: https://arxiv.org/abs/1607.01097. The AdaNet algorithm uses a weak learning algorithm to iteratively generate a set of candidate subnetworks that attempt to minimize the loss function defined in Equation (4) as part of an ensemble. At the end of each iteration, the best candidate is chosen based on its ensemble's complexity-regularized train loss. New subnetworks are allowed to use any subnetwork weights within the previous iteration's ensemble in order to improve upon them. If the complexity-regularized loss of the new ensemble, as defined in Equation (4), is less than that of the previous iteration's ensemble, the AdaNet algorithm continues onto the next iteration. AdaNet attempts to minimize the following loss function to learn the mixture weights 'w' of each subnetwork 'h' in the ensemble with differentiable convex non-increasing surrogate loss function Phi: Equation (4): .. math:: F(w) = \frac{1}{m} \sum_{i=1}^{m} \Phi \left(\sum_{j=1}^{N}w_jh_j(x_i), y_i \right) + \sum_{j=1}^{N} \left(\lambda r(h_j) + \beta \right) |w_j| with :math:`\lambda >= 0` and :math:`\beta >= 0`. This implementation uses an :class:`adanet.subnetwork.Generator` as its weak learning algorithm for generating candidate subnetworks. These are trained in parallel using a single graph per iteration. At the end of each iteration, the estimator saves the sub-graph of the best subnetwork ensemble and its weights as a separate checkpoint. At the beginning of the next iteration, the estimator imports the previous iteration's frozen graph and adds ops for the next candidates as part of a new graph and session. This allows the estimator have the performance of Tensorflow's static graph constraint (minus the performance hit of reconstructing a graph between iterations), while having the flexibility of having a dynamic graph. NOTE: Subclassing :class:`tf.estimator.Estimator` is only necessary to work with :meth:`tf.estimator.train_and_evaluate` which asserts that the estimator argument is a :class:`tf.estimator.Estimator` subclass. However, all training is delegated to a separate :class:`tf.estimator.Estimator` instance. It is responsible for supporting both local and distributed training. As such, the :class:`adanet.Estimator` is only responsible for bookkeeping across iterations. Args: head: A :class:`tf.contrib.estimator.Head` instance for computing loss and evaluation metrics for every candidate. subnetwork_generator: The :class:`adanet.subnetwork.Generator` which defines the candidate subnetworks to train and evaluate at every AdaNet iteration. max_iteration_steps: Total number of steps for which to train candidates per iteration. If :class:`OutOfRange` or :class:`StopIteration` occurs in the middle, training stops before `max_iteration_steps` steps. mixture_weight_type: The :class:`adanet.MixtureWeightType` defining which mixture weight type to learn in the linear combination of subnetwork outputs: - :class:`SCALAR`: creates a rank 0 tensor mixture weight . It performs an element- wise multiplication with its subnetwork's logits. This mixture weight is the simplest to learn, the quickest to train, and most likely to generalize well. - :class:`VECTOR`: creates a tensor with shape [k] where k is the ensemble's logits dimension as defined by `head`. It is similar to `SCALAR` in that it performs an element-wise multiplication with its subnetwork's logits, but is more flexible in learning a subnetworks's preferences per class. - :class:`MATRIX`: creates a tensor of shape [a, b] where a is the number of outputs from the subnetwork's `last_layer` and b is the number of outputs from the ensemble's `logits`. This weight matrix-multiplies the subnetwork's `last_layer`. This mixture weight offers the most flexibility and expressivity, allowing subnetworks to have outputs of different dimensionalities. However, it also has the most trainable parameters (a*b), and is therefore the most sensitive to learning rates and regularization. mixture_weight_initializer: The initializer for mixture_weights. When `None`, the default is different according to `mixture_weight_type`: - :class:`SCALAR`: initializes to 1/N where N is the number of subnetworks in the ensemble giving a uniform average. - :class:`VECTOR`: initializes each entry to 1/N where N is the number of subnetworks in the ensemble giving a uniform average. - :class:`MATRIX`: uses :meth:`tf.zeros_initializer`. warm_start_mixture_weights: Whether, at the beginning of an iteration, to initialize the mixture weights of the subnetworks from the previous ensemble to their learned value at the previous iteration, as opposed to retraining them from scratch. Takes precedence over the value for `mixture_weight_initializer` for subnetworks from previous iterations. adanet_lambda: Float multiplier 'lambda' for applying L1 regularization to subnetworks' mixture weights 'w' in the ensemble proportional to their complexity. See Equation (4) in the AdaNet paper. adanet_beta: Float L1 regularization multiplier 'beta' to apply equally to all subnetworks' weights 'w' in the ensemble regardless of their complexity. See Equation (4) in the AdaNet paper. evaluator: An :class:`adanet.Evaluator` for candidate selection after all subnetworks are done training. When `None`, candidate selection uses a moving average of their :class:`adanet.Ensemble` AdaNet loss during training instead. In order to use the *AdaNet algorithm* as described in [Cortes et al., '17], the given :class:`adanet.Evaluator` must be created with the same dataset partition used during training. Otherwise, this framework will perform *AdaNet.HoldOut* which uses a holdout set for candidate selection, but does not benefit from learning guarantees. report_materializer: An :class:`adanet.ReportMaterializer`. Its reports are made available to the `subnetwork_generator` at the next iteration, so that it can adapt its search space. When `None`, the `subnetwork_generator` :meth:`generate_candidates` method will receive empty Lists for their `previous_ensemble_reports` and `all_reports` arguments. use_bias: Whether to add a bias term to the ensemble's logits. Adding a bias allows the ensemble to learn a shift in the data, often leading to more stable training and better predictions. metric_fn: A function for adding custom evaluation metrics, which should obey the following signature: - `Args`: Can only have the following three arguments in any order: - `predictions`: Predictions `Tensor` or dict of `Tensor` created by given `head`. - `features`: Input `dict` of `Tensor` objects created by `input_fn` which is given to `estimator.evaluate` as an argument. - `labels`: Labels `Tensor` or dict of `Tensor` (for multi-head) created by `input_fn` which is given to `estimator.evaluate` as an argument. - `Returns`: Dict of metric results keyed by name. Final metrics are a union of this and `head's` existing metrics. If there is a name conflict between this and `head`s existing metrics, this will override the existing one. The values of the dict are the results of calling a metric function, namely a `(metric_tensor, update_op)` tuple. force_grow: Boolean override that forces the ensemble to grow by one subnetwork at the end of each iteration. Normally at the end of each iteration, AdaNet selects the best candidate ensemble according to its performance on the AdaNet objective. In some cases, the best ensemble is the `previous_ensemble` as opposed to one that includes a newly trained subnetwork. When `True`, the algorithm will not select the `previous_ensemble` as the best candidate, and will ensure that after n iterations the final ensemble is composed of n subnetworks. replicate_ensemble_in_training: Whether to rebuild the frozen subnetworks of the ensemble in training mode, which can change the outputs of the frozen subnetworks in the ensemble. When `False` and during candidate training, the frozen subnetworks in the ensemble are in prediction mode, so training-only ops like dropout are not applied to them. When `True` and training the candidates, the frozen subnetworks will be in training mode as well, so they will apply training-only ops like dropout. This argument is useful for regularizing learning mixture weights, or for making training-only side inputs available in subsequent iterations. For most use-cases, this should be `False`. adanet_loss_decay: Float decay for the exponential-moving-average of the AdaNet objective throughout training. This moving average is a data- driven way tracking the best candidate with only the training set. worker_wait_timeout_secs: Float number of seconds for workers to wait for chief to prepare the next iteration during distributed training. This is needed to prevent workers waiting indefinitely for a chief that may have crashed or been turned down. When the timeout is exceeded, the worker exits the train loop. In situations where the chief job is much slower than the worker jobs, this timeout should be increased. model_dir: Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to continue training a previously saved model. report_dir: Directory where the `adanet.subnetwork.MaterializedReport`s materialized by `report_materializer` would be saved. If `report_materializer` is None, this will not save anything. If `None` or empty string, defaults to "<model_dir>/report". config: `RunConfig` object to configure the runtime settings. **kwargs: Extra keyword args passed to the parent. Returns: An `Estimator` instance. Raises: ValueError: If `subnetwork_generator` is `None`. ValueError: If `max_iteration_steps` is <= 0. """ # pyformat: enable class _Keys(object): CURRENT_ITERATION = "current_iteration" EVALUATE_ENSEMBLES = "evaluate_ensembles" MATERIALIZE_REPORT = "materialize_report" INCREMENT_ITERATION = "increment_iteration" PREVIOUS_ENSEMBLE_ARCHITECTURE = "previous_ensemble_architecture" SUBNETWORK_GENERATOR = "subnetwork_generator" def __init__(self, head, subnetwork_generator, max_iteration_steps, mixture_weight_type=MixtureWeightType.SCALAR, mixture_weight_initializer=None, warm_start_mixture_weights=False, adanet_lambda=0., adanet_beta=0., evaluator=None, report_materializer=None, use_bias=False, metric_fn=None, force_grow=False, replicate_ensemble_in_training=False, adanet_loss_decay=.9, worker_wait_timeout_secs=7200, model_dir=None, report_dir=None, config=None, **kwargs): # TODO: Add argument to specify how many frozen graph # checkpoints to keep. if subnetwork_generator is None: raise ValueError("subnetwork_generator can't be None.") if max_iteration_steps <= 0.: raise ValueError("max_iteration_steps must be > 0.") self._subnetwork_generator = subnetwork_generator self._adanet_loss_decay = adanet_loss_decay # Overwrite superclass's assert that members are not overwritten in order # to overwrite public methods. Note that we are doing something that is not # explicitly supported by the Estimator API and may break in the future. tf.estimator.Estimator._assert_members_are_not_overridden = staticmethod( # pylint: disable=protected-access lambda _: None) self._evaluation_checkpoint_path = None self._evaluator = evaluator self._report_materializer = report_materializer self._force_grow = force_grow self._worker_wait_timeout_secs = worker_wait_timeout_secs self._evaluation_name = None self._inside_adanet_training_loop = False # This `Estimator` is responsible for bookkeeping across iterations, and # for training the subnetworks in both a local and distributed setting. # Subclassing improves future-proofing against new private methods being # added to `tf.estimator.Estimator` that are expected to be callable by # external functions, such as in b/110435640. super(Estimator, self).__init__( model_fn=self._adanet_model_fn, params={}, config=config, model_dir=model_dir, **kwargs) # These are defined after base Estimator's init so that they can # use the same temporary model_dir as the underlying Estimator even if # model_dir is not provided. self._ensemble_builder = _EnsembleBuilder( head=head, mixture_weight_type=mixture_weight_type, mixture_weight_initializer=mixture_weight_initializer, warm_start_mixture_weights=warm_start_mixture_weights, checkpoint_dir=self._model_dir, adanet_lambda=adanet_lambda, adanet_beta=adanet_beta, use_bias=use_bias, metric_fn=metric_fn) candidate_builder = _CandidateBuilder( max_steps=max_iteration_steps, adanet_loss_decay=self._adanet_loss_decay) self._iteration_builder = _IterationBuilder(candidate_builder, self._ensemble_builder, replicate_ensemble_in_training) report_dir = report_dir or os.path.join(self._model_dir, "report") self._report_accessor = _ReportAccessor(report_dir) def _latest_checkpoint_iteration_number(self): """Returns the iteration number from the latest checkpoint.""" latest_checkpoint = tf.train.latest_checkpoint(self.model_dir) if latest_checkpoint is None: return 0 return tf.contrib.framework.load_variable(latest_checkpoint, self._Keys.CURRENT_ITERATION) def _latest_checkpoint_architecture(self): """Returns the iteration number from the latest checkpoint.""" latest_checkpoint = tf.train.latest_checkpoint(self.model_dir) if latest_checkpoint is None: return "" return tf.contrib.framework.load_variable( latest_checkpoint, self._Keys.PREVIOUS_ENSEMBLE_ARCHITECTURE) def _latest_checkpoint_global_step(self): """Returns the global step from the latest checkpoint.""" latest_checkpoint = tf.train.latest_checkpoint(self.model_dir) if latest_checkpoint is None: return 0 return tf.contrib.framework.load_variable(latest_checkpoint, tf.GraphKeys.GLOBAL_STEP) @contextlib.contextmanager def _train_loop_context(self): """Tracks where the context is within the AdaNet train loop.""" self._inside_adanet_training_loop = True yield self._inside_adanet_training_loop = False
[docs] def train(self, input_fn, hooks=None, steps=None, max_steps=None, saving_listeners=None): if (steps is not None) and (max_steps is not None): raise ValueError("Can not provide both steps and max_steps.") if steps is not None and steps <= 0: raise ValueError("Must specify steps > 0, given: {}".format(steps)) if steps is not None: max_steps = self._latest_checkpoint_global_step() + steps # Each iteration of this AdaNet loop represents an `_Iteration`. The # current iteration number is stored as a variable in the checkpoint so # that training can be stopped and started at anytime. with self._train_loop_context(): while True: current_iteration = self._latest_checkpoint_iteration_number() tf.logging.info("Beginning training AdaNet iteration %s", current_iteration) self._iteration_ended = False result = super(Estimator, self).train( input_fn=input_fn, hooks=hooks, max_steps=max_steps, saving_listeners=saving_listeners) tf.logging.info("Finished training Adanet iteration %s", current_iteration) # If training ended because the maximum number of training steps # occurred, exit training. if self._latest_checkpoint_global_step() >= max_steps: return result # If training ended for any reason other than the iteration ending, # exit training. if not self._iteration_ended: return result tf.logging.info("Beginning bookkeeping phase for iteration %s", current_iteration) # The chief prepares the next AdaNet iteration, and increments the # iteration number by 1. if self.config.is_chief: # As the chief, store the train hooks and make a placeholder input_fn # in order to use them when preparing the next iteration. self._train_hooks = hooks or () self._prepare_next_iteration(input_fn) # This inner loop serves mainly for synchronizing the workers with the # chief during distributed training. Workers that finish training early # wait for the chief to prepare the next iteration and increment the # iteration number. Workers that are slow to finish training quickly # move onto the next iteration. And workers that go offline and return # online after training ended terminate gracefully. wait_for_chief = not self.config.is_chief timer = _CountDownTimer(self._worker_wait_timeout_secs) while wait_for_chief: # If the chief hits max_steps, it will stop training itself and not # increment the iteration number, so this is how the worker knows to # exit if it wakes up and the chief is gone. # TODO: Support steps parameter. if self._latest_checkpoint_global_step() >= max_steps: return result # In distributed training, a worker may end training before the chief # overwrites the checkpoint with the incremented iteration number. If # that is the case, it should wait for the chief to do so. Otherwise # the worker will get stuck waiting for its weights to be initialized. next_iteration = self._latest_checkpoint_iteration_number() if next_iteration > current_iteration: break # Check timeout when waiting for potentially downed chief. if timer.secs_remaining() == 0: tf.logging.error( "Chief job did not prepare next iteration after %s secs. It " "may have been preempted, been turned down, or crashed. This " "worker is now exiting training.", self._worker_wait_timeout_secs) return result tf.logging.info("Waiting for chief to finish") time.sleep(5) # Stagger starting workers to prevent training instability. if not self.config.is_chief: task_id = self.config.task_id or 0 # Wait 5 secs more for each new worker up to 60 secs. delay_secs = min(60, task_id * 5) tf.logging.info("Waiting %d secs before starting training.", delay_secs) time.sleep(delay_secs) tf.logging.info("Finished bookkeeping phase for iteration %s", current_iteration)
[docs] def evaluate(self, input_fn, steps=None, hooks=None, checkpoint_path=None, name=None): if not checkpoint_path: checkpoint_path = tf.train.latest_checkpoint(self.model_dir) # Ensure that the read to get the iteration number and read to restore # variable values come from the same checkpoint during evaluation. self._evaluation_checkpoint_path = checkpoint_path self._evaluation_name = name result = super(Estimator, self).evaluate( input_fn, steps=steps, hooks=hooks, checkpoint_path=checkpoint_path, name=name) self._evaluation_checkpoint_path = None return result
def _call_adanet_model_fn(self, input_fn, mode, params): """Calls model_fn with the given mode and parameters.""" with tf.Graph().as_default(): tf.set_random_seed(self.config.tf_random_seed) # Create global step before calling model_fn as does superclass. tf.train.get_or_create_global_step() features, labels = input_fn() self._adanet_model_fn(features, labels, mode, params) def _prepare_next_iteration(self, train_input_fn): """Prepares the next iteration. This method calls model_fn up to four times: 1. To evaluate all candidate ensembles to find the best one. 2. To materialize reports and store them to disk (if report_materializer exists). 3. To overwrite the model directory's checkpoint with the next iteration's ops. Args: train_input_fn: The input_fn used during training. """ # First, evaluate and choose the best ensemble for this iteration. params = self.params.copy() params[self._Keys.EVALUATE_ENSEMBLES] = True if self._evaluator: evaluator_input_fn = self._evaluator.input_fn else: evaluator_input_fn = train_input_fn self._call_adanet_model_fn(evaluator_input_fn, tf.estimator.ModeKeys.EVAL, params) # Then materialize and store the subnetwork reports. if self._report_materializer: params = self.params.copy() params[self._Keys.MATERIALIZE_REPORT] = True self._call_adanet_model_fn(self._report_materializer.input_fn, tf.estimator.ModeKeys.EVAL, params) self._best_ensemble_index = None # Finally, create the graph for the next iteration and overwrite the model # directory checkpoint with the expanded graph. params = self.params.copy() params[self._Keys.INCREMENT_ITERATION] = True self._call_adanet_model_fn(train_input_fn, tf.estimator.ModeKeys.TRAIN, params) def _architecture_filename(self, iteration_number): """Returns the filename of the given iteration's frozen graph.""" frozen_checkpoint = os.path.join(self.model_dir, "architecture") return "{}-{}.txt".format(frozen_checkpoint, iteration_number) def _overwrite_checkpoint(self, current_iteration, iteration_number_tensor): """Overwrites the latest checkpoint with the current graph. This is necessary for two reasons: 1. To add variables to the checkpoint that were newly created for the next iteration. Otherwise Estimator will raise an exception for having a checkpoint missing variables. 2. To increment the current iteration number so that workers know when to begin training the next iteration. Args: current_iteration: Current `_Iteration` object. iteration_number_tensor: Int variable `Tensor` storing the current iteration number. """ checkpoint_state = tf.train.get_checkpoint_state(self.model_dir) latest_checkpoint = checkpoint_state.model_checkpoint_path if not latest_checkpoint: return # Run train hook 'begin' methods which can add ops to the graph, so that # they are still present in the overwritten checkpoint. train_hooks = tuple(self._train_hooks) or () for candidate in current_iteration.candidates: if not candidate.ensemble_spec.subnetwork_train_op: assert not candidate.ensemble_spec.ensemble_train_op continue train_hooks += candidate.ensemble_spec.subnetwork_train_op.chief_hooks train_hooks += candidate.ensemble_spec.subnetwork_train_op.hooks train_hooks += candidate.ensemble_spec.ensemble_train_op.chief_hooks train_hooks += candidate.ensemble_spec.ensemble_train_op.hooks for hook in train_hooks: hook.begin() global_step_tensor = tf.train.get_global_step() global_step = tf.contrib.framework.load_variable(latest_checkpoint, tf.GraphKeys.GLOBAL_STEP) checkpoint_path = os.path.join(self.model_dir, "increment.ckpt") with tf.Session(target=self.config.master) as sess: init = tf.group( tf.global_variables_initializer(), tf.local_variables_initializer(), tf.tables_initializer(), resources.initialize_resources(resources.shared_resources())) sess.run(init) coord = tf.train.Coordinator() tf.train.start_queue_runners(sess=sess, coord=coord) control_deps = [ tf.assign(global_step_tensor, global_step), tf.assign(iteration_number_tensor, current_iteration.number), ] with tf.control_dependencies(control_deps): saver = tf.train.Saver( sharded=True, max_to_keep=self.config.keep_checkpoint_max) saver.recover_last_checkpoints( checkpoint_state.all_model_checkpoint_paths) saver.save(sess, checkpoint_path, global_step=current_iteration.number) for hook in train_hooks: hook.end(sess) def _get_best_ensemble_index(self, current_iteration): """Returns the best candidate ensemble's index in this iteration. Evaluates the ensembles using an `Evaluator` when provided. Otherwise, it returns the index of the best candidate as defined by the `_Iteration`. Args: current_iteration: Current `_Iteration`. Returns: Index of the best ensemble in the iteration's list of `_Candidates`. """ # Skip the evaluation phase when there is only one candidate subnetwork. if len(current_iteration.candidates) == 1: tf.logging.info( "As the only candidate, '%s' is moving onto the next iteration.", current_iteration.candidates[0].ensemble_spec.name) return 0 # The zero-th index candidate at iteration t>0 is always the # previous_ensemble. if current_iteration.number > 0 and self._force_grow and (len( current_iteration.candidates) == 2): tf.logging.info( "As the only candidate with `force_grow` enabled, '%s' is moving" "onto the next iteration.", current_iteration.candidates[1].ensemble_spec.name) return 1 latest_checkpoint = tf.train.latest_checkpoint(self.model_dir) tf.logging.info("Starting ensemble evaluation for iteration %s", current_iteration.number) with tf.Session() as sess: init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer(), tf.tables_initializer()) sess.run(init) saver = tf.train.Saver(sharded=True) saver.restore(sess, latest_checkpoint) coord = tf.train.Coordinator() tf.train.start_queue_runners(sess=sess, coord=coord) if self._evaluator: adanet_losses = [ c.ensemble_spec.adanet_loss for c in current_iteration.candidates ] adanet_losses = self._evaluator.evaluate_adanet_losses( sess, adanet_losses) else: adanet_losses = sess.run( [c.adanet_loss for c in current_iteration.candidates]) values = [] for i in range(len(current_iteration.candidates)): metric_name = "adanet_loss" ensemble_name = current_iteration.candidates[i].ensemble_spec.name values.append("{}/{} = {:.6f}".format(metric_name, ensemble_name, adanet_losses[i])) tf.logging.info("Computed ensemble metrics: %s", ", ".join(values)) if self._force_grow and current_iteration.number > 0: tf.logging.info( "The `force_grow` override is enabled, so the " "the performance of the previous ensemble will be ignored.") # NOTE: The zero-th index candidate at iteration t>0 is always the # previous_ensemble. adanet_losses = adanet_losses[1:] index = np.argmin(adanet_losses) + 1 else: index = np.argmin(adanet_losses) tf.logging.info("Finished ensemble evaluation for iteration %s", current_iteration.number) tf.logging.info("'%s' at index %s is moving onto the next iteration", current_iteration.candidates[index].ensemble_spec.name, index) return index def _materialize_report(self, current_iteration): """Generates reports as defined by `Builder`s. Materializes the Tensors and metrics defined in the `Builder`s' `build_subnetwork_report` method using `ReportMaterializer`, and stores them to disk using `_ReportAccessor`. Args: current_iteration: Current `_Iteration`. """ latest_checkpoint = tf.train.latest_checkpoint(self.model_dir) tf.logging.info("Starting metric logging for iteration %s", current_iteration.number) assert self._best_ensemble_index is not None best_candidate = current_iteration.candidates[self._best_ensemble_index] best_ensemble = best_candidate.ensemble_spec.ensemble best_name = best_ensemble.weighted_subnetworks[-1].name included_subnetwork_names = [best_name] with tf.Session() as sess: init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer(), tf.tables_initializer()) sess.run(init) saver = tf.train.Saver(sharded=True) saver.restore(sess, latest_checkpoint) coord = tf.train.Coordinator() tf.train.start_queue_runners(sess=sess, coord=coord) materialized_reports = ( self._report_materializer.materialize_subnetwork_reports( sess, current_iteration.number, current_iteration.subnetwork_reports, included_subnetwork_names)) self._report_accessor.write_iteration_report(current_iteration.number, materialized_reports) tf.logging.info("Finished saving subnetwork reports for iteration %s", current_iteration.number) def _training_hooks(self, current_iteration, training): """Returns training hooks for this iteration. Args: current_iteration: Current `_Iteration`. training: Whether in training mode. Returns: A list of `tf.train.SessionRunHook` instances. """ if not training: return [] def after_fn(): self._iteration_ended = True training_hooks = list(current_iteration.estimator_spec.training_hooks) + [ _StopAfterTrainingHook(current_iteration, after_fn=after_fn) ] for summary in current_iteration.summaries: output_dir = self.model_dir if summary.scope: output_dir = os.path.join(output_dir, "candidate", summary.scope) summary_saver_hook = tf.train.SummarySaverHook( save_steps=self.config.save_summary_steps, output_dir=output_dir, summary_op=summary.merge_all()) training_hooks.append(summary_saver_hook) return training_hooks def _evaluation_hooks(self, current_iteration, training): """Returns evaluation hooks for this iteration. Args: current_iteration: Current `_Iteration`. training: Whether in training mode. Returns: A list of `tf.train.SessionRunHook` instances. """ if training: return [] evaluation_hooks = [] for candidate in current_iteration.candidates: eval_subdir = "eval" if self._evaluation_name: eval_subdir = "eval_{}".format(self._evaluation_name) eval_metric_hook = _EvalMetricSaverHook( name=candidate.ensemble_spec.name, eval_metric_ops=candidate.ensemble_spec.eval_metric_ops, output_dir=os.path.join(self.model_dir, "candidate", candidate.ensemble_spec.name, eval_subdir)) evaluation_hooks.append(eval_metric_hook) return evaluation_hooks def _save_architecture(self, filename, ensemble): """Persists the ensemble's architecture in a serialized format. Writes to a text file with one subnetwork's iteration number and name per line. Args: filename: String filename to persist the ensemble architecture. ensemble: Target `adanet.Ensemble` instance. """ architecture = [ "{}:{}".format(w.iteration_number, w.name) for w in ensemble.weighted_subnetworks ] # Make directories since model_dir may not have been created yet. tf.gfile.MakeDirs(os.path.dirname(filename)) with tf.gfile.GFile(filename, "w") as record_file: record_file.write(os.linesep.join(architecture)) def _read_architecture(self, filename): """Reads an ensemble architecture from disk. Assumes the file was written with `_save_architecture`. Args: filename: String filename where features were recorded. Returns: A list of <iteration_number>:<subnetwork name> strings. Raises: OSError: When file not found at `filename`. """ if not tf.gfile.Exists(filename): raise OSError(errno.ENOENT, os.strerror(errno.ENOENT), filename) architecture = [] with tf.gfile.GFile(filename, "r") as record_file: for line in record_file: feature_name = line.rstrip() architecture.append(feature_name) return architecture # TODO: Refactor architecture building logic to its own module. def _architecture_ensemble_spec(self, architecture, features, mode, labels): """Returns an `_EnsembleSpec` with the given architecture. Creates the ensemble architecture by calling `generate_subnetworks` on `self._subnetwork_generator` and only calling `build_subnetwork` on `Builders` included in the architecture. Once their ops are created, their variables are restored from the checkpoint. Args: architecture: A list of <iteration_number>:<subnetwork name> strings. features: Dictionary of `Tensor` objects keyed by feature name. mode: Defines whether this is training, evaluation or prediction. See `ModeKeys`. labels: Labels `Tensor` or a dictionary of string label name to `Tensor` (for multi-head). Can be `None`. Returns: An `EnsembleSpec` instance for the given architecture. Raises: ValueError: If a subnetwork from `architecture` is not found in the generated candidate `Builders` of the specified iteration. """ previous_ensemble_spec = None previous_ensemble = None for serialized_subnetwork in architecture: serialized_iteration_number, name = serialized_subnetwork.split(":") rebuild_iteration_number = int(serialized_iteration_number) previous_ensemble_reports, all_reports = [], [] if self._report_materializer: previous_ensemble_reports, all_reports = ( self._collate_subnetwork_reports(rebuild_iteration_number)) generated_subnetwork_builders = ( self._subnetwork_generator.generate_candidates( previous_ensemble=previous_ensemble, iteration_number=rebuild_iteration_number, previous_ensemble_reports=previous_ensemble_reports, all_reports=all_reports)) rebuild_subnetwork_builder = None for builder in generated_subnetwork_builders: if builder.name == name: rebuild_subnetwork_builder = builder break if rebuild_subnetwork_builder is None: raise ValueError("Required subnetwork name is missing from " "generated candidates: {}".format(name)) previous_ensemble_summary = None if previous_ensemble_spec: # Always skip summaries when rebuilding previous architecture, # since they are not useful. previous_ensemble_summary = _ScopedSummary( previous_ensemble_spec.name, skip_summary=True) current_iteration = self._iteration_builder.build_iteration( iteration_number=rebuild_iteration_number, subnetwork_builders=[rebuild_subnetwork_builder], features=features, labels=labels, mode=mode, previous_ensemble_summary=previous_ensemble_summary, previous_ensemble_spec=previous_ensemble_spec, rebuilding=True) previous_ensemble_spec = current_iteration.candidates[-1].ensemble_spec previous_ensemble = previous_ensemble_spec.ensemble return previous_ensemble_spec def _collate_subnetwork_reports(self, iteration_number): """Prepares subnetwork.Reports to be passed to Generator. Reads subnetwork.MaterializedReports from past iterations, collates those that were included in previous_ensemble into previous_ensemble_reports as a List of subnetwork.MaterializedReports, and collates all reports from previous iterations into all_reports as another List of subnetwork.MaterializedReports. Args: iteration_number: Python integer AdaNet iteration number, starting from 0. Returns: (previous_ensemble_reports: List<subnetwork.MaterializedReport>, materialized_reports: List<MaterializedReport>) """ materialized_reports_all = (self._report_accessor.read_iteration_reports()) previous_ensemble_reports = [] all_reports = [] # Since the number of iteration reports changes after the # MATERIALIZE_REPORT phase, we need to make sure that we always pass the # same reports to the Generator in the same iteration, # otherwise the graph that is built in the FREEZE_ENSEMBLE phase would be # different from the graph built in the training phase. # Iteration 0 should have 0 iteration reports passed to the # Generator, since there are no previous iterations. # Iteration 1 should have 1 list of reports for Builders # generated in iteration 0. # Iteration 2 should have 2 lists of reports -- one for iteration 0, # one for iteration 1. Note that the list of reports for iteration >= 1 # should contain "previous_ensemble", in addition to the # Builders at the start of that iteration. # Iteration t should have t lists of reports. for i, iteration_reports in enumerate(materialized_reports_all): # This ensures that the FREEZE_ENSEMBLE phase does not pass the reports # generated in the previous phase of the same iteration to the # Generator when building the graph. if i >= iteration_number: break # Assumes that only one subnetwork is added to the ensemble in # each iteration. chosen_subnetwork_in_this_iteration = [ subnetwork_report for subnetwork_report in iteration_reports if subnetwork_report.included_in_final_ensemble ][0] previous_ensemble_reports.append(chosen_subnetwork_in_this_iteration) all_reports.extend(iteration_reports) return previous_ensemble_reports, all_reports def _adanet_model_fn(self, features, labels, mode, params): """AdaNet model_fn. This model_fn is called at least three times per iteration: 1. The first call generates, builds, and trains the candidate subnetworks to ensemble in this iteration. 2. Once training is over, bookkeeping begins. The next call is to evaluate the best candidate ensembles according to the AdaNet objective. 2.b. Optionally, when a report materializer is provided, another call creates the graph for producing subnetwork reports for the next iteration and other AdaNet runs. 3. The final call is responsible for rebuilding the ensemble architecture from t-1 by regenerating the best builders and warm-starting their weights, adding ops and initialing the weights for the next candidate subnetworks, and overwriting the latest checkpoint with its graph and variables, so that first call of the next iteration has the right variables in the checkpoint. Args: features: Dictionary of `Tensor` objects keyed by feature name. labels: Labels `Tensor` or a dictionary of string label name to `Tensor` (for multi-head). Can be `None`. mode: Defines whether this is training, evaluation or prediction. See `ModeKeys`. params: A dict of parameters. Returns: A `EstimatorSpec` instance. Raises: UserWarning: When calling model_fn directly in TRAIN mode. """ training = mode == tf.estimator.ModeKeys.TRAIN if training and not self._inside_adanet_training_loop: raise UserWarning( "The adanet.Estimator's model_fn should not be called directly in " "TRAIN mode, because its behavior is undefined outside the context " "of its `train` method. If you are trying to add custom metrics " "with `tf.contrib.estimator.add_metrics`, pass the `metric_fn` to " "this `Estimator's` constructor instead.") iteration_number = self._latest_checkpoint_iteration_number() # Use the evaluation checkpoint path to get both the iteration number and # variable values to avoid any race conditions between the first and second # checkpoint reads. if mode == tf.estimator.ModeKeys.EVAL and self._evaluation_checkpoint_path: iteration_number = tf.contrib.framework.load_variable( self._evaluation_checkpoint_path, self._Keys.CURRENT_ITERATION) if self._Keys.INCREMENT_ITERATION in params: iteration_number += 1 architecture_filename = self._architecture_filename(iteration_number - 1) architecture = [] if tf.gfile.Exists(architecture_filename): architecture = self._read_architecture(architecture_filename) tf.logging.info( "Importing architecture from %s: [%s].", architecture_filename, ", ".join(sorted(["'{}'".format(f) for f in architecture]))) skip_summaries = mode == tf.estimator.ModeKeys.PREDICT with tf.variable_scope("adanet"): previous_ensemble_spec = None previous_ensemble = None previous_ensemble_summary = None if architecture: previous_ensemble_spec = self._architecture_ensemble_spec( architecture, features, mode, labels) previous_ensemble = previous_ensemble_spec.ensemble previous_ensemble_summary = _ScopedSummary( previous_ensemble_spec.name, skip_summary=skip_summaries) if self._Keys.INCREMENT_ITERATION in params: latest_checkpoint = tf.train.latest_checkpoint(self.model_dir) tf.train.warm_start(latest_checkpoint, vars_to_warm_start=[".*"]) previous_ensemble_reports, all_reports = [], [] if self._report_materializer: previous_ensemble_reports, all_reports = ( self._collate_subnetwork_reports(iteration_number)) subnetwork_builders = self._subnetwork_generator.generate_candidates( previous_ensemble=previous_ensemble, iteration_number=iteration_number, previous_ensemble_reports=previous_ensemble_reports, all_reports=all_reports) current_iteration = self._iteration_builder.build_iteration( iteration_number=iteration_number, subnetwork_builders=subnetwork_builders, features=features, labels=labels, mode=mode, previous_ensemble_summary=previous_ensemble_summary, previous_ensemble_spec=previous_ensemble_spec) # Variable which allows us to read the current iteration from a checkpoint. iteration_number_tensor = tf.get_variable( self._Keys.CURRENT_ITERATION, shape=[], dtype=tf.int64, initializer=tf.zeros_initializer(), trainable=False) adanet_summary = _ScopedSummary("global", skip_summaries) adanet_summary.scalar("iteration/adanet/iteration", iteration_number_tensor) adanet_summary.scalar("iteration_step/adanet/iteration_step", current_iteration.step) if current_iteration.estimator_spec.loss is not None: adanet_summary.scalar("loss", current_iteration.estimator_spec.loss) adanet_summary.scalar("loss/adanet/adanet_weighted_ensemble", current_iteration.estimator_spec.loss) iteration_estimator_spec = current_iteration.estimator_spec estimator_spec = tf.estimator.EstimatorSpec( mode=mode, predictions=iteration_estimator_spec.predictions, loss=iteration_estimator_spec.loss, train_op=iteration_estimator_spec.train_op, eval_metric_ops=iteration_estimator_spec.eval_metric_ops, training_chief_hooks=iteration_estimator_spec.training_chief_hooks, training_hooks=self._training_hooks(current_iteration, training), evaluation_hooks=self._evaluation_hooks(current_iteration, training), scaffold=tf.train.Scaffold(summary_op=adanet_summary.merge_all()), export_outputs=iteration_estimator_spec.export_outputs) if self._Keys.EVALUATE_ENSEMBLES in params: assert self.config.is_chief self._best_ensemble_index = self._get_best_ensemble_index( current_iteration) ensemble = current_iteration.candidates[ self._best_ensemble_index].ensemble_spec.ensemble new_architecture_filename = self._architecture_filename(iteration_number) self._save_architecture(new_architecture_filename, ensemble) elif self._Keys.MATERIALIZE_REPORT in params: assert self.config.is_chief assert self._best_ensemble_index is not None self._materialize_report(current_iteration) elif self._Keys.INCREMENT_ITERATION in params: assert self.config.is_chief latest_checkpoint = tf.train.latest_checkpoint(self.model_dir) tf.logging.info( "Overwriting checkpoint with new graph for iteration %s to %s", iteration_number, latest_checkpoint) self._overwrite_checkpoint(current_iteration, iteration_number_tensor) return estimator_spec