Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 6 Jul 2024 04:37:52 +0000 (07:37 +0300)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 6 Jul 2024 04:37:52 +0000 (07:37 +0300)
quizz_machine.py

index c6c2f95..92b5799 100755 (executable)
@@ -238,7 +238,7 @@ class QuizzMachine:
                 result_dir,
                 "culture_w_quizzes",
                 self.train_w_quizzes[:72],
-                show_to_be_predicted=True,
+                n_backward=self.train_w_quizzes[:72, 0] == self.token_backward,
             )
 
     def save_quizzes(
@@ -246,7 +246,7 @@ class QuizzMachine:
         result_dir,
         filename_prefix,
         quizzes,
-        show_to_be_predicted=False,
+        n_backward=None,
         mistakes=None,
     ):
         quizzes = quizzes.clone()
@@ -256,8 +256,11 @@ class QuizzMachine:
         assert forward.size(0) + backward.size(0) == quizzes.size(0)
         quizzes[ib] = self.reverse_time(quizzes[ib])
 
-        if show_to_be_predicted:
-            predicted_prompts = ib.long()
+        if n_backward is None:
+            predicted_prompts = None
+            predicted_answers = None
+        else:
+            predicted_prompts = n_backward.long()
             predicted_answers = 1 - predicted_prompts
             if mistakes is not None:
                 # 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct
@@ -267,9 +270,6 @@ class QuizzMachine:
                 # 0/2 ~ not-to-predict / to predict
                 predicted_prompts *= 2
                 predicted_answers *= 2
-        else:
-            predicted_prompts = None
-            predicted_answers = None
 
         self.problem.save_quizzes(
             result_dir,
@@ -390,7 +390,7 @@ class QuizzMachine:
             result_dir,
             f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
             quizzes=test_result[:72],
-            show_to_be_predicted=True,
+            n_backward=self.test_w_quizzes[:72, 0] == self.token_backward,
             mistakes=test_correct[:72] * 2 - 1,
         )