+ def get_state(self):
+ return {"nb_finished_epochs": self.nb_finished_epochs}
+
+
+class AutoScheduler(LearningRateScheduler):
+ def __init__(self, learning_rate_init, growth=1.0, degrowth=0.2):
+ self.learning_rate_init = learning_rate_init
+ self.learning_rate = learning_rate_init
+ self.growth = growth
+ self.degrowth = degrowth
+ self.pred_loss = None
+
+ def get_learning_rate(self):
+ return self.learning_rate
+
+ def update(self, nb_finished_epochs, loss):
+ if self.pred_loss is not None:
+ if loss >= self.pred_loss:
+ self.learning_rate *= self.degrowth
+ else:
+ self.learning_rate *= self.growth
+ self.pred_loss = loss
+
+ def reset(self):
+ self.learning_rate = self.learning_rate_init
+
+ def get_state(self):
+ return {
+ "learning_rate_init": self.learning_rate_init,
+ "pred_loss": self.pred_loss,
+ }
+