Update.
[culture.git] / quiz_machine.py
index 8ab5696..88fd9f1 100755 (executable)
@@ -241,7 +241,7 @@ class QuizMachine:
         self.train_c_quizzes = []
         self.test_c_quizzes = []
 
-    def save_quizzes(
+    def save_quiz_illustrations(
         self,
         result_dir,
         filename_prefix,
@@ -266,7 +266,7 @@ class QuizMachine:
             predicted_prompts *= 2
             predicted_answers *= 2
 
-        self.problem.save_quizzes(
+        self.problem.save_quiz_illustrations(
             result_dir,
             filename_prefix,
             quizzes[:, 1 : 1 + self.prompt_len],
@@ -368,11 +368,7 @@ class QuizMachine:
                 backward_nb_total = correct[n_backward].size(0)
 
                 self.logger(
-                    f"{log_prefix}_forward_accuracy {n_epoch} model {model.id} nb_correct {forward_nb_correct} / {forward_nb_total} ({forward_nb_correct*100/forward_nb_total} %)"
-                )
-
-                self.logger(
-                    f"{log_prefix}_backward_accuracy {n_epoch} model {model.id} nb_correct {backward_nb_correct} / {backward_nb_total} ({backward_nb_correct*100/backward_nb_total} %)"
+                    f"{log_prefix}_accuracy {n_epoch} model {model.id} forward {forward_nb_correct} / {forward_nb_total} backward {backward_nb_correct} / {backward_nb_total}"
                 )
 
             return result, correct
@@ -388,7 +384,7 @@ class QuizMachine:
 
         ##############################
 
-        self.save_quizzes(
+        self.save_quiz_illustrations(
             result_dir,
             f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
             quizzes=test_result[:72],
@@ -420,7 +416,7 @@ class QuizMachine:
 
     def logproba_of_solutions(self, models, c_quizzes):
         logproba = c_quizzes.new_zeros(
-            c_quizzes.size(0), len(models), device=self.device
+            c_quizzes.size(0), len(models), device=self.device, dtype=torch.float32
         )
 
         for model in models: