Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 1 Jul 2024 11:09:03 +0000 (14:09 +0300)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 1 Jul 2024 11:09:03 +0000 (14:09 +0300)
main.py
quizz_machine.py

diff --git a/main.py b/main.py
index 10c7b49..eb0ef27 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -435,13 +435,16 @@ def create_c_quizzes(
             f"keep c_quizzes kept {nv} nb_accumulated {nb_validated} / {nb_to_create}"
         )
 
             f"keep c_quizzes kept {nv} nb_accumulated {nb_validated} / {nb_to_create}"
         )
 
-    # ------------------------------------------------------------
+    # store the new c_quizzes which have been validated
 
     new_c_quizzes = valid_c_quizzes(recorded, standard_validity)
 
     quizz_machine.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True)
     quizz_machine.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False)
 
 
     new_c_quizzes = valid_c_quizzes(recorded, standard_validity)
 
     quizz_machine.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True)
     quizz_machine.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False)
 
+    # save a bunch of images to investigate what quizzes with a
+    # certain nb of correct predictions look like
+
     for n in range(len(models) + 1):
         s = (
             "_validated"
     for n in range(len(models) + 1):
         s = (
             "_validated"
index 88f2c1c..d591d79 100755 (executable)
@@ -398,4 +398,16 @@ class QuizzMachine:
                 device=self.device,
             )
 
                 device=self.device,
             )
 
+            c_quizzes = self.reverse_time(c_quizzes)
+            masked_inplace_autoregression(
+                model=model_for_generation,
+                batch_size=self.batch_size,
+                input=c_quizzes,
+                ar_mask=ar_mask_solve,
+                seq_logproba=seq_logproba,
+                temperature=temperature,
+                deterministic_synthesis=True,
+                device=self.device,
+            )
+
         return c_quizzes, seq_logproba.mean()
         return c_quizzes, seq_logproba.mean()