Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 3 Aug 2024 04:58:21 +0000 (06:58 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 3 Aug 2024 04:58:21 +0000 (06:58 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index 059a29d..63597b4 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -470,6 +470,8 @@ c_quizzes_procedure = [
     (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_transformer_hot),
     (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_transformer_cold),
     (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold),
+    # (("B", "f_B", "A", "f_A"), (0, 0, 1, 1), model_transformer_cold),
+    # (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold),
 ]
 
 ######################################################################
@@ -477,15 +479,30 @@ c_quizzes_procedure = [
 
 def save_additional_results(models, science_w_quizzes):
     for model in models:
+        recorder = []
+
         c_quizzes = quiz_machine.generate_c_quizzes(
-            128, model_for_generation=model, procedure=c_quizzes_procedure
+            32,
+            model_for_generation=model,
+            procedure=c_quizzes_procedure,
+            recorder=recorder,
         )
 
+        c_quizzes = torch.cat([c[:, None, :] for c, _, in recorder], dim=1)
+        predicted_parts = torch.cat([t[:, None, :] for _, t in recorder], dim=1)
+        nrow = c_quizzes.size(1)
+        c_quizzes = c_quizzes.reshape(-1, c_quizzes.size(-1))
+        predicted_parts = predicted_parts.reshape(-1, predicted_parts.size(-1))
+
+        filename = f"non_validated_{n_epoch:04d}_{model.id:02d}.png"
         quiz_machine.problem.save_quizzes_as_image(
             args.result_dir,
-            f"non_validated_{n_epoch:04d}_{model.id:02d}.png",
-            c_quizzes,
+            filename,
+            quizzes=c_quizzes,
+            predicted_parts=predicted_parts,
+            nrow=nrow,
         )
+        log_string(f"wrote {filename}")
 
     ######################################################################
 
index 015f6d2..3fc1066 100755 (executable)
@@ -374,7 +374,9 @@ class QuizMachine:
 
     ######################################################################
 
-    def generate_c_quizzes(self, nb, model_for_generation, procedure, to_recycle=None):
+    def generate_c_quizzes(
+        self, nb, model_for_generation, procedure, to_recycle=None, recorder=None
+    ):
         seq_logproba = torch.zeros(nb, device=self.device)
 
         c_quizzes = None
@@ -399,6 +401,13 @@ class QuizMachine:
 
             model_for_generation.reset_transformations()
 
+            if recorder is not None:
+                x = c_quizzes.clone()
+                t = torch.tensor(m, device=x.device)[None, :].expand(x.size(0), -1)
+                recorder.append(
+                    self.problem.reconfigure([x, t], ("A", "f_A", "B", "f_B"))
+                )
+
             if to_recycle is not None and to_recycle.size(0) > 0:
                 to_recycle = self.problem.reconfigure(to_recycle, s)
                 c_quizzes[: to_recycle.size(0)] = to_recycle