From 981c10ef7b0a7a9af9488d9b21925a36077dc80c Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 8 Aug 2024 12:26:06 +0200 Subject: [PATCH] Update. --- main.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index c77a7f3..8bca425 100755 --- a/main.py +++ b/main.py @@ -44,6 +44,7 @@ parser.add_argument("--max_percents_of_test_in_train", type=int, default=-1) parser.add_argument("--log_command", type=str, default=None) # ---------------------------------- + parser.add_argument("--nb_epochs", type=int, default=10000) parser.add_argument("--batch_size", type=int, default=None) @@ -62,6 +63,8 @@ parser.add_argument("--nb_new_c_quizzes_for_test", type=int, default=None) parser.add_argument("--learning_rate", type=float, default=5e-4) +parser.add_argument("--schedule_free", action="store_true", default=False) + # ---------------------------------- parser.add_argument("--model", type=str, default=None) @@ -362,7 +365,8 @@ log_string(f"vocabulary_size {vocabulary_size}") def run_tests(model, quiz_machine, local_device=main_device): with torch.autograd.no_grad(): - model.eval().to(local_device) + model.to(local_device).eval() + model.optimizer.eval() nb_test_samples, acc_test_loss = 0, 0.0 nb_samples_accumulated = 0 @@ -394,6 +398,7 @@ def run_tests(model, quiz_machine, local_device=main_device): def one_epoch(model, quiz_machine, local_device=main_device): model.to(local_device).train() + model.optimizer.train() nb_train_samples, acc_train_loss = 0, 0.0 @@ -995,6 +1000,9 @@ def compute_causal_attzero(t_q, t_k): return t_q < t_k +if args.schedule_free: + import schedulefree + for k in range(args.nb_gpts): log_string(f"creating model {k} and its w_quizzes") @@ -1011,7 +1019,13 @@ for k in range(args.nb_gpts): model.id = k - model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) + if args.schedule_free: + model.optimizer = schedulefree.AdamWScheduleFree( + model.parameters(), lr=args.learning_rate + ) + else: + model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) + model.main_test_accuracy = 0.0 model.train_w_quizzes = quiz_machine.problem.generate_w_quizzes( -- 2.20.1