Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 13 Aug 2024 12:26:55 +0000 (14:26 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 13 Aug 2024 12:26:55 +0000 (14:26 +0200)
main.py

diff --git a/main.py b/main.py
index bd46948..dda62af 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -165,7 +165,7 @@ assert not args.grids_science_tasks or (
 default_args = {
     "model": "37M",
     "batch_size": 25,
-    "inference_batch_size": 50,
+    "inference_batch_size": 25,
     "nb_train_samples": 40000,
     "nb_test_samples": 1000,
 }
@@ -806,19 +806,13 @@ if args.resume:
             model.load_state_dict(d["state_dict"])
             model.optimizer.load_state_dict(d["optimizer_state_dict"])
             model.main_test_accuracy = d["main_test_accuracy"]
+            model.train_c_quiz_bags = d["train_c_quiz_bags"]
+            model.test_c_quiz_bags = d["test_c_quiz_bags"]
             log_string(f"successfully loaded {filename}")
         except FileNotFoundError:
             log_string(f"cannot find {filename}")
             pass
 
-    try:
-        filename = "c_quizzes.pth"
-        quiz_machine.load_c_quizzes(os.path.join(args.result_dir, filename))
-        log_string(f"successfully loaded {filename}")
-    except FileNotFoundError:
-        log_string(f"cannot find {filename}")
-        pass
-
     try:
         filename = "state.pth"
         state = torch.load(os.path.join(args.result_dir, filename))
@@ -878,10 +872,6 @@ for n_epoch in range(current_epoch, args.nb_epochs):
             args.nb_new_c_quizzes_for_test,
         )
 
-        filename = "c_quizzes.pth"
-        quiz_machine.save_c_quizzes(os.path.join(args.result_dir, filename))
-        log_string(f"wrote {filename}")
-
         # Force one epoch of training
         for model in models:
             model.main_test_accuracy = 0.0
@@ -918,6 +908,8 @@ for n_epoch in range(current_epoch, args.nb_epochs):
                 "state_dict": model.state_dict(),
                 "optimizer_state_dict": model.optimizer.state_dict(),
                 "main_test_accuracy": model.main_test_accuracy,
+                "train_c_quiz_bags": model.train_c_quiz_bags,
+                "test_c_quiz_bags": model.test_c_quiz_bags,
             },
             os.path.join(args.result_dir, filename),
         )