Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 18 Aug 2024 15:39:43 +0000 (17:39 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 18 Aug 2024 15:39:43 +0000 (17:39 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index db16214..046514d 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -61,6 +61,8 @@ parser.add_argument("--nb_new_c_quizzes_for_train", type=int, default=None)
 
 parser.add_argument("--nb_new_c_quizzes_for_test", type=int, default=None)
 
+parser.add_argument("--c_quiz_multiplier", type=int, default=1)
+
 parser.add_argument("--learning_rate", type=float, default=5e-4)
 
 parser.add_argument("--lambda_H", type=float, default=0.0)
@@ -389,7 +391,7 @@ def one_epoch(model, quiz_machine, local_device=main_device):
     nb_train_samples, acc_train_loss = 0, 0.0
 
     full_input, full_mask_loss = quiz_machine.data_input(
-        args.nb_train_samples, model.train_c_quiz_bags
+        args.nb_train_samples, model.train_c_quiz_bags, args.c_quiz_multiplier
     )
     src = zip(full_input.split(args.batch_size), full_mask_loss.split(args.batch_size))
 
@@ -900,10 +902,10 @@ log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
 ######################################################################
 
 if args.nb_new_c_quizzes_for_train is None:
-    args.nb_new_c_quizzes_for_train = args.nb_train_samples // 1000
+    args.nb_new_c_quizzes_for_train = args.nb_train_samples // 250
 
 if args.nb_new_c_quizzes_for_test is None:
-    args.nb_new_c_quizzes_for_test = args.nb_test_samples // 1000
+    args.nb_new_c_quizzes_for_test = args.nb_test_samples // 250
 
 log_string(
     f"nb_new_c_quizzes_for_train {args.nb_new_c_quizzes_for_train} nb_new_c_quizzes_for_test {args.nb_new_c_quizzes_for_test}"
@@ -1126,8 +1128,6 @@ for n_epoch in range(current_epoch, args.nb_epochs):
     log_string(f"current_best_test_accuracies {cta}")
 
     ##################################################
-    # If all the models are good enough, generate new quizzes and
-    # re-compute the test errors
 
     for model in models:
         if model.test_accuracy >= args.accuracy_to_make_c_quizzes:
@@ -1136,7 +1136,6 @@ for n_epoch in range(current_epoch, args.nb_epochs):
             )
             model.best_dict = copy.deepcopy(model.state_dict())
             model.best_test_accuracy = model.test_accuracy
-            model.test_accuracy = 0.0
 
     # we restart
     if total_time_generating_c_quizzes == 0:
@@ -1167,7 +1166,17 @@ for n_epoch in range(current_epoch, args.nb_epochs):
     # Select, improve, and eval the worst model(s)
 
     if total_time_training_models <= total_time_generating_c_quizzes:
-        ranked_models = sorted(models, key=lambda m: float(m.test_accuracy))
+        ranked_models = sorted(
+            models,
+            # This ugly recipe will pick the worst if there some below
+            # args.accuracy_to_make_c_quizzes or one at random if they
+            # are all above
+            key=lambda m: float(
+                m.test_accuracy
+                if m.test_accuracy < args.accuracy_to_make_c_quizzes
+                else args.accuracy_to_make_c_quizzes + torch.rand(1).item()
+            ),
+        )
 
         weakest_models = ranked_models[: len(gpus)]
 
index 18136e8..a0b007a 100755 (executable)
@@ -140,10 +140,23 @@ class QuizMachine:
 
     ######################################################################
 
-    def data_input(self, nb_samples, c_quiz_bags):
+    def data_input(self, nb_samples, c_quiz_bags, c_quiz_multiplier=1):
         if len(c_quiz_bags) > 0:
             c_quizzes = torch.cat(c_quiz_bags, dim=0)
 
+            if c_quiz_multiplier > 1:
+                n = min(c_quiz_multiplier, (nb_samples // 2) // c_quizzes.size(0))
+                body = c_quizzes.repeat(n, 1)
+                if n < c_quiz_multiplier:
+                    tail = c_quizzes[
+                        torch.randperm(c_quizzes.size(0))[
+                            : nb_samples // 2 - body.size(0)
+                        ]
+                    ]
+                    c_quizzes = torch.cat([body, tail], dim=0)
+                else:
+                    c_quizzes = body
+
             if c_quizzes.size(0) > nb_samples // 2:
                 i = torch.randperm(c_quizzes.size(0))[: nb_samples // 2]
                 c_quizzes = c_quizzes[i]