Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 8 Aug 2024 10:26:06 +0000 (12:26 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 8 Aug 2024 10:26:06 +0000 (12:26 +0200)
main.py

diff --git a/main.py b/main.py
index c77a7f3..8bca425 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -44,6 +44,7 @@ parser.add_argument("--max_percents_of_test_in_train", type=int, default=-1)
 parser.add_argument("--log_command", type=str, default=None)
 
 # ----------------------------------
+
 parser.add_argument("--nb_epochs", type=int, default=10000)
 
 parser.add_argument("--batch_size", type=int, default=None)
@@ -62,6 +63,8 @@ parser.add_argument("--nb_new_c_quizzes_for_test", type=int, default=None)
 
 parser.add_argument("--learning_rate", type=float, default=5e-4)
 
+parser.add_argument("--schedule_free", action="store_true", default=False)
+
 # ----------------------------------
 parser.add_argument("--model", type=str, default=None)
 
@@ -362,7 +365,8 @@ log_string(f"vocabulary_size {vocabulary_size}")
 
 def run_tests(model, quiz_machine, local_device=main_device):
     with torch.autograd.no_grad():
-        model.eval().to(local_device)
+        model.to(local_device).eval()
+        model.optimizer.eval()
 
         nb_test_samples, acc_test_loss = 0, 0.0
         nb_samples_accumulated = 0
@@ -394,6 +398,7 @@ def run_tests(model, quiz_machine, local_device=main_device):
 
 def one_epoch(model, quiz_machine, local_device=main_device):
     model.to(local_device).train()
+    model.optimizer.train()
 
     nb_train_samples, acc_train_loss = 0, 0.0
 
@@ -995,6 +1000,9 @@ def compute_causal_attzero(t_q, t_k):
     return t_q < t_k
 
 
+if args.schedule_free:
+    import schedulefree
+
 for k in range(args.nb_gpts):
     log_string(f"creating model {k} and its w_quizzes")
 
@@ -1011,7 +1019,13 @@ for k in range(args.nb_gpts):
 
     model.id = k
 
-    model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+    if args.schedule_free:
+        model.optimizer = schedulefree.AdamWScheduleFree(
+            model.parameters(), lr=args.learning_rate
+        )
+    else:
+        model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+
     model.main_test_accuracy = 0.0
 
     model.train_w_quizzes = quiz_machine.problem.generate_w_quizzes(