diff --git a/fairseq/meters.py b/fairseq/meters.py index bfa9a24fb4..6c71ac990e 100644 --- a/fairseq/meters.py +++ b/fairseq/meters.py @@ -3,40 +3,108 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import bisect +from collections import OrderedDict import time +from typing import Dict, Optional -class AverageMeter(object): - """Computes and stores the average and current value""" +class Meter(object): + """Base class for Meters.""" + def __init__(self): + pass + + def state_dict(self): + return {} + + def load_state_dict(self, state_dict): + pass + + def reset(self): + raise NotImplementedError + + @property + def smoothed_value(self) -> float: + """Smoothed value used for logging.""" + raise NotImplementedError + + +class AverageMeter(Meter): + """Computes and stores the average and current value""" + + def __init__(self, round: Optional[int] = None): + self.round = round self.reset() def reset(self): - self.val = 0 - self.avg = 0 - self.sum = 0 - self.count = 0 + self.val = None # most recent update + self.sum = 0 # sum from all updates + self.count = 0 # total n from all updates def update(self, val, n=1): - self.val = val - self.sum += val * n - self.count += n - self.avg = self.sum / self.count + if val is not None: + self.val = val + if n > 0: + self.sum += val * n + self.count += n + + def state_dict(self): + return { + 'val': self.val, + 'sum': self.sum, + 'count': self.count, + 'round': self.round, + } + + def load_state_dict(self, state_dict): + self.val = state_dict['val'] + self.sum = state_dict['sum'] + self.count = state_dict['count'] + self.round = state_dict.get('round', None) + + @property + def avg(self): + return self.sum / self.count if self.count > 0 else self.val + + @property + def smoothed_value(self) -> float: + val = self.avg + if self.round is not None and val is not None: + val = round(val, self.round) + return val -class TimeMeter(object): +class TimeMeter(Meter): """Computes the average occurrence of some event per second""" - def __init__(self, init=0): - self.reset(init) - def reset(self, init=0): + def __init__(self, init: int = 0, n: int = 0, round: Optional[int] = None): + self.round = round + self.reset(init, n) + + def reset(self, init=0, n=0): self.init = init self.start = time.time() - self.n = 0 + self.n = n def update(self, val=1): self.n += val + def state_dict(self): + return { + 'init': self.elapsed_time, + 'n': self.n, + 'round': self.round, + } + + def load_state_dict(self, state_dict): + if 'start' in state_dict: + # backwards compatibility for old state_dicts + self.reset(init=state_dict['init']) + else: + self.reset(init=state_dict['init'], n=state_dict['n']) + self.round = state_dict.get('round', None) + @property def avg(self): return self.n / self.elapsed_time @@ -45,11 +113,22 @@ def avg(self): def elapsed_time(self): return self.init + (time.time() - self.start) + @property + def smoothed_value(self) -> float: + val = self.avg + if self.round is not None and val is not None: + val = round(val, self.round) + return val + -class StopwatchMeter(object): +class StopwatchMeter(Meter): """Computes the sum/avg duration of some event in seconds""" - def __init__(self): - self.reset() + + def __init__(self, round: Optional[int] = None): + self.round = round + self.sum = 0 + self.n = 0 + self.start_time = None def start(self): self.start_time = time.time() @@ -59,13 +138,98 @@ def stop(self, n=1): delta = time.time() - self.start_time self.sum += delta self.n += n - self.start_time = None def reset(self): - self.sum = 0 - self.n = 0 + self.sum = 0 # cumulative time during which stopwatch was active + self.n = 0 # total n across all start/stop + self.start() + + def state_dict(self): + return { + 'sum': self.sum, + 'n': self.n, + 'round': self.round, + } + + def load_state_dict(self, state_dict): + self.sum = state_dict['sum'] + self.n = state_dict['n'] self.start_time = None + self.round = state_dict.get('round', None) @property def avg(self): - return self.sum / self.n + return self.sum / self.n if self.n > 0 else self.sum + + @property + def elapsed_time(self): + if self.start_time is None: + return 0. + return time.time() - self.start_time + + @property + def smoothed_value(self) -> float: + val = self.avg if self.sum > 0 else self.elapsed_time + if self.round is not None and val is not None: + val = round(val, self.round) + return val + + +class MetersDict(OrderedDict): + """A sorted dictionary of :class:`Meters`. + + Meters are sorted according to a priority that is given when the + meter is first added to the dictionary. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.priorities = [] + + def __setitem__(self, key, value): + assert key not in self, "MetersDict doesn't support reassignment" + priority, value = value + bisect.insort(self.priorities, (priority, len(self.priorities), key)) + super().__setitem__(key, value) + for _, _, key in self.priorities: # reorder dict to match priorities + self.move_to_end(key) + + def add_meter(self, key, meter, priority): + self.__setitem__(key, (priority, meter)) + + def state_dict(self): + return [ + (pri, key, self[key].__class__.__name__, self[key].state_dict()) + for pri, _, key in self.priorities + # can't serialize DerivedMeter instances + if not isinstance(self[key], MetersDict._DerivedMeter) + ] + + def load_state_dict(self, state_dict): + self.clear() + self.priorities.clear() + for pri, key, meter_cls, meter_state in state_dict: + meter = globals()[meter_cls]() + meter.load_state_dict(meter_state) + self.add_meter(key, meter, pri) + + def get_smoothed_value(self, key: str) -> float: + """Get a single smoothed value.""" + meter = self[key] + if isinstance(meter, MetersDict._DerivedMeter): + return meter.fn(self) + else: + return meter.smoothed_value + + def get_smoothed_values(self) -> Dict[str, float]: + """Get all smoothed values.""" + return OrderedDict([(key, self.get_smoothed_value(key)) for key in self.keys()]) + + class _DerivedMeter(Meter): + """A Meter whose values are derived from other Meters.""" + + def __init__(self, fn): + self.fn = fn + + def reset(self): + pass diff --git a/fairseq/metrics.py b/fairseq/metrics.py new file mode 100644 index 0000000000..8b560cf907 --- /dev/null +++ b/fairseq/metrics.py @@ -0,0 +1,216 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +A standalone module for aggregating metrics. + +Metrics can be logged from anywhere using the `log_*` functions defined +in this module. The logged values will be aggregated dynamically based +on the aggregation context in which the logging occurs. See the +:func:`aggregate` context manager for more details. +""" + +from collections import OrderedDict +import contextlib +import time +from typing import Callable, Dict, List, Optional +import uuid + +from .meters import * + + +# Aggregation contexts are considered "active" when inside the scope +# created by the :func:`aggregate` context manager. +_aggregators = OrderedDict() +_active_aggregators = OrderedDict() + + +# The "default" aggregator observes all logged values. +_aggregators["default"] = MetersDict() +_active_aggregators["default"] = _aggregators["default"] + + +@contextlib.contextmanager +def aggregate(name: Optional[str] = None, exclusive: bool = False): + """Context manager to aggregate metrics under a given name. + + Aggregations can be nested. If *exclusive* is ``False``, then logged + metrics will be recorded along the entire stack of nested + aggregators, including a global "default" aggregator. If *exclusive* + is ``True``, then only the most recent aggregator will be used. + + Note that aggregation contexts are uniquely identified by their + *name* (e.g., train, valid). Creating a context with an existing + name will reuse the corresponding :class:`MetersDict` instance. + If no name is given then a temporary aggregator will be created + and reset when the context manager exits. + + Usage:: + + with metrics.aggregate("train"): + for step, batch in enumerate(epoch): + with metrics.aggregate() as agg: + metrics.log_scalar("loss", get_loss(batch)) + if step % log_interval == 0: + print(agg.get_smoothed_value("loss")) + print(metrics.get_smoothed_values("train")["loss"]) + + Args: + name (str): name of the aggregation. Defaults to a + random/temporary name if not given explicitly. + exclusive (bool): only log to the most recent aggregation + context, instead of all nested aggregations. + """ + if name is None: + # generate a temporary name + name = str(uuid.uuid4()) + assert name not in _aggregators + agg = MetersDict() + else: + assert name != "default" + agg = _aggregators.setdefault(name, MetersDict()) + + if exclusive: + backup_aggregators = _active_aggregators.copy() + _active_aggregators.clear() + + _active_aggregators[name] = agg + yield agg + del _active_aggregators[name] + + if exclusive: + _active_aggregators.clear() + _active_aggregators.update(backup_aggregators) + + +def get_active_aggregators() -> List[MetersDict]: + return list(_active_aggregators.values()) + + +def log_scalar( + key: str, + value: float, + weight: float = 1, + priority: int = 10, + round: Optional[int] = None, +): + """Log a scalar value. + + Args: + key (str): name of the field to log + value (float): value to log + weight (float): weight that this value contributes to the average. + A weight of 0 will always log the latest value. + priority (int): smaller values are logged earlier in the output + round (Optional[int]): number of digits to round to when displaying + """ + for agg in get_active_aggregators(): + if key not in agg: + agg.add_meter(key, AverageMeter(round=round), priority) + agg[key].update(value, weight) + + +def log_derived(key: str, fn: Callable[[MetersDict], float], priority: int = 20): + """Log a scalar value derived from other meters. + + Args: + key (str): name of the field to log + fn (Callable[[MetersDict], float]): function that takes a single + argument *meters* and returns the derived value + priority (int): smaller values are logged earlier in the output + """ + for agg in get_active_aggregators(): + if key not in agg: + agg.add_meter(key, MetersDict._DerivedMeter(fn), priority) + + +def log_speed(key: str, value: float, priority: int = 30, round: Optional[int] = None): + """Log the rate of some quantity per second. + + Args: + key (str): name of the field to log + value (float): value to log + priority (int): smaller values are logged earlier in the output + round (Optional[int]): number of digits to round to when displaying + """ + for agg in get_active_aggregators(): + if key not in agg: + agg.add_meter(key, TimeMeter(round=round), priority) + agg[key].reset() # reset meter on the first call + else: + agg[key].update(value) + + +def log_start_time(key: str, priority: int = 40, round: Optional[int] = None): + """Log the duration of some event in seconds. + + The duration will be computed once :func:`log_stop_time` is called. + + Args: + key (str): name of the field to log + priority (int): smaller values are logged earlier in the output + round (Optional[int]): number of digits to round to when displaying + """ + for agg in get_active_aggregators(): + if key not in agg: + agg.add_meter(key, StopwatchMeter(round=round), priority) + agg[key].start() + + +def log_stop_time(key: str, weight: float = 0.): + """Log the duration of some event in seconds. + + The duration will be computed since :func:`log_start_time` was called. + Set weight > 0 to report the average time instead of the sum. + + Args: + key (str): name of the field to log + weight (float): weight that this time contributes to the average + """ + for agg in get_active_aggregators(): + agg[key].stop(weight) + + +def reset_meters(name: str): + """Reset Meter instances aggregated under a given *name*.""" + try: + for meter in get_meters(name).values(): + meter.reset() + except KeyError: + pass + + +def get_meter(name: str, key: str) -> Meter: + """Get Meter instance.""" + return _aggregators[name][key] + + +def get_meters(name: str) -> MetersDict: + """Get Meter instances aggregated under a given *name*.""" + return _aggregators[name] + + +def get_smoothed_value(name: str, key: str) -> float: + """Get a single smoothed value.""" + meters = get_meters(name) + return meters.get_smoothed_value(key) + + +def get_smoothed_values(name: str) -> Dict[str, float]: + """Get smoothed values aggregated under a given *name*.""" + meters = get_meters(name) + return meters.get_smoothed_values() + + +def state_dict(): + return OrderedDict([ + (name, agg.state_dict()) + for name, agg in _aggregators.items() + ]) + + +def load_state_dict(state_dict): + for name, agg_state in state_dict.items(): + _aggregators[name] = MetersDict() + _aggregators[name].load_state_dict(agg_state)