Update.
[culture.git] / quiz_machine.py
index 88fd9f1..eab41dc 100755 (executable)
@@ -373,7 +373,7 @@ class QuizMachine:
 
             return result, correct
 
-        compute_accuracy(model.train_w_quizzes[:nmax], log_prefix="train")
+        compute_accuracy(model.train_w_quizzes[:nmax], log_prefix="train")
 
         test_result, test_correct = compute_accuracy(
             model.test_w_quizzes[:nmax], log_prefix="test"
@@ -412,6 +412,12 @@ class QuizMachine:
             else:
                 self.test_c_quizzes.append(new_c_quizzes.to("cpu"))
 
+    def save_c_quizzes(self, filename):
+        torch.save((self.train_c_quizzes, self.test_c_quizzes), filename)
+
+    def load_c_quizzes(self, filename):
+        self.train_c_quizzes, self.test_c_quizzes = torch.load(filename)
+
     ######################################################################
 
     def logproba_of_solutions(self, models, c_quizzes):