Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 26 Jun 2024 15:36:33 +0000 (17:36 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 26 Jun 2024 15:36:33 +0000 (17:36 +0200)
main.py

diff --git a/main.py b/main.py
index b7b55b5..d063423 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -364,10 +364,18 @@ def create_c_quizzes(
     model_indexes = []
     sum_logits, sum_nb_c_quizzes = 0, 0
 
-    while (
-        sum([x.size(0) for x in recorded[args.nb_correct_to_validate]])
-        < nb_for_train + nb_for_test
-    ):
+    def nb_generated():
+        return sum([sum([x.size(0) for x in recorded[n]]) for n in recorded.keys()])
+
+    def nb_validated():
+        return sum(
+            [
+                sum([x.size(0) for x in recorded[n]])
+                for n in range(args.nb_correct_to_validate, len(models))
+            ]
+        )
+
+    while nb_validated() < nb_for_train + nb_for_test:
         nb_to_validate = nb_for_train + nb_for_test
 
         if len(model_indexes) == 0:
@@ -396,13 +404,8 @@ def create_c_quizzes(
         for n in range(nb_correct.max() + 1):
             recorded[n].append(new_c_quizzes[nb_correct == n].clone())
 
-        nb_validated = sum([x.size(0) for x in recorded[args.nb_correct_to_validate]])
-        nb_generated = sum(
-            [sum([x.size(0) for x in recorded[n]]) for n in recorded.keys()]
-        )
-
         log_string(
-            f"keep c_quizzes {nb_validated*100/nb_generated:.02f}% kept total {nb_validated}/{nb_to_validate}"
+            f"keep c_quizzes {nb_validated()*100/nb_generated():.02f}% kept total {nb_validated()} / {nb_to_validate}"
         )
 
     # concatenate and shuffle
@@ -414,13 +417,21 @@ def create_c_quizzes(
         else:
             del recorded[n]
 
-    new_c_quizzes = recorded[args.nb_correct_to_validate][: nb_for_train + nb_for_test]
+    new_c_quizzes = torch.cat(
+        [recorded[n] for n in range(args.nb_correct_to_validate, len(models))], dim=0
+    )
+
+    new_c_quizzes = new_c_quizzes[
+        torch.randperm(new_c_quizzes.size(0), device=new_c_quizzes.device)[
+            : nb_for_train + nb_for_test
+        ]
+    ]
 
     quizz_machine.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True)
     quizz_machine.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False)
 
     for n in recorded.keys():
-        s = "_validated" if n == args.nb_correct_to_validate else ""
+        s = "_validated" if n >= args.nb_correct_to_validate and n < len(models) else ""
         quizz_machine.problem.save_quizzes(
             recorded[n][:72],
             args.result_dir,