Files
2021-08-19 12:40:54 +08:00

27 lines
922 B
Python

class Meters:
def __init__(self, eps=-1e-3, stop_threshold=10) -> None:
self.eps = eps
self.stop_threshold = stop_threshold
self.avg = 0
self.cnt = 0
self.reset_early_stop()
def reset_early_stop(self):
self.min_loss = float('inf')
self.satis_num = 0
self.update_res = True
self.early_stop = False
def update_avg(self, val, k=1):
self.avg = self.avg + (val - self.avg) * k / (self.cnt + k)
self.cnt += k
def update_early_stop(self, val):
delta = (val - self.min_loss) / self.min_loss
if float(val) < self.min_loss:
self.min_loss = float(val)
self.update_res = True
else:
self.update_res = False
self.satis_num = self.satis_num + 1 if delta >= self.eps else 0
self.early_stop = self.satis_num >= self.stop_threshold