Source code for hetseq.progress_bar

from collections import OrderedDict
import json
from numbers import Number
import os
import sys

from hetseq import distributed_utils
from hetseq.meters import AverageMeter, StopwatchMeter, TimeMeter

g_tbmf_wrapper = None


def build_progress_bar(args, iterator, epoch=None, prefix=None, default='tqdm', no_progress_bar='none'):
    if args.log_format is None:
        args.log_format = no_progress_bar if args.no_progress_bar else default

    if args.log_format == 'tqdm' and not sys.stderr.isatty():
        args.log_format = 'simple'

    if args.log_format == 'json':
        bar = json_progress_bar(iterator, epoch, prefix, args.log_interval)
    elif args.log_format == 'none':
        bar = noop_progress_bar(iterator, epoch, prefix)
    elif args.log_format == 'simple':
        bar = simple_progress_bar(iterator, epoch, prefix, args.log_interval)
    elif args.log_format == 'tqdm':
        bar = tqdm_progress_bar(iterator, epoch, prefix)
    else:
        raise ValueError('Unknown log format: {}'.format(args.log_format))

    return bar


def format_stat(stat):
    if isinstance(stat, Number):
        stat = '{:g}'.format(stat)
    elif isinstance(stat, AverageMeter):
        stat = '{:.3f}'.format(stat.avg)
    elif isinstance(stat, TimeMeter):
        stat = '{:g}'.format(round(stat.avg))
    elif isinstance(stat, StopwatchMeter):
        #stat = '{:g}'.format(round(stat.sum))
        stat = '{:.4f}'.format(stat.sum)
    return stat


[docs]class progress_bar(object): """Abstract class for progress bars.""" def __init__(self, iterable, epoch=None, prefix=None): self.iterable = iterable self.offset = getattr(iterable, 'offset', 0) self.epoch = epoch self.prefix = '' if epoch is not None: self.prefix += '| epoch {:03d}'.format(epoch) if prefix is not None: self.prefix += ' | {}'.format(prefix) def __len__(self): return len(self.iterable) def __enter__(self): return self def __exit__(self, *exc): return False def __iter__(self): raise NotImplementedError def log(self, stats, tag='', step=None): """Log intermediate stats according to log_interval.""" raise NotImplementedError def print(self, stats, tag='', step=None): """Print end-of-epoch stats.""" raise NotImplementedError def _str_commas(self, stats): return ', '.join(key + '=' + stats[key].strip() for key in stats.keys()) def _str_pipes(self, stats): return ' | '.join(key + ' ' + stats[key].strip() for key in stats.keys()) def _format_stats(self, stats): postfix = OrderedDict(stats) # Preprocess stats according to datatype for key in postfix.keys(): postfix[key] = str(format_stat(postfix[key])) return postfix
[docs]class noop_progress_bar(progress_bar): """No logging.""" def __init__(self, iterable, epoch=None, prefix=None): super().__init__(iterable, epoch, prefix) def __iter__(self): for obj in self.iterable: yield obj def log(self, stats, tag='', step=None): """Log intermediate stats according to log_interval.""" pass def print(self, stats, tag='', step=None): """Print end-of-epoch stats.""" pass
[docs]class simple_progress_bar(progress_bar): """A minimal logger for non-TTY environments.""" def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000): super().__init__(iterable, epoch, prefix) self.log_interval = log_interval self.stats = None def __iter__(self): size = len(self.iterable) for i, obj in enumerate(self.iterable, start=self.offset): yield obj if self.stats is not None and i > 0 and \ self.log_interval is not None and i % self.log_interval == 0: postfix = self._str_commas(self.stats) print('{}: {:5d} / {:d} {}'.format(self.prefix, i, size, postfix), flush=True) def log(self, stats, tag='', step=None): """Log intermediate stats according to log_interval.""" self.stats = self._format_stats(stats) def print(self, stats, tag='', step=None): """Print end-of-epoch stats.""" postfix = self._str_pipes(self._format_stats(stats)) print('{} | {}'.format(self.prefix, postfix), flush=True)