Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

data: expose downsampling preferences to plugins #3271

Merged
merged 1 commit into from
Feb 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions tensorboard/backend/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,17 +97,20 @@
logger = tb_logging.get_logger()


def tensor_size_guidance_from_flags(flags):
"""Apply user per-summary size guidance overrides."""

tensor_size_guidance = dict(DEFAULT_TENSOR_SIZE_GUIDANCE)
def _parse_samples_per_plugin(flags):
result = {}
if not flags or not flags.samples_per_plugin:
return tensor_size_guidance

return result
for token in flags.samples_per_plugin.split(","):
k, v = token.strip().split("=")
tensor_size_guidance[k] = int(v)
result[k] = int(v)
return result


def _apply_tensor_size_guidance(sampling_hints):
"""Apply user per-summary size guidance overrides."""
tensor_size_guidance = dict(DEFAULT_TENSOR_SIZE_GUIDANCE)
tensor_size_guidance.update(sampling_hints)
return tensor_size_guidance


Expand Down Expand Up @@ -151,9 +154,10 @@ def standard_tensorboard_wsgi(flags, plugin_loaders, assets_zip_provider):
multiplexer = _DbModeMultiplexer(flags.db, db_connection_provider)
else:
# Regular logdir loading mode.
sampling_hints = _parse_samples_per_plugin(flags)
multiplexer = event_multiplexer.EventMultiplexer(
size_guidance=DEFAULT_SIZE_GUIDANCE,
tensor_size_guidance=tensor_size_guidance_from_flags(flags),
tensor_size_guidance=_apply_tensor_size_guidance(sampling_hints),
purge_orphaned_data=flags.purge_orphaned_data,
max_reload_threads=flags.max_reload_threads,
event_file_active_filter=_get_event_file_active_filter(flags),
Expand Down Expand Up @@ -238,6 +242,7 @@ def TensorBoardWSGIApp(
multiplexer=deprecated_multiplexer,
assets_zip_provider=assets_zip_provider,
plugin_name_to_instance=plugin_name_to_instance,
sampling_hints=_parse_samples_per_plugin(flags),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parsing this twice felt easier than adding an extra parameter to this
method, but I’d be happy to change that if others object.

window_title=flags.window_title,
)
tbplugins = []
Expand Down
6 changes: 6 additions & 0 deletions tensorboard/plugins/base_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ def __init__(
logdir=None,
multiplexer=None,
plugin_name_to_instance=None,
sampling_hints=None,
window_title=None,
):
"""Instantiates magic container.
Expand Down Expand Up @@ -291,6 +292,10 @@ def __init__(
plugin may be absent from this mapping until it is registered. Plugin
logic should handle cases in which a plugin is absent from this
mapping, lest a KeyError is raised.
sampling_hints: Map from plugin name to `int` or `NoneType`, where
the value represents the user-specified downsampling limit as
given to the `--samples_per_plugin` flag, or `None` if none was
explicitly given for this plugin.
window_title: A string specifying the window title.
"""
self.assets_zip_provider = assets_zip_provider
Expand All @@ -301,6 +306,7 @@ def __init__(
self.logdir = logdir
self.multiplexer = multiplexer
self.plugin_name_to_instance = plugin_name_to_instance
self.sampling_hints = sampling_hints
self.window_title = window_title


Expand Down
19 changes: 13 additions & 6 deletions tensorboard/plugins/histogram/histograms_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
from tensorboard.util import tensor_util


_DEFAULT_DOWNSAMPLING = 500 # histograms per time series


class HistogramsPlugin(base_plugin.TBPlugin):
"""Histograms Plugin for TensorBoard.

Expand All @@ -62,6 +65,9 @@ def __init__(self, context):
"""
self._multiplexer = context.multiplexer
self._db_connection_provider = context.db_connection_provider
self._downsample_to = (context.sampling_hints or {}).get(
self.plugin_name, _DEFAULT_DOWNSAMPLING
)
if context.flags and context.flags.generic_data == "true":
self._data_provider = context.data_provider
else:
Expand Down Expand Up @@ -174,20 +180,21 @@ def histograms_impl(self, tag, run, experiment, downsample_to=None):
"""Result of the form `(body, mime_type)`.

At most `downsample_to` events will be returned. If this value is
`None`, then no downsampling will be performed.
`None`, then default downsampling will be performed.

Raises:
tensorboard.errors.PublicError: On invalid request.
"""
if self._data_provider:
# Downsample reads to 500 histograms per time series, which is
# the default size guidance for histograms under the multiplexer
# loading logic.
SAMPLE_COUNT = downsample_to if downsample_to is not None else 500
sample_count = (
downsample_to
if downsample_to is not None
else self._downsample_to
)
all_histograms = self._data_provider.read_tensors(
experiment_id=experiment,
plugin_name=metadata.PLUGIN_NAME,
downsample=SAMPLE_COUNT,
downsample=sample_count,
run_tag_filter=provider.RunTagFilter(runs=[run], tags=[tag]),
)
histograms = all_histograms.get(run, {}).get(tag, None)
Expand Down
10 changes: 5 additions & 5 deletions tensorboard/plugins/image/images_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
}

_DEFAULT_IMAGE_MIMETYPE = "application/octet-stream"
_DEFAULT_DOWNSAMPLING = 10 # images per time series


# Extend imghdr.tests to include svg.
Expand All @@ -69,6 +70,9 @@ def __init__(self, context):
"""
self._multiplexer = context.multiplexer
self._db_connection_provider = context.db_connection_provider
self._downsample_to = (context.sampling_hints or {}).get(
self.plugin_name, _DEFAULT_DOWNSAMPLING
)
if context.flags and context.flags.generic_data == "true":
self._data_provider = context.data_provider
else:
Expand Down Expand Up @@ -239,14 +243,10 @@ def _image_response_for_run(self, experiment, run, tag, sample):
parameters.
"""
if self._data_provider:
# Downsample reads to 10 images per time series, which is the
# default size guidance for images under the multiplexer loading
# logic.
SAMPLE_COUNT = 10
all_images = self._data_provider.read_blob_sequences(
experiment_id=experiment,
plugin_name=metadata.PLUGIN_NAME,
downsample=SAMPLE_COUNT,
downsample=self._downsample_to,
run_tag_filter=provider.RunTagFilter(runs=[run], tags=[tag]),
)
images = all_images.get(run, {}).get(tag, None)
Expand Down
12 changes: 7 additions & 5 deletions tensorboard/plugins/scalar/scalars_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
from tensorboard.util import tensor_util


_DEFAULT_DOWNSAMPLING = 1000 # scalars per time series


class OutputFormat(object):
"""An enum used to list the valid output formats for API calls."""

Expand All @@ -60,6 +63,9 @@ def __init__(self, context):
"""
self._multiplexer = context.multiplexer
self._db_connection_provider = context.db_connection_provider
self._downsample_to = (context.sampling_hints or {}).get(
self.plugin_name, _DEFAULT_DOWNSAMPLING
)
if context.flags and context.flags.generic_data != "false":
self._data_provider = context.data_provider
else:
Expand Down Expand Up @@ -169,14 +175,10 @@ def index_impl(self, experiment=None):
def scalars_impl(self, tag, run, experiment, output_format):
"""Result of the form `(body, mime_type)`."""
if self._data_provider:
# Downsample reads to 1000 scalars per time series, which is the
# default size guidance for scalars under the multiplexer loading
# logic.
SAMPLE_COUNT = 1000
all_scalars = self._data_provider.read_scalars(
experiment_id=experiment,
plugin_name=metadata.PLUGIN_NAME,
downsample=SAMPLE_COUNT,
downsample=self._downsample_to,
run_tag_filter=provider.RunTagFilter(runs=[run], tags=[tag]),
)
scalars = all_scalars.get(run, {}).get(tag, None)
Expand Down
7 changes: 6 additions & 1 deletion tensorboard/plugins/text/text_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
2d tables are supported. Showing a 2d slice of the data instead."""
)

_DEFAULT_DOWNSAMPLING = 100 # text tensors per time series


def make_table_row(contents, tag="td"):
"""Given an iterable of string contents, make a table row.
Expand Down Expand Up @@ -212,6 +214,9 @@ def __init__(self, context):
context: A base_plugin.TBContext instance.
"""
self._multiplexer = context.multiplexer
self._downsample_to = (context.sampling_hints or {}).get(
self.plugin_name, _DEFAULT_DOWNSAMPLING
)
if context.flags and context.flags.generic_data == "true":
self._data_provider = context.data_provider
else:
Expand Down Expand Up @@ -261,7 +266,7 @@ def text_impl(self, run, tag, experiment):
all_text = self._data_provider.read_tensors(
experiment_id=experiment,
plugin_name=metadata.PLUGIN_NAME,
downsample=100,
downsample=self._downsample_to,
run_tag_filter=provider.RunTagFilter(runs=[run], tags=[tag]),
)
text = all_text.get(run, {}).get(tag, None)
Expand Down