sczhou's picture
init code
320e465
raw
history blame
1.47 kB
import time
class TimeEstimator:
def __init__(self, total_iter, step_size):
self.avg_time_window = [] # window-based average
self.exp_avg_time = None # exponential moving average
self.alpha = 0.7 # for exponential moving average
self.last_time = time.time() # would not be accurate for the first iteration but well
self.total_iter = total_iter
self.step_size = step_size
self.buffering_exp = True
# call this at a fixed interval
# does not have to be every step
def update(self):
curr_time = time.time()
time_per_iter = curr_time - self.last_time
self.last_time = curr_time
self.avg_time_window.append(time_per_iter)
if self.buffering_exp:
if self.exp_avg_time is not None:
# discard the first iteration call to not pollute the ema
self.buffering_exp = False
self.exp_avg_time = time_per_iter
else:
self.exp_avg_time = self.alpha * self.exp_avg_time + (1 - self.alpha) * time_per_iter
def get_est_remaining(self, it):
if self.exp_avg_time is None:
return 0
remaining_iter = self.total_iter - it
return remaining_iter * self.exp_avg_time / self.step_size
def get_and_reset_avg_time(self):
avg = sum(self.avg_time_window) / len(self.avg_time_window) / self.step_size
self.avg_time_window = []
return avg