From a09ee76c8283b7daf4c914df47f86d1964fc25d4 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 6 Jan 2024 14:30:38 +0100 Subject: [PATCH] Update. --- main.py | 27 ++++++++++++++++++++++++--- tasks.py | 4 ++-- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/main.py b/main.py index 74e70b2..1a17e51 100755 --- a/main.py +++ b/main.py @@ -66,6 +66,16 @@ parser.add_argument("--learning_rate", type=float, default=6e-4) parser.add_argument("--min_learning_rate", type=float, default=6e-5) +# legacy + +parser.add_argument("--legacy_lr_schedule", action="store_true", default=False) + +parser.add_argument("--legacy_learning_rate", type=float, default=1e-4) + +parser.add_argument("--legacy_min_learning_rate", type=float, default=2e-5) + +parser.add_argument("--nb_large_lr_epochs", type=float, default=10) + ######################################## parser.add_argument("--model", type=str, default=None) @@ -460,10 +470,21 @@ for n in vars(args): ###################################################################### -# from nanoGPT +def get_lr(n_epoch, it): + if args.legacy_lr_schedule: + # my crude scheduling to compare to previous baseline, added + # warmup though + + if it < args.nb_warmup_iter: + return args.legacy_learning_rate * it / args.nb_warmup_iter + elif it < args.nb_large_lr_epochs: + return args.legacy_learning_rate + else: + return args.legacy_min_learning_rate + + # from nanoGPT -def get_lr(it): # 1) linear warmup for warmup_iter steps if it < args.nb_warmup_iter: return args.learning_rate * it / args.nb_warmup_iter @@ -848,7 +869,7 @@ for n_epoch in range(nb_epochs_finished, nb_epochs): total_loss = loss + (args.rho * inner_loss if args.rho > 0 else 0.0) it += 1 - lr = get_lr(it) + lr = get_lr(n_epoch, it) for param_group in optimizer.param_groups: param_group["lr"] = lr diff --git a/tasks.py b/tasks.py index 58638ed..afad8af 100755 --- a/tasks.py +++ b/tasks.py @@ -58,7 +58,7 @@ def masked_inplace_autoregression( class Task: - def batches(self, split="train"): + def batches(self, split="train", desc=None): pass def vocabulary_size(self): @@ -328,7 +328,7 @@ class PicoCLVR(Task): self.train_input = self.tensorize(self.train_descr) self.test_input = self.tensorize(self.test_descr) - def batches(self, split="train"): + def batches(self, split="train", desc=None): assert split in {"train", "test"} input = self.train_input if split == "train" else self.test_input for batch in tqdm.tqdm( -- 2.20.1