from collections import OrderedDict
import json
from numbers import Number
import os
import sys
from hetseq import distributed_utils
from hetseq.meters import AverageMeter, StopwatchMeter, TimeMeter
g_tbmf_wrapper = None
def build_progress_bar(args, iterator, epoch=None, prefix=None, default='tqdm', no_progress_bar='none'):
if args.log_format is None:
args.log_format = no_progress_bar if args.no_progress_bar else default
if args.log_format == 'tqdm' and not sys.stderr.isatty():
args.log_format = 'simple'
if args.log_format == 'json':
bar = json_progress_bar(iterator, epoch, prefix, args.log_interval)
elif args.log_format == 'none':
bar = noop_progress_bar(iterator, epoch, prefix)
elif args.log_format == 'simple':
bar = simple_progress_bar(iterator, epoch, prefix, args.log_interval)
elif args.log_format == 'tqdm':
bar = tqdm_progress_bar(iterator, epoch, prefix)
else:
raise ValueError('Unknown log format: {}'.format(args.log_format))
return bar
def format_stat(stat):
if isinstance(stat, Number):
stat = '{:g}'.format(stat)
elif isinstance(stat, AverageMeter):
stat = '{:.3f}'.format(stat.avg)
elif isinstance(stat, TimeMeter):
stat = '{:g}'.format(round(stat.avg))
elif isinstance(stat, StopwatchMeter):
#stat = '{:g}'.format(round(stat.sum))
stat = '{:.4f}'.format(stat.sum)
return stat
[docs]class progress_bar(object):
"""Abstract class for progress bars."""
def __init__(self, iterable, epoch=None, prefix=None):
self.iterable = iterable
self.offset = getattr(iterable, 'offset', 0)
self.epoch = epoch
self.prefix = ''
if epoch is not None:
self.prefix += '| epoch {:03d}'.format(epoch)
if prefix is not None:
self.prefix += ' | {}'.format(prefix)
def __len__(self):
return len(self.iterable)
def __enter__(self):
return self
def __exit__(self, *exc):
return False
def __iter__(self):
raise NotImplementedError
def log(self, stats, tag='', step=None):
"""Log intermediate stats according to log_interval."""
raise NotImplementedError
def print(self, stats, tag='', step=None):
"""Print end-of-epoch stats."""
raise NotImplementedError
def _str_commas(self, stats):
return ', '.join(key + '=' + stats[key].strip()
for key in stats.keys())
def _str_pipes(self, stats):
return ' | '.join(key + ' ' + stats[key].strip()
for key in stats.keys())
def _format_stats(self, stats):
postfix = OrderedDict(stats)
# Preprocess stats according to datatype
for key in postfix.keys():
postfix[key] = str(format_stat(postfix[key]))
return postfix
[docs]class noop_progress_bar(progress_bar):
"""No logging."""
def __init__(self, iterable, epoch=None, prefix=None):
super().__init__(iterable, epoch, prefix)
def __iter__(self):
for obj in self.iterable:
yield obj
def log(self, stats, tag='', step=None):
"""Log intermediate stats according to log_interval."""
pass
def print(self, stats, tag='', step=None):
"""Print end-of-epoch stats."""
pass
[docs]class simple_progress_bar(progress_bar):
"""A minimal logger for non-TTY environments."""
def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000):
super().__init__(iterable, epoch, prefix)
self.log_interval = log_interval
self.stats = None
def __iter__(self):
size = len(self.iterable)
for i, obj in enumerate(self.iterable, start=self.offset):
yield obj
if self.stats is not None and i > 0 and \
self.log_interval is not None and i % self.log_interval == 0:
postfix = self._str_commas(self.stats)
print('{}: {:5d} / {:d} {}'.format(self.prefix, i, size, postfix),
flush=True)
def log(self, stats, tag='', step=None):
"""Log intermediate stats according to log_interval."""
self.stats = self._format_stats(stats)
def print(self, stats, tag='', step=None):
"""Print end-of-epoch stats."""
postfix = self._str_pipes(self._format_stats(stats))
print('{} | {}'.format(self.prefix, postfix), flush=True)