Update.
[culture.git] / main.py
diff --git a/main.py b/main.py
index 10c7b49..0a7be99 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -417,7 +417,9 @@ def create_c_quizzes(
         sum_logits += c_quizzes.size(0) * ave_seq_logproba
         sum_nb_c_quizzes += c_quizzes.size(0)
 
-        nb_correct = quizz_machine.compute_correctness(c_quizzes, models)
+        nb_correct = quizz_machine.compute_correctness(
+            c_quizzes, models, both_direction=True
+        )
 
         if args.dirty_debug:
             nb_correct = torch.randint(
@@ -435,13 +437,16 @@ def create_c_quizzes(
             f"keep c_quizzes kept {nv} nb_accumulated {nb_validated} / {nb_to_create}"
         )
 
-    # ------------------------------------------------------------
+    # store the new c_quizzes which have been validated
 
     new_c_quizzes = valid_c_quizzes(recorded, standard_validity)
 
     quizz_machine.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True)
     quizz_machine.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False)
 
+    # save a bunch of images to investigate what quizzes with a
+    # certain nb of correct predictions look like
+
     for n in range(len(models) + 1):
         s = (
             "_validated"