Update.
[culture.git] / quizz_machine.py
index 62ae8ce..632c9ae 100755 (executable)
@@ -238,10 +238,17 @@ class QuizzMachine:
                 result_dir,
                 "culture_w_quizzes",
                 self.train_w_quizzes[:72],
-                prediction=True,
+                show_to_be_predicted=True,
             )
 
-    def save_quizzes(self, result_dir, filename_prefix, quizzes, prediction=False):
+    def save_quizzes(
+        self,
+        result_dir,
+        filename_prefix,
+        quizzes,
+        show_to_be_predicted=False,
+        mistakes=None,
+    ):
         quizzes = quizzes.clone()
         forward = quizzes[quizzes[:, 0] == self.token_forward]
         ib = quizzes[:, 0] == self.token_backward
@@ -249,9 +256,17 @@ class QuizzMachine:
         assert forward.size(0) + backward.size(0) == quizzes.size(0)
         quizzes[ib] = self.reverse_time(quizzes[ib])
 
-        if prediction:
-            predicted_prompts = ib
-            predicted_answers = torch.logical_not(ib)
+        if show_to_be_predicted:
+            predicted_prompts = ib.long()
+            predicted_answers = 1 - predicted_prompts
+            if mistakes is not None:
+                # 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct
+                predicted_prompts *= mistakes
+                predicted_answers *= mistakes
+            else:
+                # 0/2 ~ not-to-predict / to predict
+                predicted_prompts *= 2
+                predicted_answers *= 2
         else:
             predicted_prompts = None
             predicted_answers = None
@@ -409,11 +424,14 @@ class QuizzMachine:
             device=self.device,
         )
 
+        mistakes = (input == result).flatten(1).long().min(dim=1).values * 2 - 1
+
         self.save_quizzes(
             result_dir,
             f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
             quizzes=result[:72],
-            prediction=True,
+            show_to_be_predicted=True,
+            mistakes=mistakes[:72],
         )
 
         return main_test_accuracy