Update.
[culture.git] / quiz_machine.py
index 1f1046d..c39bf7a 100755 (executable)
@@ -241,7 +241,7 @@ class QuizMachine:
         self.train_c_quizzes = []
         self.test_c_quizzes = []
 
-    def save_quizzes(
+    def save_quiz_illustrations(
         self,
         result_dir,
         filename_prefix,
@@ -266,7 +266,7 @@ class QuizMachine:
             predicted_prompts *= 2
             predicted_answers *= 2
 
-        self.problem.save_quizzes(
+        self.problem.save_quiz_illustrations(
             result_dir,
             filename_prefix,
             quizzes[:, 1 : 1 + self.prompt_len],
@@ -327,6 +327,7 @@ class QuizMachine:
         self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000
     ):
         def compute_accuracy(input, log_prefix=None):
+            input = input.to(self.device)
             ar_mask = self.make_ar_mask(input)
             result = input.clone() * (1 - ar_mask)
             seq_logproba = torch.empty(input.size(0), device=self.device)
@@ -367,11 +368,7 @@ class QuizMachine:
                 backward_nb_total = correct[n_backward].size(0)
 
                 self.logger(
-                    f"{log_prefix}_forward_accuracy {n_epoch} model {model.id} nb_correct {forward_nb_correct} / {forward_nb_total} ({forward_nb_correct*100/forward_nb_total} %)"
-                )
-
-                self.logger(
-                    f"{log_prefix}_backward_accuracy {n_epoch} model {model.id} nb_correct {backward_nb_correct} / {backward_nb_total} ({backward_nb_correct*100/backward_nb_total} %)"
+                    f"{log_prefix}_accuracy {n_epoch} model {model.id} forward {forward_nb_correct} / {forward_nb_total} backward {backward_nb_correct} / {backward_nb_total}"
                 )
 
             return result, correct
@@ -387,7 +384,7 @@ class QuizMachine:
 
         ##############################
 
-        self.save_quizzes(
+        self.save_quiz_illustrations(
             result_dir,
             f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
             quizzes=test_result[:72],
@@ -404,35 +401,50 @@ class QuizMachine:
         input[:-nb] = input[nb:].clone()
         fresh_w_quizzes = self.generate_token_sequences(nb)
         self.reverse_random_half_in_place(fresh_w_quizzes)
-        input[-nb:] = fresh_w_quizzes.to(self.device)
+        input[-nb:] = fresh_w_quizzes.to("cpu")
 
     ######################################################################
 
     def store_c_quizzes(self, new_c_quizzes, for_train=True):
         with self.LOCK_C_QUIZZES:
             if for_train:
-                self.train_c_quizzes.append(new_c_quizzes)
+                self.train_c_quizzes.append(new_c_quizzes.to("cpu"))
             else:
-                self.test_c_quizzes.append(new_c_quizzes)
+                self.test_c_quizzes.append(new_c_quizzes.to("cpu"))
+
+    def save_c_quizzes(self, filename):
+        torch.save((self.train_c_quizzes, self.test_c_quizzes), filename)
+
+    def load_c_quizzes(self, filename):
+        self.train_c_quizzes, self.test_c_quizzes = torch.load(filename)
 
     ######################################################################
 
     def logproba_of_solutions(self, models, c_quizzes):
-        logproba = c_quizzes.new_zeros(c_quizzes.size(0), len(models))
+        logproba = c_quizzes.new_zeros(
+            c_quizzes.size(0), len(models), device=self.device, dtype=torch.float32
+        )
 
         for model in models:
-            for input, l in zip(
-                c_quizzes.split(self.batch_size), logproba.split(self.batch_size)
-            ):
-                ar_mask = self.make_ar_mask(input)
-                output = model(mygpt.BracketedSequence(input)).x
-                ce = (
-                    F.cross_entropy(output.transpose(1, 2), input, reduction="none")
-                    * ar_mask
-                )
-                l[:, model.id] = -ce.sum(dim=-1)
-
-        return logproba
+            with torch.autograd.no_grad():
+                t = model.training
+                model.eval()
+
+                for input, l in zip(
+                    c_quizzes.split(self.batch_size), logproba.split(self.batch_size)
+                ):
+                    input = input.to(self.device)
+                    ar_mask = self.make_ar_mask(input)
+                    output = model(mygpt.BracketedSequence(input)).x
+                    ce = (
+                        F.cross_entropy(output.transpose(1, 2), input, reduction="none")
+                        * ar_mask
+                    )
+                    l[:, model.id] = -ce.sum(dim=-1)
+
+                model.train(t)
+
+        return logproba.to("cpu")
 
     ###############################################################
 
@@ -561,4 +573,4 @@ class QuizMachine:
             device=self.device,
         )
 
-        return c_quizzes
+        return c_quizzes.to("cpu")