Update.
[culture.git] / quiz_machine.py
index cdfba85..34c09a7 100755 (executable)
@@ -285,6 +285,11 @@ class QuizMachine:
             predicted_answers,
         )
 
             predicted_answers,
         )
 
+    def vocabulary_size(self):
+        return self.nb_token_values
+
+    ######################################################################
+
     def batches(self, model, split="train", desc=None):
         assert split in {"train", "test"}
         if split == "train":
     def batches(self, model, split="train", desc=None):
         assert split in {"train", "test"}
         if split == "train":
@@ -324,8 +329,7 @@ class QuizMachine:
         ):
             yield batch
 
         ):
             yield batch
 
-    def vocabulary_size(self):
-        return self.nb_token_values
+    ######################################################################
 
     def produce_results(
         self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000
 
     def produce_results(
         self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000
@@ -400,6 +404,8 @@ class QuizMachine:
 
         return main_test_accuracy
 
 
         return main_test_accuracy
 
+    ######################################################################
+
     def renew_w_quizzes(self, model, nb, for_train=True):
         input = model.train_w_quizzes if for_train else model.test_w_quizzes
         nb = min(nb, input.size(0))
     def renew_w_quizzes(self, model, nb, for_train=True):
         input = model.train_w_quizzes if for_train else model.test_w_quizzes
         nb = min(nb, input.size(0))
@@ -408,13 +414,17 @@ class QuizMachine:
         self.reverse_random_half_in_place(fresh_w_quizzes)
         input[-nb:] = fresh_w_quizzes.to(self.device)
 
         self.reverse_random_half_in_place(fresh_w_quizzes)
         input[-nb:] = fresh_w_quizzes.to(self.device)
 
+    ######################################################################
+
     def store_c_quizzes(self, new_c_quizzes, for_train=True):
         if for_train:
             self.train_c_quizzes.append(new_c_quizzes)
         else:
             self.test_c_quizzes.append(new_c_quizzes)
 
     def store_c_quizzes(self, new_c_quizzes, for_train=True):
         if for_train:
             self.train_c_quizzes.append(new_c_quizzes)
         else:
             self.test_c_quizzes.append(new_c_quizzes)
 
-    def logproba_solution(self, models, c_quizzes):
+    ######################################################################
+
+    def logproba_of_solutions(self, models, c_quizzes):
         logproba = c_quizzes.new_zeros(c_quizzes.size(0), len(models))
 
         for model in models:
         logproba = c_quizzes.new_zeros(c_quizzes.size(0), len(models))
 
         for model in models: