Source code for adanet.core.summary

"""Tensorboard summaries for the single graph AdaNet implementation.

Copyright 2018 The AdaNet Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import abc
import contextlib

import tensorflow as tf
from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.python.ops import summary_op_util


[docs]class Summary(object): """Interface for writing summaries to Tensorboard.""" __metaclass__ = abc.ABCMeta
[docs] @abc.abstractmethod def scalar(self, name, tensor, family=None): """Outputs a `tf.Summary` protocol buffer containing a single scalar value. The generated tf.Summary has a Tensor.proto containing the input Tensor. Args: name: A name for the generated node. Will also serve as the series name in TensorBoard. tensor: A real numeric Tensor containing a single value. family: Optional; if provided, used as the prefix of the summary tag name, which controls the tab name used for display on Tensorboard. Returns: A scalar `Tensor` of type `string`. Which contains a `tf.Summary` protobuf. Raises: ValueError: If tensor has the wrong shape or type. """
[docs] @abc.abstractmethod def image(self, name, tensor, max_outputs=3, family=None): """Outputs a `tf.Summary` protocol buffer with images. The summary has up to `max_outputs` summary values containing images. The images are built from `tensor` which must be 4-D with shape `[batch_size, height, width, channels]` and where `channels` can be: * 1: `tensor` is interpreted as Grayscale. * 3: `tensor` is interpreted as RGB. * 4: `tensor` is interpreted as RGBA. The images have the same number of channels as the input tensor. For float input, the values are normalized one image at a time to fit in the range `[0, 255]`. `uint8` values are unchanged. The op uses two different normalization algorithms: * If the input values are all positive, they are rescaled so the largest one is 255. * If any input value is negative, the values are shifted so input value 0.0 is at 127. They are then rescaled so that either the smallest value is 0, or the largest one is 255. The `tag` in the outputted tf.Summary.Value protobufs is generated based on the name, with a suffix depending on the max_outputs setting: * If `max_outputs` is 1, the summary value tag is '*name*/image'. * If `max_outputs` is greater than 1, the summary value tags are generated sequentially as '*name*/image/0', '*name*/image/1', etc. Args: name: A name for the generated node. Will also serve as a series name in TensorBoard. tensor: A 4-D `uint8` or `float32` `Tensor` of shape `[batch_size, height, width, channels]` where `channels` is 1, 3, or 4. max_outputs: Max number of batch elements to generate images for. family: Optional; if provided, used as the prefix of the summary tag name, which controls the tab name used for display on Tensorboard. Returns: A scalar `Tensor` of type `string`. The serialized `tf.Summary` protocol buffer. """
[docs] @abc.abstractmethod def histogram(self, name, values, family=None): """Outputs a `tf.Summary` protocol buffer with a histogram. Adding a histogram summary makes it possible to visualize your data's distribution in TensorBoard. You can see a detailed explanation of the TensorBoard histogram dashboard [here](https://www.tensorflow.org/get_started/tensorboard_histograms). The generated [`tf.Summary`]( tensorflow/core/framework/summary.proto) has one summary value containing a histogram for `values`. This op reports an `InvalidArgument` error if any value is not finite. Args: name: A name for the generated node. Will also serve as a series name in TensorBoard. values: A real numeric `Tensor`. Any shape. Values to use to build the histogram. family: Optional; if provided, used as the prefix of the summary tag name, which controls the tab name used for display on Tensorboard. Returns: A scalar `Tensor` of type `string`. The serialized `tf.Summary` protocol buffer. """
[docs] @abc.abstractmethod def audio(self, name, tensor, sample_rate, max_outputs=3, family=None): """Outputs a `tf.Summary` protocol buffer with audio. The summary has up to `max_outputs` summary values containing audio. The audio is built from `tensor` which must be 3-D with shape `[batch_size, frames, channels]` or 2-D with shape `[batch_size, frames]`. The values are assumed to be in the range of `[-1.0, 1.0]` with a sample rate of `sample_rate`. The `tag` in the outputted tf.Summary.Value protobufs is generated based on the name, with a suffix depending on the max_outputs setting: * If `max_outputs` is 1, the summary value tag is '*name*/audio'. * If `max_outputs` is greater than 1, the summary value tags are generated sequentially as '*name*/audio/0', '*name*/audio/1', etc Args: name: A name for the generated node. Will also serve as a series name in TensorBoard. tensor: A 3-D `float32` `Tensor` of shape `[batch_size, frames, channels]` or a 2-D `float32` `Tensor` of shape `[batch_size, frames]`. sample_rate: A Scalar `float32` `Tensor` indicating the sample rate of the signal in hertz. max_outputs: Max number of batch elements to generate audio for. family: Optional; if provided, used as the prefix of the summary tag name, which controls the tab name used for display on Tensorboard. Returns: A scalar `Tensor` of type `string`. The serialized `tf.Summary` protocol buffer. """
def _strip_scope(name, scope, additional_scope): """Returns the name with scope stripped from it.""" if additional_scope: name = name.replace("{}/".format(additional_scope), "") if scope: name = name.replace("{}/".format(scope), "", 1) return name class _ScopedSummary(Summary): """Records summaries in a given scope. Each scope gets assigned a different collection where summary ops gets added. This allows Tensorboard to display summaries with different scopes but the same name in the same charts. """ _TMP_COLLECTION_NAME = "_tmp_summaries" def __init__(self, scope=None, skip_summary=False): """Initializes a `_ScopedSummary`. Args: scope: String scope name. skip_summary: Whether to record summary ops. Returns: A `_ScopedSummary` instance. """ if tpu_function.get_tpu_context().number_of_shards: tf.logging.log_first_n( tf.logging.WARN, "Scoped summaries will be skipped since they do not support TPU", 1) skip_summary = True self._scope = scope self._additional_scope = None self._skip_summary = skip_summary self._actual_summary_scalar_fn = tf.summary.scalar self._actual_summary_image_fn = tf.summary.image self._actual_summary_histogram_fn = tf.summary.histogram self._actual_summary_audio_fn = tf.summary.audio @property def scope(self): """Returns scope string.""" return self._scope @contextlib.contextmanager def current_scope(self): """Registers the current context's scope to strip it from summary tags.""" self._additional_scope = tf.get_default_graph().get_name_scope() yield self._additional_scope = None @contextlib.contextmanager def _strip_tag_scope(self): """Monkey patches `summary_op_util.summary_scope` to strip tag scopes.""" original_summary_scope = summary_op_util.summary_scope @contextlib.contextmanager def strip_tag_scope_fn(name, family=None, default_name=None, values=None): tag, scope = (None, None) with original_summary_scope(name, family, default_name, values) as (t, s): tag = _strip_scope(t, self.scope, self._additional_scope) scope = s yield tag, scope summary_op_util.summary_scope = strip_tag_scope_fn yield summary_op_util.summary_scope = original_summary_scope def _collection_name(self): """Returns the collection for recording.""" if self._scope: return "_{}_summaries".format(self._scope) return "_global_summaries" def _prefix_scope(self, name): """Prefixes summary name with scope.""" if self._scope: if name[0] == "/": name = name[1:] return "{scope}/{name}".format(scope=self._scope, name=name) return name def scalar(self, name, tensor, family=None): """See `Summary`.""" if self._skip_summary: return tf.constant("") with self._strip_tag_scope(): summary = self._actual_summary_scalar_fn( name=self._prefix_scope(name), tensor=tensor, family=family, collections=[self._TMP_COLLECTION_NAME]) tf.add_to_collection(self._collection_name(), summary) return summary def image(self, name, tensor, max_outputs=3, family=None): """See `Summary`.""" if self._skip_summary: return tf.constant("") with self._strip_tag_scope(): summary = self._actual_summary_image_fn( name=self._prefix_scope(name), tensor=tensor, max_outputs=max_outputs, family=family, collections=[self._TMP_COLLECTION_NAME]) tf.add_to_collection(self._collection_name(), summary) return summary def histogram(self, name, values, family=None): """See `Summary`.""" if self._skip_summary: return tf.constant("") with self._strip_tag_scope(): summary = self._actual_summary_histogram_fn( name=self._prefix_scope(name), values=values, family=family, collections=[self._TMP_COLLECTION_NAME]) tf.add_to_collection(self._collection_name(), summary) return summary def audio(self, name, tensor, sample_rate, max_outputs=3, family=None): """See `Summary`.""" if self._skip_summary: return tf.constant("") with self._strip_tag_scope(): summary = self._actual_summary_audio_fn( name=self._prefix_scope(name), tensor=tensor, sample_rate=sample_rate, max_outputs=max_outputs, family=family, collections=[self._TMP_COLLECTION_NAME]) tf.add_to_collection(self._collection_name(), summary) return summary def merge_all(self): """Returns the list of summaries added using this _ScopedSummary. Note: this is an abuse of the tf.summary.merge_all API since it is expected to return a summary op with all summaries merged. However, ScopedSummary is only used in the internal implementation, so this should be OK. """ return tf.get_collection(key=self._collection_name())