Source code for adanet.core.ensemble

"""An AdaNet ensemble definition in Tensorflow using a single graph.

Copyright 2018 The AdaNet Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import contextlib
import functools
import inspect

from adanet.core.subnetwork import TrainOpSpec
import tensorflow as tf

from tensorflow.python.summary import summary as summary_lib
from tensorflow.python.training import training_util

_VALID_METRIC_FN_ARGS = set(["features", "labels", "predictions"])


[docs]class WeightedSubnetwork( collections.namedtuple( "WeightedSubnetwork", ["name", "iteration_number", "weight", "logits", "subnetwork"])): """An AdaNet weighted subnetwork. A weighted subnetwork is a weight 'w' applied to a subnetwork's last layer 'u'. The results is the weighted subnetwork's logits, regularized by its complexity. Args: name: String name of `subnetwork` as defined by its :class:`adanet.subnetwork.Builder`. iteration_number: Integer iteration when the subnetwork was created. weight: The weight :class:`tf.Tensor` or dict of string to weight :class:`tf.Tensor` (for multi-head) to apply to this subnetwork. The AdaNet paper refers to this weight as 'w' in Equations (4), (5), and (6). logits: The output :class:`tf.Tensor` or dict of string to weight :class:`tf.Tensor` (for multi-head) after the matrix multiplication of `weight` and the subnetwork's :meth:`last_layer`. The output's shape is [batch_size, logits_dimension]. It is equivalent to a linear logits layer in a neural network. subnetwork: The :class:`adanet.subnetwork.Subnetwork` to weight. Returns: An :class:`adanet.WeightedSubnetwork` object. """ def __new__(cls, name="", iteration_number=0, weight=None, logits=None, subnetwork=None): return super(WeightedSubnetwork, cls).__new__( cls, name=name, iteration_number=iteration_number, weight=weight, logits=logits, subnetwork=subnetwork)
[docs]class Ensemble( collections.namedtuple("Ensemble", ["weighted_subnetworks", "bias", "logits"])): """An AdaNet ensemble. An ensemble is a collection of subnetworks which forms a neural network through the weighted sum of their outputs. It is represented by 'f' throughout the AdaNet paper. Its component subnetworks' weights are complexity regularized (Gamma) as defined in Equation (4). Args: weighted_subnetworks: List of :class:`adanet.WeightedSubnetwork` instances that form this ensemble. Ordered from first to most recent. bias: Bias term :class:`tf.Tensor` or dict of string to bias term :class:`tf.Tensor` (for multi-head) for the ensemble's logits. logits: Logits :class:`tf.Tensor` or dict of string to logits :class:`tf.Tensor` (for multi-head). The result of the function 'f' as defined in Section 5.1 which is the sum of the logits of all :class:`adanet.WeightedSubnetwork` instances in ensemble. Returns: An :class:`adanet.Ensemble` instance. """ def __new__(cls, weighted_subnetworks, bias, logits): # TODO: Make weighted_subnetworks property a tuple so that # `Ensemble` is immutable. return super(Ensemble, cls).__new__( cls, weighted_subnetworks=weighted_subnetworks, bias=bias, logits=logits)
class _EnsembleSpec( collections.namedtuple("_EnsembleSpec", [ "name", "ensemble", "predictions", "loss", "adanet_loss", "subnetwork_train_op", "ensemble_train_op", "eval_metric_ops", "export_outputs", ])): """A collections of a ensemble training and evaluation `Tensors`.""" def __new__(cls, name, ensemble, predictions, loss=None, adanet_loss=None, subnetwork_train_op=None, ensemble_train_op=None, eval_metric_ops=None, export_outputs=None): """Creates an `EnsembleSpec` instance. Args: name: String name of this ensemble. Should be unique in the graph. ensemble: The `Ensemble` of interest. predictions: Predictions `Tensor` or dict of `Tensor`. loss: Loss `Tensor` as defined by the surrogate loss function Phi in Equations (4), (5), and (6). Must be either scalar, or with shape `[1]`. adanet_loss: Loss `Tensor` as defined by F(w) in Equation (4). Must be either scalar, or with shape `[1]`. The AdaNet algorithm aims to minimize this objective which balances training loss with the total complexity of the subnetworks in the ensemble. subnetwork_train_op: Candidate subnetwork's `TrainOpSpec`. ensemble_train_op: Candidate ensemble's mixture weights `TrainOpSpec`. eval_metric_ops: Dict of metric results keyed by name. The values of the dict are the results of calling a metric function, namely a `(metric_tensor, update_op)` tuple. `metric_tensor` should be evaluated without any impact on state (typically is a pure computation based on variables.). For example, it should not trigger the `update_op` or require any input fetching. export_outputs: Describes the output signatures to be exported to `SavedModel` and used during serving. See `tf.estimator.EstimatorSpec`. Returns: An `EnsembleSpec` object. """ # TODO: Make weighted_subnetworks property a tuple so that # `Ensemble` is immutable. return super(_EnsembleSpec, cls).__new__( cls, name=name, ensemble=ensemble, predictions=predictions, loss=loss, adanet_loss=adanet_loss, subnetwork_train_op=subnetwork_train_op, ensemble_train_op=ensemble_train_op, eval_metric_ops=eval_metric_ops, export_outputs=export_outputs)
[docs]class MixtureWeightType(object): """Mixture weight types available for learning subnetwork contributions. The following mixture weight types are defined: * `SCALAR`: Produces a rank 0 `Tensor` mixture weight. * `VECTOR`: Produces a rank 1 `Tensor` mixture weight. * `MATRIX`: Produces a rank 2 `Tensor` mixture weight. """ SCALAR = "scalar" VECTOR = "vector" MATRIX = "matrix"
def _architecture_as_metric(weighted_subnetworks): """Returns a representation of the ensemble's architecture as a tf.metric.""" joined_names = " | ".join([w.name for w in weighted_subnetworks]) architecture = tf.convert_to_tensor( "| {} |".format(joined_names), name="architecture") architecture_summary = tf.summary.text("architecture/adanet", architecture) return (architecture_summary, tf.no_op()) def _call_metric_fn(metric_fn, features, labels, predictions): """Calls metric fn with proper arguments.""" if not metric_fn: return {} metric_fn_args = inspect.getargspec(metric_fn).args kwargs = {} if "features" in metric_fn_args: kwargs["features"] = features if "labels" in metric_fn_args: kwargs["labels"] = labels if "predictions" in metric_fn_args: kwargs["predictions"] = predictions return metric_fn(**kwargs) def _verify_metric_fn_args(metric_fn): if not metric_fn: return args = set(inspect.getargspec(metric_fn).args) invalid_args = list(args - _VALID_METRIC_FN_ARGS) if invalid_args: raise ValueError("metric_fn (%s) has following not expected args: %s" % (metric_fn, invalid_args)) def _add_eval_metric_ops(eval_metric_ops, group_name, estimator_spec, metric_fn): """Adds eval metric ops to the given dictionary for the given group name.""" eval_metric_ops["loss/adanet/{}".format(group_name)] = tf.metrics.mean( estimator_spec.loss) metric_ops = estimator_spec.eval_metric_ops for metric in sorted(metric_ops): eval_metric_ops["{metric}/adanet/{group_name}".format( metric=metric, group_name=group_name)] = metric_ops[metric] metric_ops = metric_fn(predictions=estimator_spec.predictions) for metric in sorted(metric_ops): eval_metric_ops["{metric}/adanet/{group_name}".format( metric=metric, group_name=group_name)] = metric_ops[metric] def _get_value(target, key): if isinstance(target, dict): return target[key] return target def _to_train_op_spec(train_op): if isinstance(train_op, TrainOpSpec): return train_op return TrainOpSpec(train_op) @contextlib.contextmanager def _subnetwork_context(iteration_step_scope, scoped_summary): """Monkey-patches global attributes with subnetwork-specifics ones.""" old_get_global_step_fn = tf.train.get_global_step old_get_or_create_global_step_fn = tf.train.get_or_create_global_step old_summary_scalar = summary_lib.scalar old_summary_image = summary_lib.image old_summary_histogram = summary_lib.histogram old_summary_audio = summary_lib.audio def iteration_step(graph=None): del graph with tf.variable_scope(iteration_step_scope, reuse=tf.AUTO_REUSE): return tf.get_variable( "iteration_step", shape=[], initializer=tf.zeros_initializer(), trainable=False, dtype=tf.int64) # Monkey-patch global attributes. tf.summary.scalar = scoped_summary.scalar tf.summary.image = scoped_summary.image tf.summary.histogram = scoped_summary.histogram tf.summary.audio = scoped_summary.audio summary_lib.scalar = scoped_summary.scalar summary_lib.image = scoped_summary.image summary_lib.histogram = scoped_summary.histogram summary_lib.audio = scoped_summary.audio tf.train.get_global_step = iteration_step tf.train.get_or_create_global_step = iteration_step training_util.get_global_step = iteration_step training_util.get_or_create_global_step = iteration_step try: yield finally: # Revert monkey-patches. training_util.get_or_create_global_step = old_get_or_create_global_step_fn training_util.get_global_step = old_get_global_step_fn tf.train.get_or_create_global_step = old_get_or_create_global_step_fn tf.train.get_global_step = old_get_global_step_fn summary_lib.audio = old_summary_audio summary_lib.histogram = old_summary_histogram summary_lib.image = old_summary_image summary_lib.scalar = old_summary_scalar tf.summary.audio = old_summary_audio tf.summary.histogram = old_summary_histogram tf.summary.image = old_summary_image tf.summary.scalar = old_summary_scalar def _clear_trainable_variables(): del tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)[:] def _set_trainable_variables(var_list): _clear_trainable_variables() for var in var_list: tf.add_to_collections(tf.GraphKeys.TRAINABLE_VARIABLES, var) class _EnsembleBuilder(object): """Builds `Ensemble` instances.""" def __init__(self, head, mixture_weight_type, mixture_weight_initializer=None, warm_start_mixture_weights=False, checkpoint_dir=None, adanet_lambda=0., adanet_beta=0., use_bias=True, metric_fn=None): """Returns an initialized `_EnsembleBuilder`. Args: head: A `tf.contrib.estimator.Head` instance. mixture_weight_type: The `adanet.MixtureWeightType` defining which mixture weight type to learn. mixture_weight_initializer: The initializer for mixture_weights. When `None`, the default is different according to `mixture_weight_type`. `SCALAR` initializes to 1/N where N is the number of subnetworks in the ensemble giving a uniform average. `VECTOR` initializes each entry to 1/N where N is the number of subnetworks in the ensemble giving a uniform average. `MATRIX` uses `tf.zeros_initializer`. warm_start_mixture_weights: Whether, at the beginning of an iteration, to initialize the mixture weights of the subnetworks from the previous ensemble to their learned value at the previous iteration, as opposed to retraining them from scratch. Takes precedence over the value for `mixture_weight_initializer` for subnetworks from previous iterations. checkpoint_dir: The checkpoint_dir to use for warm-starting mixture weights and bias at the logit layer. Ignored if warm_start_mixture_weights is False. adanet_lambda: Float multiplier 'lambda' for applying L1 regularization to subnetworks' mixture weights 'w' in the ensemble proportional to their complexity. See Equation (4) in the AdaNet paper. adanet_beta: Float L1 regularization multiplier 'beta' to apply equally to all subnetworks' weights 'w' in the ensemble regardless of their complexity. See Equation (4) in the AdaNet paper. use_bias: Whether to add a bias term to the ensemble's logits. metric_fn: A function which should obey the following signature: - Args: can only have following three arguments in any order: * predictions: Predictions `Tensor` or dict of `Tensor` created by given `Head`. * features: Input `dict` of `Tensor` objects created by `input_fn` which is given to `estimator.evaluate` as an argument. * labels: Labels `Tensor` or dict of `Tensor` (for multi-head) created by `input_fn` which is given to `estimator.evaluate` as an argument. - Returns: Dict of metric results keyed by name. Final metrics are a union of this and `Head's` existing metrics. If there is a name conflict between this and `estimator`s existing metrics, this will override the existing one. The values of the dict are the results of calling a metric function, namely a `(metric_tensor, update_op)` tuple. Returns: An `_EnsembleBuilder` instance. Raises: ValueError: if warm_start_mixture_weights is True but checkpoint_dir is None. ValueError: if metric_fn is invalid. """ if warm_start_mixture_weights: if checkpoint_dir is None: raise ValueError("checkpoint_dir cannot be None when " "warm_start_mixture_weights is True.") _verify_metric_fn_args(metric_fn) self._head = head self._mixture_weight_type = mixture_weight_type self._mixture_weight_initializer = mixture_weight_initializer self._warm_start_mixture_weights = warm_start_mixture_weights self._checkpoint_dir = checkpoint_dir self._adanet_lambda = adanet_lambda self._adanet_beta = adanet_beta self._use_bias = use_bias self._metric_fn = metric_fn def append_new_subnetwork(self, ensemble_name, ensemble_spec, subnetwork_builder, iteration_number, iteration_step, summary, features, mode, labels=None): """Adds a `Subnetwork` to an `_EnsembleSpec`. For iteration t > 0, the ensemble is built given the `Ensemble` for t-1 and the new subnetwork to train as part of the ensemble. The `Ensemble` at iteration 0 is comprised of just the subnetwork. The subnetwork is first given a weight 'w' in a `WeightedSubnetwork` which determines its contribution to the ensemble. The subnetwork's complexity L1-regularizes this weight. Args: ensemble_name: String name of the ensemble. ensemble_spec: The recipient `_EnsembleSpec` for the `Subnetwork`. subnetwork_builder: A `adanet.Builder` instance which defines how to train the subnetwork and ensemble mixture weights. iteration_number: Integer current iteration number. iteration_step: Integer `Tensor` representing the step since the beginning of the current iteration, as opposed to the global step. summary: A `_ScopedSummary` instance for recording ensemble summaries. features: Input `dict` of `Tensor` objects. mode: Estimator's `ModeKeys`. labels: Labels `Tensor` or a dictionary of string label name to `Tensor` (for multi-head). Can be `None`. Returns: An new `EnsembleSpec` instance with the `Subnetwork` appended. """ with tf.variable_scope("ensemble_{}".format(ensemble_name)): weighted_subnetworks = [] subnetwork_index = 0 num_subnetworks = 1 ensemble = None if ensemble_spec: ensemble = ensemble_spec.ensemble previous_subnetworks = [ ensemble.weighted_subnetworks[index] for index in subnetwork_builder.prune_previous_ensemble(ensemble) ] num_subnetworks += len(previous_subnetworks) for weighted_subnetwork in previous_subnetworks: weight_initializer = None if self._warm_start_mixture_weights: weight_initializer = tf.contrib.framework.load_variable( self._checkpoint_dir, weighted_subnetwork.weight.op.name) with tf.variable_scope( "weighted_subnetwork_{}".format(subnetwork_index)): weighted_subnetworks.append( self._build_weighted_subnetwork( weighted_subnetwork.name, weighted_subnetwork.iteration_number, weighted_subnetwork.subnetwork, num_subnetworks, weight_initializer=weight_initializer)) subnetwork_index += 1 ensemble_scope = tf.get_variable_scope() with tf.variable_scope("weighted_subnetwork_{}".format(subnetwork_index)): with tf.variable_scope("subnetwork"): _clear_trainable_variables() build_subnetwork = functools.partial( subnetwork_builder.build_subnetwork, features=features, logits_dimension=self._head.logits_dimension, training=mode == tf.estimator.ModeKeys.TRAIN, iteration_step=iteration_step, summary=summary, previous_ensemble=ensemble) # Check which args are in the implemented build_subnetwork method # signature for backwards compatibility. defined_args = inspect.getargspec( subnetwork_builder.build_subnetwork).args if "labels" in defined_args: build_subnetwork = functools.partial( build_subnetwork, labels=labels) with summary.current_scope(), _subnetwork_context( iteration_step_scope=ensemble_scope, scoped_summary=summary): tf.logging.info("Building subnetwork '%s'", subnetwork_builder.name) subnetwork = build_subnetwork() var_list = tf.trainable_variables() weighted_subnetworks.append( self._build_weighted_subnetwork(subnetwork_builder.name, iteration_number, subnetwork, num_subnetworks)) if ensemble: if len(previous_subnetworks) == len(ensemble.weighted_subnetworks): bias = self._create_bias_term( weighted_subnetworks, prior=ensemble.bias) else: bias = self._create_bias_term(weighted_subnetworks) tf.logging.info( "Builder '%s' is using a subset of the subnetworks " "from the previous ensemble, so its ensemble's bias " "term will not be warm started with the previous " "ensemble's bias.", subnetwork_builder.name) else: bias = self._create_bias_term(weighted_subnetworks) return self._build_ensemble_spec( name=ensemble_name, weighted_subnetworks=weighted_subnetworks, summary=summary, bias=bias, features=features, mode=mode, iteration_step=iteration_step, labels=labels, subnetwork_builder=subnetwork_builder, var_list=var_list, previous_ensemble_spec=ensemble_spec) def _build_ensemble_spec(self, name, weighted_subnetworks, summary, bias, features, mode, iteration_step, labels=None, subnetwork_builder=None, var_list=None, previous_ensemble_spec=None): """Builds an `_EnsembleSpec` with the given `WeightedSubnetwork`s. Args: name: The string name of the ensemble. Typically the name of the builder that returned the given `Subnetwork`. weighted_subnetworks: List of `WeightedSubnetwork` instances that form this ensemble. Ordered from first to most recent. summary: A `_ScopedSummary` instance for recording ensemble summaries. bias: Bias term `Tensor` or dict of string to `Tensor` (for multi-head) for the AdaNet-weighted ensemble logits. features: Input `dict` of `Tensor` objects. mode: Estimator `ModeKeys` indicating training, evaluation, or inference. iteration_step: Integer `Tensor` representing the step since the beginning of the current iteration, as opposed to the global step. labels: Labels `Tensor` or a dictionary of string label name to `Tensor` (for multi-head). subnetwork_builder: A `adanet.Builder` instance which defines how to train the subnetwork and ensemble mixture weights. var_list: Optional list or tuple of `Variable` objects to update to minimize `loss`. previous_ensemble_spec: Link the rest of the `_EnsembleSpec` from iteration t-1. Used for creating the subnetwork train_op. Returns: An `_EnsembleSpec` instance. """ ensemble_logits, ensemble_complexity_regularization = ( self._adanet_weighted_ensemble_logits(weighted_subnetworks, bias, summary)) # The AdaNet-weighted ensemble. adanet_weighted_ensemble_spec = self._head.create_estimator_spec( features=features, mode=mode, logits=ensemble_logits, labels=labels, train_op_fn=lambda _: tf.no_op()) # A baseline ensemble: the uniform-average of subnetwork outputs. # It is practically free to compute, requiring no additional training, and # tends to generalize very well. However the AdaNet-weighted ensemble # should perform at least as well given the correct hyperparameters. uniform_average_ensemble_spec = self._head.create_estimator_spec( features=features, mode=mode, logits=self._uniform_average_ensemble_logits(weighted_subnetworks), labels=labels, train_op_fn=lambda _: tf.no_op()) # The subnetwork. new_subnetwork = weighted_subnetworks[-1].subnetwork subnetwork_spec = self._head.create_estimator_spec( features=features, mode=mode, logits=new_subnetwork.logits, labels=labels, train_op_fn=lambda _: tf.no_op()) ensemble_loss = adanet_weighted_ensemble_spec.loss adanet_loss = None eval_metric_ops = {} if mode != tf.estimator.ModeKeys.PREDICT: adanet_loss = ensemble_loss if isinstance(ensemble_complexity_regularization, dict): for key in sorted(ensemble_complexity_regularization): adanet_loss += ensemble_complexity_regularization[key] else: adanet_loss += ensemble_complexity_regularization if mode == tf.estimator.ModeKeys.EVAL: metric_fn = functools.partial( _call_metric_fn, metric_fn=self._metric_fn, features=features, labels=labels) _add_eval_metric_ops( eval_metric_ops=eval_metric_ops, group_name="adanet_weighted_ensemble", estimator_spec=adanet_weighted_ensemble_spec, metric_fn=metric_fn) _add_eval_metric_ops( eval_metric_ops=eval_metric_ops, group_name="uniform_average_ensemble", estimator_spec=uniform_average_ensemble_spec, metric_fn=metric_fn) _add_eval_metric_ops( eval_metric_ops=eval_metric_ops, group_name="subnetwork", estimator_spec=subnetwork_spec, metric_fn=metric_fn) eval_metric_ops["architecture/adanet/ensembles"] = ( _architecture_as_metric(weighted_subnetworks)) if mode == tf.estimator.ModeKeys.TRAIN: with summary.current_scope(): summary.scalar("loss/adanet/adanet_weighted_ensemble", adanet_weighted_ensemble_spec.loss) summary.scalar("loss/adanet/subnetwork", subnetwork_spec.loss) summary.scalar("loss/adanet/uniform_average_ensemble", uniform_average_ensemble_spec.loss) # Create train ops for training subnetworks and learning mixture weights. subnetwork_train_op = None ensemble_train_op = None if mode == tf.estimator.ModeKeys.TRAIN and subnetwork_builder: ensemble_scope = tf.get_variable_scope() _set_trainable_variables(var_list) with tf.variable_scope("train_subnetwork"): previous_ensemble = None if previous_ensemble_spec: previous_ensemble = previous_ensemble_spec.ensemble with summary.current_scope(), _subnetwork_context( iteration_step_scope=ensemble_scope, scoped_summary=summary): subnetwork_train_op = _to_train_op_spec( subnetwork_builder.build_subnetwork_train_op( subnetwork=new_subnetwork, loss=subnetwork_spec.loss, var_list=var_list, labels=labels, iteration_step=iteration_step, summary=summary, previous_ensemble=previous_ensemble)) # Note that these mixture weights are on top of the last_layer of the # subnetwork constructed in TRAIN mode, which means that dropout is # still applied when the mixture weights are being trained. ensemble_var_list = [w.weight for w in weighted_subnetworks] if self._use_bias: ensemble_var_list.insert(0, bias) _set_trainable_variables(ensemble_var_list) ensemble_scope = tf.get_variable_scope() with tf.variable_scope("train_mixture_weights"): with summary.current_scope(), _subnetwork_context( iteration_step_scope=ensemble_scope, scoped_summary=summary): ensemble_train_op = _to_train_op_spec( subnetwork_builder.build_mixture_weights_train_op( loss=adanet_loss, var_list=ensemble_var_list, logits=ensemble_logits, labels=labels, iteration_step=iteration_step, summary=summary)) return _EnsembleSpec( name=name, ensemble=Ensemble( weighted_subnetworks=weighted_subnetworks, bias=bias, logits=ensemble_logits, ), predictions=adanet_weighted_ensemble_spec.predictions, loss=ensemble_loss, adanet_loss=adanet_loss, subnetwork_train_op=subnetwork_train_op, ensemble_train_op=ensemble_train_op, eval_metric_ops=eval_metric_ops, export_outputs=adanet_weighted_ensemble_spec.export_outputs) def _complexity_regularization(self, weight_l1_norm, complexity): """For a subnetwork, computes: (lambda * r(h) + beta) * |w|.""" if self._adanet_lambda == 0. and self._adanet_beta == 0.: return tf.constant(0., name="zero") return tf.scalar_mul(self._adanet_gamma(complexity), weight_l1_norm) def _adanet_gamma(self, complexity): """For a subnetwork, computes: lambda * r(h) + beta.""" if self._adanet_lambda == 0.: return self._adanet_beta return tf.scalar_mul(self._adanet_lambda, tf.to_float(complexity)) + self._adanet_beta def _select_mixture_weight_initializer(self, num_subnetworks): if self._mixture_weight_initializer: return self._mixture_weight_initializer if (self._mixture_weight_type == MixtureWeightType.SCALAR or self._mixture_weight_type == MixtureWeightType.VECTOR): return tf.constant_initializer(1. / num_subnetworks) return tf.zeros_initializer() def _build_weighted_subnetwork(self, name, iteration_number, subnetwork, num_subnetworks, weight_initializer=None): """Builds an `WeightedSubnetwork`. Args: name: String name of `subnetwork`. iteration_number: Integer iteration when the subnetwork was created. subnetwork: The `Subnetwork` to weight. num_subnetworks: The number of subnetworks in the ensemble. weight_initializer: Initializer for the weight variable. Returns: A `WeightedSubnetwork` instance. Raises: ValueError: When the subnetwork's last layer and logits dimension do not match and requiring a SCALAR or VECTOR mixture weight. """ if isinstance(subnetwork.last_layer, dict): logits, weight = {}, {} for key in sorted(subnetwork.last_layer): logits[key], weight[key] = self._build_weighted_subnetwork_helper( subnetwork, num_subnetworks, weight_initializer, key) else: logits, weight = self._build_weighted_subnetwork_helper( subnetwork, num_subnetworks, weight_initializer) return WeightedSubnetwork( name=name, iteration_number=iteration_number, subnetwork=subnetwork, logits=logits, weight=weight) def _build_weighted_subnetwork_helper(self, subnetwork, num_subnetworks, weight_initializer=None, key=None): """Returns the logits and weight of the `WeightedSubnetwork` for key.""" # Treat subnetworks as if their weights are frozen, and ensure that # mixture weight gradients do not propagate through. last_layer = _get_value(subnetwork.last_layer, key) logits = _get_value(subnetwork.logits, key) weight_shape = None last_layer_size = last_layer.get_shape().as_list()[-1] logits_size = logits.get_shape().as_list()[-1] batch_size = tf.shape(last_layer)[0] if weight_initializer is None: weight_initializer = self._select_mixture_weight_initializer( num_subnetworks) if self._mixture_weight_type == MixtureWeightType.SCALAR: weight_shape = [] if self._mixture_weight_type == MixtureWeightType.VECTOR: weight_shape = [logits_size] if self._mixture_weight_type == MixtureWeightType.MATRIX: weight_shape = [last_layer_size, logits_size] with tf.variable_scope("{}logits".format(key + "_" if key else "")): # Mark as not trainable to not add to the TRAINABLE_VARIABLES # collection. Training is handled explicitly with var_lists. weight = tf.get_variable( name="mixture_weight", shape=weight_shape, initializer=weight_initializer, trainable=False) if self._mixture_weight_type == MixtureWeightType.MATRIX: # TODO: Add Unit tests for the ndims == 3 path. ndims = len(last_layer.get_shape().as_list()) if ndims > 3: raise NotImplementedError( "Last Layer with more than 3 dimensions are not supported with " "matrix mixture weights.") # This is reshaping [batch_size, timesteps, emb_dim ] to # [batch_size x timesteps, emb_dim] for matrix multiplication # and reshaping back. if ndims == 3: tf.logging.info("Rank 3 tensors like [batch_size, timesteps, d] are " "reshaped to rank 2 [ batch_size x timesteps, d] for " "the weight matrix multiplication, and are reshaped " "to their original shape afterwards.") last_layer = tf.reshape(last_layer, [-1, last_layer_size]) logits = tf.matmul(last_layer, weight) if ndims == 3: logits = tf.reshape(logits, [batch_size, -1, logits_size]) else: logits = tf.multiply(_get_value(subnetwork.logits, key), weight) return logits, weight def _create_bias_term(self, weighted_subnetworks, prior=None): """Returns a bias term vector. If `use_bias` is set, then it returns a trainable bias variable initialized to zero, or warm-started with the given prior. Otherwise it returns a zero constant bias. Args: weighted_subnetworks: List of `WeightedSubnetwork` instances that form this ensemble. Ordered from first to most recent. prior: Prior bias term `Tensor` of dict of string to `Tensor` (for multi- head) for warm-starting the bias term variable. Returns: A bias term `Tensor` or dict of string to bias term `Tensor` (for multi- head). """ if not isinstance(weighted_subnetworks[0].subnetwork.logits, dict): return self._create_bias_term_helper(weighted_subnetworks, prior) bias_terms = {} for key in sorted(weighted_subnetworks[0].subnetwork.logits): bias_terms[key] = self._create_bias_term_helper(weighted_subnetworks, prior, key) return bias_terms def _create_bias_term_helper(self, weighted_subnetworks, prior, key=None): """Returns a bias term for weights with the given key.""" shape = None if prior is None: prior = tf.zeros_initializer() logits = _get_value(weighted_subnetworks[0].subnetwork.logits, key) logits_dimension = logits.get_shape().as_list()[-1] shape = logits_dimension else: prior = tf.contrib.framework.load_variable(self._checkpoint_dir, _get_value(prior, key).op.name) # Mark as not trainable to not add to the TRAINABLE_VARIABLES collection. # Training is handled explicitly with var_lists. return tf.get_variable( name="{}bias".format(key + "_" if key else ""), shape=shape, initializer=prior, trainable=False) def _adanet_weighted_ensemble_logits(self, weighted_subnetworks, bias, summary): """Computes the AdaNet weighted ensemble logits. If `use_bias` is set, then it returns a trainable bias variable initialized to zero, or warm-started with the given prior. Otherwise it returns a zero constant bias. Args: weighted_subnetworks: List of `WeightedSubnetwork` instances that form this ensemble. Ordered from first to most recent. bias: Bias term `Tensor` or dict of string to `Tensor` (for multi-head) for the AdaNet-weighted ensemble logits. summary: A `_ScopedSummary` instance for recording ensemble summaries. Returns: A two-tuple of: 1. Ensemble logits `Tensor` or dict of string to logits `Tensor` (for multi-head). 2. Ensemble complexity regularization """ if not isinstance(weighted_subnetworks[0].subnetwork.logits, dict): return self._adanet_weighted_ensemble_logits_helper( weighted_subnetworks, bias, summary) logits, ensemble_complexity_regularization = {}, {} for key in sorted(weighted_subnetworks[0].subnetwork.logits): logits[key], ensemble_complexity_regularization[key] = ( self._adanet_weighted_ensemble_logits_helper(weighted_subnetworks, bias, summary, key)) return logits, ensemble_complexity_regularization def _adanet_weighted_ensemble_logits_helper(self, weighted_subnetworks, bias, summary, key=None): """Returns the AdaNet ensemble logits and regularization term for key.""" subnetwork_logits = [] ensemble_complexity_regularization = 0 total_weight_l1_norms = 0 weights = [] for weighted_subnetwork in weighted_subnetworks: weight_l1_norm = tf.norm( _get_value(weighted_subnetwork.weight, key), ord=1) total_weight_l1_norms += weight_l1_norm ensemble_complexity_regularization += self._complexity_regularization( weight_l1_norm, weighted_subnetwork.subnetwork.complexity) subnetwork_logits.append(_get_value(weighted_subnetwork.logits, key)) weights.append(weight_l1_norm) with tf.variable_scope("{}logits".format(key + "_" if key else "")): ensemble_logits = _get_value(bias, key) for logits in subnetwork_logits: ensemble_logits = tf.add(ensemble_logits, logits) with summary.current_scope(): summary.scalar( "complexity_regularization/adanet/adanet_weighted_ensemble", ensemble_complexity_regularization) summary.histogram("mixture_weights/adanet/adanet_weighted_ensemble", weights) for iteration, weight in enumerate(weights): scope = "adanet/adanet_weighted_ensemble/subnetwork_{}".format( iteration) summary.scalar("mixture_weight_norms/{}".format(scope), weight) fraction = weight / total_weight_l1_norms summary.scalar("mixture_weight_fractions/{}".format(scope), fraction) return ensemble_logits, ensemble_complexity_regularization def _uniform_average_ensemble_logits(self, weighted_subnetworks): """Computes the uniform average ensemble logits. Args: weighted_subnetworks: List of `WeightedSubnetwork` instances that form this ensemble. Ordered from first to most recent. Returns: Ensemble logits `Tensor` or dict of string to logits `Tensor` (for multi-head). """ if not isinstance(weighted_subnetworks[0].subnetwork.logits, dict): return self._uniform_average_ensemble_logits_helper(weighted_subnetworks) logits = {} for key in sorted(weighted_subnetworks[0].subnetwork.logits): logits[key] = self._uniform_average_ensemble_logits_helper( weighted_subnetworks, key) return logits def _uniform_average_ensemble_logits_helper(self, weighted_subnetworks, key=None): """Returns logits for the baseline ensemble for the given key.""" return tf.add_n([ _get_value(wwl.subnetwork.logits, key) for wwl in weighted_subnetworks ]) / len(weighted_subnetworks)