Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 20 Jul 2024 22:45:18 +0000 (00:45 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 20 Jul 2024 22:45:18 +0000 (00:45 +0200)
main.py

diff --git a/main.py b/main.py
index 7588a50..653f5f5 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -472,11 +472,19 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
             temperature_cold=args.temperature_cold,
         )
 
-        recorded_too_simple.append(
-            keep_good_quizzes(models, c_quizzes, required_nb_failures=0)
+        c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)]
+
+        nc = quiz_machine.solution_nb_correct(models, c_quizzes)
+
+        count_nc = tuple(
+            n.item() for n in F.one_hot(nc, num_classes=len(models) + 1).sum(dim=0)
         )
 
-        c_quizzes = keep_good_quizzes(models, c_quizzes)
+        log_string(f"nb_correct {count_nc}")
+
+        recorded_too_simple.append(c_quizzes[nc == len(models)])
+
+        c_quizzes = c_quizzes[nc == len(models) - 1]
 
         nb_validated[model_for_generation.id] += c_quizzes.size(0)
         total_nb_validated = nb_validated.sum().item()
@@ -517,7 +525,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
     ######################################################################
     # save images
 
-    vq = validated_quizzes[:128]
+    vq = validated_quizzes[torch.randperm(validated_quizzes.size(0))[:128]]
 
     if vq.size(0) > 0:
         prefix = f"culture_c_quiz_{n_epoch:04d}"
@@ -525,7 +533,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
             args.result_dir, prefix, vq, show_part_to_predict=False
         )
 
-    vq = too_simple_quizzes[:128]
+    vq = too_simple_quizzes
 
     if vq.size(0) > 0:
         prefix = f"culture_c_quiz_{n_epoch:04d}_too_simple"
@@ -642,6 +650,11 @@ if args.dirty_debug:
 ######################################################################
 
 for n_epoch in range(current_epoch, args.nb_epochs):
+    state = {"current_epoch": n_epoch}
+    filename = "state.pth"
+    torch.save(state, os.path.join(args.result_dir, filename))
+    log_string(f"wrote {filename}")
+
     log_string(f"--- epoch {n_epoch} ----------------------------------------")
 
     cta = " ".join([f"{float(m.main_test_accuracy):.04f}" for m in models])
@@ -700,11 +713,6 @@ for n_epoch in range(current_epoch, args.nb_epochs):
         )
         log_string(f"wrote {filename}")
 
-    state = {"current_epoch": n_epoch}
-    filename = "state.pth"
-    torch.save(state, os.path.join(args.result_dir, filename))
-    log_string(f"wrote {filename}")
-
     # Renew the training samples
 
     for model in weakest_models: