Update.
[culture.git] / quizz_machine.py
index 717e8ac..92b5799 100755 (executable)
@@ -238,7 +238,7 @@ class QuizzMachine:
                 result_dir,
                 "culture_w_quizzes",
                 self.train_w_quizzes[:72],
                 result_dir,
                 "culture_w_quizzes",
                 self.train_w_quizzes[:72],
-                show_to_be_predicted=True,
+                n_backward=self.train_w_quizzes[:72, 0] == self.token_backward,
             )
 
     def save_quizzes(
             )
 
     def save_quizzes(
@@ -246,7 +246,7 @@ class QuizzMachine:
         result_dir,
         filename_prefix,
         quizzes,
         result_dir,
         filename_prefix,
         quizzes,
-        show_to_be_predicted=False,
+        n_backward=None,
         mistakes=None,
     ):
         quizzes = quizzes.clone()
         mistakes=None,
     ):
         quizzes = quizzes.clone()
@@ -256,8 +256,11 @@ class QuizzMachine:
         assert forward.size(0) + backward.size(0) == quizzes.size(0)
         quizzes[ib] = self.reverse_time(quizzes[ib])
 
         assert forward.size(0) + backward.size(0) == quizzes.size(0)
         quizzes[ib] = self.reverse_time(quizzes[ib])
 
-        if show_to_be_predicted:
-            predicted_prompts = ib.long()
+        if n_backward is None:
+            predicted_prompts = None
+            predicted_answers = None
+        else:
+            predicted_prompts = n_backward.long()
             predicted_answers = 1 - predicted_prompts
             if mistakes is not None:
                 # 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct
             predicted_answers = 1 - predicted_prompts
             if mistakes is not None:
                 # 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct
@@ -267,9 +270,6 @@ class QuizzMachine:
                 # 0/2 ~ not-to-predict / to predict
                 predicted_prompts *= 2
                 predicted_answers *= 2
                 # 0/2 ~ not-to-predict / to predict
                 predicted_prompts *= 2
                 predicted_answers *= 2
-        else:
-            predicted_prompts = None
-            predicted_answers = None
 
         self.problem.save_quizzes(
             result_dir,
 
         self.problem.save_quizzes(
             result_dir,
@@ -360,28 +360,28 @@ class QuizzMachine:
                 result[n_backward], correct[n_backward] = compute_accuracy(back_input)
 
             if log_prefix is not None:
                 result[n_backward], correct[n_backward] = compute_accuracy(back_input)
 
             if log_prefix is not None:
-                nb_correct = correct[n_forward].sum()
-                nb_total = correct[n_forward].size(0)
-                back_nb_correct = correct[n_backward].sum()
-                back_nb_total = correct[n_backward].size(0)
+                forward_nb_correct = correct[n_forward].sum()
+                forward_nb_total = correct[n_forward].size(0)
+                backward_nb_correct = correct[n_backward].sum()
+                backward_nb_total = correct[n_backward].size(0)
 
                 self.logger(
 
                 self.logger(
-                    f"accuracy {log_prefix} {n_epoch} {model.id=} {nb_correct} / {nb_total}"
+                    f"forward_accuracy {log_prefix} {n_epoch} {model.id=} {forward_nb_correct} / {forward_nb_total}"
                 )
 
                 self.logger(
                 )
 
                 self.logger(
-                    f"back_accuracy {log_prefix} {n_epoch} {model.id=} {back_nb_correct} / {back_nb_total}"
+                    f"backward_accuracy {log_prefix} {n_epoch} {model.id=} {backward_nb_correct} / {backward_nb_total}"
                 )
 
             return result, correct
 
         compute_accuracy(self.train_w_quizzes[:nmax], log_prefix="train")
 
                 )
 
             return result, correct
 
         compute_accuracy(self.train_w_quizzes[:nmax], log_prefix="train")
 
-        result, correct = compute_accuracy(
+        test_result, test_correct = compute_accuracy(
             self.test_w_quizzes[:nmax], log_prefix="test"
         )
 
             self.test_w_quizzes[:nmax], log_prefix="test"
         )
 
-        main_test_accuracy = correct.sum() / correct.size(0)
+        main_test_accuracy = test_correct.sum() / test_correct.size(0)
         self.logger(f"main_test_accuracy {n_epoch} {main_test_accuracy}")
 
         ##############################
         self.logger(f"main_test_accuracy {n_epoch} {main_test_accuracy}")
 
         ##############################
@@ -389,9 +389,9 @@ class QuizzMachine:
         self.save_quizzes(
             result_dir,
             f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
         self.save_quizzes(
             result_dir,
             f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
-            quizzes=result[:72],
-            show_to_be_predicted=True,
-            mistakes=correct[:72] * 2 - 1,
+            quizzes=test_result[:72],
+            n_backward=self.test_w_quizzes[:72, 0] == self.token_backward,
+            mistakes=test_correct[:72] * 2 - 1,
         )
 
         return main_test_accuracy
         )
 
         return main_test_accuracy