Update.
[culture.git] / quiz_machine.py
index 4f704a0..eab41dc 100755 (executable)
@@ -241,7 +241,7 @@ class QuizMachine:
         self.train_c_quizzes = []
         self.test_c_quizzes = []
 
-    def save_quizzes(
+    def save_quiz_illustrations(
         self,
         result_dir,
         filename_prefix,
@@ -266,7 +266,7 @@ class QuizMachine:
             predicted_prompts *= 2
             predicted_answers *= 2
 
-        self.problem.save_quizzes(
+        self.problem.save_quiz_illustrations(
             result_dir,
             filename_prefix,
             quizzes[:, 1 : 1 + self.prompt_len],
@@ -373,7 +373,7 @@ class QuizMachine:
 
             return result, correct
 
-        compute_accuracy(model.train_w_quizzes[:nmax], log_prefix="train")
+        compute_accuracy(model.train_w_quizzes[:nmax], log_prefix="train")
 
         test_result, test_correct = compute_accuracy(
             model.test_w_quizzes[:nmax], log_prefix="test"
@@ -384,7 +384,7 @@ class QuizMachine:
 
         ##############################
 
-        self.save_quizzes(
+        self.save_quiz_illustrations(
             result_dir,
             f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
             quizzes=test_result[:72],
@@ -412,11 +412,17 @@ class QuizMachine:
             else:
                 self.test_c_quizzes.append(new_c_quizzes.to("cpu"))
 
+    def save_c_quizzes(self, filename):
+        torch.save((self.train_c_quizzes, self.test_c_quizzes), filename)
+
+    def load_c_quizzes(self, filename):
+        self.train_c_quizzes, self.test_c_quizzes = torch.load(filename)
+
     ######################################################################
 
     def logproba_of_solutions(self, models, c_quizzes):
         logproba = c_quizzes.new_zeros(
-            c_quizzes.size(0), len(models), device=self.device
+            c_quizzes.size(0), len(models), device=self.device, dtype=torch.float32
         )
 
         for model in models: