Update
authorFrançois Fleuret <francois@fleuret.org>
Tue, 28 Mar 2023 20:17:16 +0000 (22:17 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 28 Mar 2023 20:17:16 +0000 (22:17 +0200)
beaver.py

index f395d22..f5b3563 100755 (executable)
--- a/beaver.py
+++ b/beaver.py
@@ -264,7 +264,7 @@ def oneshot(gpt, learning_rate_scheduler, task):
     learning_rate_scheduler.reset()
 
     for n_epoch in range(args.nb_epochs):
-        learning_rate = learning_rate_scheduler.learning_rate()
+        learning_rate = learning_rate_scheduler.get_learning_rate()
         optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
 
         acc_train_loss, nb_train_samples = 0, 0
@@ -342,7 +342,7 @@ def oneshot(gpt, learning_rate_scheduler, task):
 
 
 class LearningRateScheduler:
-    def learning_rate(self):
+    def get_learning_rate(self):
         pass
 
     def update(self, nb_finished_epochs, loss):
@@ -355,7 +355,8 @@ class LearningRateScheduler:
         return vars(self)
 
     def set_state(self, state):
-        for k, v in state.item():
+        print(f"{state=}")
+        for k, v in state.items():
             setattr(self, k, v)
 
 
@@ -364,12 +365,47 @@ class StepWiseScheduler(LearningRateScheduler):
         self.nb_finished_epochs = 0
         self.schedule = schedule
 
-    def learning_rate(self):
+    def get_learning_rate(self):
         return self.schedule[self.nb_finished_epochs]
 
+    def update(self, nb_finished_epochs, loss):
+        self.nb_finished_epochs = nb_finished_epochs
+
     def reset(self):
         self.nb_finished_epochs = 0
 
+    def get_state(self):
+        return {"nb_finished_epochs": self.nb_finished_epochs}
+
+
+class AutoScheduler(LearningRateScheduler):
+    def __init__(self, learning_rate_init, growth=1.0, degrowth=0.2):
+        self.learning_rate_init = learning_rate_init
+        self.learning_rate = learning_rate_init
+        self.growth = growth
+        self.degrowth = degrowth
+        self.pred_loss = None
+
+    def get_learning_rate(self):
+        return self.learning_rate
+
+    def update(self, nb_finished_epochs, loss):
+        if self.pred_loss is not None:
+            if loss >= self.pred_loss:
+                self.learning_rate *= self.degrowth
+            else:
+                self.learning_rate *= self.growth
+        self.pred_loss = loss
+
+    def reset(self):
+        self.learning_rate = self.learning_rate_init
+
+    def get_state(self):
+        return {
+            "learning_rate_init": self.learning_rate_init,
+            "pred_loss": self.pred_loss,
+        }
+
 
 ######################################################################
 
@@ -589,7 +625,7 @@ log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
 ######################################################################
 
 if args.learning_rate_schedule == "auto":
-    pass
+    learning_rate_scheduler = AutoScheduler(args.learning_rate)
 
 elif args.learning_rate_schedule == "cos":
     schedule = {}
@@ -629,6 +665,7 @@ else:
         checkpoint = torch.load(checkpoint_name)
         nb_epochs_finished = checkpoint["nb_epochs_finished"]
         model.load_state_dict(checkpoint["model_state"])
+        learning_rate_scheduler.set_state(checkpoint["learning_rate_scheduler_state"])
         torch.set_rng_state(checkpoint["rng_state"])
         if torch.cuda.is_available():
             torch.cuda.set_rng_state(checkpoint["cuda_rng_state"])
@@ -638,9 +675,9 @@ else:
     except FileNotFoundError:
         log_string("starting from scratch.")
 
-    except:
-        log_string("error when loading the checkpoint.")
-        exit(1)
+    except:
+    # log_string("error when loading the checkpoint.")
+    # exit(1)
 
 ######################################################################
 
@@ -673,7 +710,7 @@ if nb_epochs_finished >= args.nb_epochs:
 learning_rate_scheduler.reset()
 
 for n_epoch in range(nb_epochs_finished, args.nb_epochs):
-    learning_rate = learning_rate_scheduler.learning_rate()
+    learning_rate = learning_rate_scheduler.get_learning_rate()
 
     log_string(f"learning_rate {learning_rate}")
 
@@ -721,6 +758,7 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs):
     checkpoint = {
         "nb_epochs_finished": n_epoch + 1,
         "model_state": model.state_dict(),
+        "learning_rate_scheduler_state": learning_rate_scheduler.get_state(),
         "rng_state": torch.get_rng_state(),
     }