Source code for adanet.subnetwork.report

"""Container for an `adanet.Subnetwork`'s attributes and metrics.

Copyright 2018 The AdaNet Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections

from adanet import tf_compat
import six
import tensorflow.compat.v1 as tf


[docs]class Report( collections.namedtuple("Report", ["hparams", "attributes", "metrics"])): # pyformat: disable """A container for data to be collected about a :class:`Subnetwork`. Args: hparams: A dict mapping strings to python strings, ints, bools, or floats. It is meant to contain the constants that define the :class:`adanet.subnetwork.Builder`, such as dropout, number of layers, or initial learning rate. attributes: A dict mapping strings to rank 0 Tensors of dtype string, int32, or float32. It is meant to contain properties that may or may not change over the course of training the :class:`adanet.subnetwork.Subnetwork`, such as the number of parameters, the Lipschitz constant, the :math:`L2` norm of the weights, or learning rate at materialization time. metrics: Dict of metric results keyed by name. The values of the dict are the results of calling a metric function, namely a `(metric_tensor, update_op)` tuple. `metric_tensor` should be evaluated without any impact on state (typically is a pure computation results based on variables.). For example, it should not trigger the :code:`update_op` or requires any input fetching. This is meant to contain metrics of interest, such as the training loss, complexity regularized loss, or standard deviation of the last layer outputs. Returns: A validated :class:`adanet.subnetwork.Report` object. Raises: ValueError: If validation fails. """ # pyformat: enable def __new__(cls, hparams, attributes, metrics): def _is_scalar(tensor): """Returns True iff tensor is scalar.""" return tensor.shape.ndims == 0 def _is_accepted_dtype(tensor): """Returns True iff tensor has the dtype we can handle.""" return tensor.dtype.base_dtype in (tf.bool, tf.int32, tf.float32, tf.float64, tf.string) # Validate hparams for key, value in hparams.items(): if not isinstance(value, (bool, int, float, six.string_types)): raise ValueError( "hparam '{}' refers to invalid value {}, type {}. type must be " "python primitive int, float, bool, or string.".format( key, value, type(value))) # Validate attributes for key, value in attributes.items(): if not isinstance(value, tf.Tensor): raise ValueError("attribute '{}' refers to invalid value: {}, type: {}." "type must be Tensor.".format(key, value, type(value))) if not (_is_scalar(value) and _is_accepted_dtype(value)): raise ValueError( "attribute '{}' refers to invalid tensor {}. Shape: {}".format( key, value, value.get_shape())) # Validate metrics metrics_copy = {} for key, value in metrics.items(): value = tf_compat.metric_op(value) if not isinstance(value, tuple): raise ValueError( "metric '{}' has invalid type {}. Must be a tuple.".format( key, type(value))) if len(value) < 2: raise ValueError( "metric tuple '{}' has fewer than 2 elements".format(key)) if not isinstance(value[0], (tf.Tensor, tf.Variable)): raise ValueError( "First element of metric tuple '{}' has value {} and type {}. " "Must be a Tensor or Variable.".format(key, value[0], type(value[0]))) if not _is_accepted_dtype(value[0]): raise ValueError( "First element of metric '{}' refers to Tensor of the wrong " "dtype {}. Must be one of tf.bool, tf.int32, tf.float32, " "tf.float64 or tf.string.".format(key, value[0].dtype)) if not _is_scalar(value[0]): tf.logging.warn( "First element of metric '{}' refers to Tensor of rank > 0. " "AdaNet is currently unable to store metrics of rank > 0 -- this " "metric will be dropped from the report. " "value: {}".format(key, value[0])) continue if not isinstance(value[1], (tf.Tensor, tf.Operation, tf.Variable)): raise ValueError( "Second element of metric tuple '{}' has value {} and type {}. " "Must be a Tensor, Operation, or Variable.".format( key, value[1], type(value[1]))) metrics_copy[key] = value return super(Report, cls).__new__( cls, hparams=hparams, attributes=attributes, metrics=metrics_copy)
[docs]class MaterializedReport( collections.namedtuple("MaterializedReport", [ "iteration_number", "name", "hparams", "attributes", "metrics", "included_in_final_ensemble" ])): # pyformat: disable """Data collected about a :class:`adanet.subnetwork.Subnetwork`. Args: iteration_number: A python integer for the AdaNet iteration number, starting from 0. name: A string, which is either the name of the corresponding Builder, or "previous_ensemble" if it refers to the previous_ensemble. hparams: A dict mapping strings to python strings, ints, or floats. These are constants passed from the author of the :class:`adanet.subnetwork.Builder` that was used to construct this :class:`adanet.subnetwork.Subnetwork`. It is meant to contain the arguments that defined the :class:`adanet.subnetwork.Builder`, such as dropout, number of layers, or initial learning rate. attributes: A dict mapping strings to python strings, ints, bools, or floats. These are python primitives that come from materialized Tensors; these Tensors were defined by the author of the :class:`adanet.subnetwork.Builder` that was used to construct this :class:`adanet.subnetwork.Subnetwork`. It is meant to contain properties that may or may not change over the course of training the :class:`adanet.subnetwork.Subnetwork`, such as the number of parameters, the Lipschitz constant, or the :math:`L2` norm of the weights. metrics: A dict mapping strings to python strings, ints, or floats. These are python primitives that come from metrics that were evaluated on the trained :class:`adanet.subnetwork.Subnetwork` over some dataset; these metrics were defined by the author of the :class:`adanet.subnetwork.Builder` that was used to construct this :class:`adanet.subnetwork.Subnetwork`. It is meant to contain performance metrics or measures that could predict generalization, such as the training loss, complexity regularized loss, or standard deviation of the last layer outputs. included_in_final_ensemble: A boolean denoting whether the associated :class:`adanet.subnetwork.Subnetwork` was included in the ensemble at the end of the AdaNet iteration. Returns: An :class:`adanet.subnetwork.MaterializedReport` object. """ # pyformat: enable def __new__(cls, iteration_number, name, hparams, attributes, metrics, included_in_final_ensemble=False): return super(MaterializedReport, cls).__new__( cls, iteration_number=iteration_number, name=name, hparams=hparams, attributes=attributes, metrics=metrics, included_in_final_ensemble=included_in_final_ensemble)