From a113de0d0ba103b6fb1bfdec69b550147a2a262f Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 28 Mar 2023 22:17:16 +0200 Subject: [PATCH] Update --- beaver.py | 56 ++++++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 47 insertions(+), 9 deletions(-) diff --git a/beaver.py b/beaver.py index f395d22..f5b3563 100755 --- a/beaver.py +++ b/beaver.py @@ -264,7 +264,7 @@ def oneshot(gpt, learning_rate_scheduler, task): learning_rate_scheduler.reset() for n_epoch in range(args.nb_epochs): - learning_rate = learning_rate_scheduler.learning_rate() + learning_rate = learning_rate_scheduler.get_learning_rate() optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) acc_train_loss, nb_train_samples = 0, 0 @@ -342,7 +342,7 @@ def oneshot(gpt, learning_rate_scheduler, task): class LearningRateScheduler: - def learning_rate(self): + def get_learning_rate(self): pass def update(self, nb_finished_epochs, loss): @@ -355,7 +355,8 @@ class LearningRateScheduler: return vars(self) def set_state(self, state): - for k, v in state.item(): + print(f"{state=}") + for k, v in state.items(): setattr(self, k, v) @@ -364,12 +365,47 @@ class StepWiseScheduler(LearningRateScheduler): self.nb_finished_epochs = 0 self.schedule = schedule - def learning_rate(self): + def get_learning_rate(self): return self.schedule[self.nb_finished_epochs] + def update(self, nb_finished_epochs, loss): + self.nb_finished_epochs = nb_finished_epochs + def reset(self): self.nb_finished_epochs = 0 + def get_state(self): + return {"nb_finished_epochs": self.nb_finished_epochs} + + +class AutoScheduler(LearningRateScheduler): + def __init__(self, learning_rate_init, growth=1.0, degrowth=0.2): + self.learning_rate_init = learning_rate_init + self.learning_rate = learning_rate_init + self.growth = growth + self.degrowth = degrowth + self.pred_loss = None + + def get_learning_rate(self): + return self.learning_rate + + def update(self, nb_finished_epochs, loss): + if self.pred_loss is not None: + if loss >= self.pred_loss: + self.learning_rate *= self.degrowth + else: + self.learning_rate *= self.growth + self.pred_loss = loss + + def reset(self): + self.learning_rate = self.learning_rate_init + + def get_state(self): + return { + "learning_rate_init": self.learning_rate_init, + "pred_loss": self.pred_loss, + } + ###################################################################### @@ -589,7 +625,7 @@ log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)") ###################################################################### if args.learning_rate_schedule == "auto": - pass + learning_rate_scheduler = AutoScheduler(args.learning_rate) elif args.learning_rate_schedule == "cos": schedule = {} @@ -629,6 +665,7 @@ else: checkpoint = torch.load(checkpoint_name) nb_epochs_finished = checkpoint["nb_epochs_finished"] model.load_state_dict(checkpoint["model_state"]) + learning_rate_scheduler.set_state(checkpoint["learning_rate_scheduler_state"]) torch.set_rng_state(checkpoint["rng_state"]) if torch.cuda.is_available(): torch.cuda.set_rng_state(checkpoint["cuda_rng_state"]) @@ -638,9 +675,9 @@ else: except FileNotFoundError: log_string("starting from scratch.") - except: - log_string("error when loading the checkpoint.") - exit(1) + # except: + # log_string("error when loading the checkpoint.") + # exit(1) ###################################################################### @@ -673,7 +710,7 @@ if nb_epochs_finished >= args.nb_epochs: learning_rate_scheduler.reset() for n_epoch in range(nb_epochs_finished, args.nb_epochs): - learning_rate = learning_rate_scheduler.learning_rate() + learning_rate = learning_rate_scheduler.get_learning_rate() log_string(f"learning_rate {learning_rate}") @@ -721,6 +758,7 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs): checkpoint = { "nb_epochs_finished": n_epoch + 1, "model_state": model.state_dict(), + "learning_rate_scheduler_state": learning_rate_scheduler.get_state(), "rng_state": torch.get_rng_state(), } -- 2.20.1