import numpy as np
class BenchmarkStats:
""" Tracks statistics used for benchmarking. """
def __init__(self):
self.utts = []
self.times = []
self.losses = []
def update(self, utts, times, losses):
self.utts.append(utts)
self.times.append(times)
self.losses.append(losses)
def get(self, n_epochs):
throughput = sum(self.utts[-n_epochs:]) / sum(self.times[-n_epochs:])
return {'throughput': throughput, 'benchmark_epochs_num': n_epochs,
'loss': np.mean(self.losses[-n_epochs:])}