Update.
[culture.git] / quiz_machine.py
index cdfba85..1f1046d 100755 (executable)
@@ -15,6 +15,8 @@ from torch.nn import functional as F
 import mygpt
 from mygpt import BracketedSequence
 
+import threading
+
 ######################################################################
 
 # ar_mask is a tensor with 0s and 1s, of same shape as input, with
@@ -235,22 +237,10 @@ class QuizMachine:
         self.prompt_len = None
         self.answer_len = None
 
-        # self.train_w_quizzes = self.generate_token_sequences(nb_train_samples)
-        # self.reverse_random_half_in_place(self.train_w_quizzes)
-
-        # self.test_w_quizzes = self.generate_token_sequences(nb_test_samples).to(device)
-        # self.reverse_random_half_in_place(self.test_w_quizzes)
-
+        self.LOCK_C_QUIZZES = threading.Lock()
         self.train_c_quizzes = []
         self.test_c_quizzes = []
 
-        # if result_dir is not None:
-        # self.save_quizzes(
-        # result_dir,
-        # "culture_w_quizzes",
-        # self.train_w_quizzes[:72],
-        # )
-
     def save_quizzes(
         self,
         result_dir,
@@ -285,34 +275,41 @@ class QuizMachine:
             predicted_answers,
         )
 
+    def vocabulary_size(self):
+        return self.nb_token_values
+
+    ######################################################################
+
     def batches(self, model, split="train", desc=None):
         assert split in {"train", "test"}
-        if split == "train":
-            w_quizzes = model.train_w_quizzes
-            c_quizzes = self.train_c_quizzes
-        else:
-            w_quizzes = model.test_w_quizzes
-            c_quizzes = self.test_c_quizzes
-
-        if len(c_quizzes) > 0:
-            c_quizzes = torch.cat(c_quizzes, dim=0)
-            if c_quizzes.size(0) > w_quizzes.size(0) // 2:
-                i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2]
-                c_quizzes = c_quizzes[i]
 
-            i = torch.randperm(w_quizzes.size(0))[
-                : w_quizzes.size(0) - c_quizzes.size(0)
-            ]
-            w_quizzes = w_quizzes[i]
+        with self.LOCK_C_QUIZZES:
+            if split == "train":
+                w_quizzes = model.train_w_quizzes
+                c_quizzes = self.train_c_quizzes
+            else:
+                w_quizzes = model.test_w_quizzes
+                c_quizzes = self.test_c_quizzes
+
+            if len(c_quizzes) > 0:
+                c_quizzes = torch.cat(c_quizzes, dim=0)
+                if c_quizzes.size(0) > w_quizzes.size(0) // 2:
+                    i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2]
+                    c_quizzes = c_quizzes[i]
+
+                i = torch.randperm(w_quizzes.size(0))[
+                    : w_quizzes.size(0) - c_quizzes.size(0)
+                ]
+                w_quizzes = w_quizzes[i]
 
-            self.nb_batch_w_quizzes = w_quizzes.size(0)
-            self.nb_batch_c_quizzes = c_quizzes.size(0)
+                self.nb_batch_w_quizzes = w_quizzes.size(0)
+                self.nb_batch_c_quizzes = c_quizzes.size(0)
 
-            input = torch.cat([w_quizzes, c_quizzes], dim=0)
-        else:
-            input = w_quizzes
-            self.nb_batch_w_quizzes = w_quizzes.size(0)
-            self.nb_batch_c_quizzes = 0
+                input = torch.cat([w_quizzes, c_quizzes], dim=0)
+            else:
+                input = w_quizzes
+                self.nb_batch_w_quizzes = w_quizzes.size(0)
+                self.nb_batch_c_quizzes = 0
 
         # Shuffle
         input = input[torch.randperm(input.size(0))]
@@ -324,8 +321,7 @@ class QuizMachine:
         ):
             yield batch
 
-    def vocabulary_size(self):
-        return self.nb_token_values
+    ######################################################################
 
     def produce_results(
         self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000
@@ -400,6 +396,8 @@ class QuizMachine:
 
         return main_test_accuracy
 
+    ######################################################################
+
     def renew_w_quizzes(self, model, nb, for_train=True):
         input = model.train_w_quizzes if for_train else model.test_w_quizzes
         nb = min(nb, input.size(0))
@@ -408,13 +406,18 @@ class QuizMachine:
         self.reverse_random_half_in_place(fresh_w_quizzes)
         input[-nb:] = fresh_w_quizzes.to(self.device)
 
+    ######################################################################
+
     def store_c_quizzes(self, new_c_quizzes, for_train=True):
-        if for_train:
-            self.train_c_quizzes.append(new_c_quizzes)
-        else:
-            self.test_c_quizzes.append(new_c_quizzes)
+        with self.LOCK_C_QUIZZES:
+            if for_train:
+                self.train_c_quizzes.append(new_c_quizzes)
+            else:
+                self.test_c_quizzes.append(new_c_quizzes)
+
+    ######################################################################
 
-    def logproba_solution(self, models, c_quizzes):
+    def logproba_of_solutions(self, models, c_quizzes):
         logproba = c_quizzes.new_zeros(c_quizzes.size(0), len(models))
 
         for model in models: