Update.
[culture.git] / quiz_machine.py
index de1e8d1..eab41dc 100755 (executable)
@@ -15,6 +15,8 @@ from torch.nn import functional as F
 import mygpt
 from mygpt import BracketedSequence
 
 import mygpt
 from mygpt import BracketedSequence
 
+import threading
+
 ######################################################################
 
 # ar_mask is a tensor with 0s and 1s, of same shape as input, with
 ######################################################################
 
 # ar_mask is a tensor with 0s and 1s, of same shape as input, with
@@ -27,8 +29,8 @@ def one_batch_masked_inplace_autoregression(
     input,
     ar_mask,
     seq_logproba,
     input,
     ar_mask,
     seq_logproba,
-    temperature=1.0,
-    deterministic_synthesis=False,
+    temperature,
+    deterministic_synthesis,
 ):
     to_generate = (ar_mask.sum(0) > 0).nonzero()
 
 ):
     to_generate = (ar_mask.sum(0) > 0).nonzero()
 
@@ -50,7 +52,8 @@ def one_batch_masked_inplace_autoregression(
             t_next = dist.sample()
 
         all_n = torch.arange(t_next.size(0))
             t_next = dist.sample()
 
         all_n = torch.arange(t_next.size(0))
-        seq_logproba += logits[all_n, t_next].sum(dim=-1)
+
+        seq_logproba += logits[all_n, t_next]
 
         input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
 
 
         input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
 
@@ -116,6 +119,19 @@ class QuizMachine:
         ).all()
         return i_forward, i_backward
 
         ).all()
         return i_forward, i_backward
 
+    def non_trivial(self, quizzes):
+        quizzes = quizzes.clone()
+        n_forward = quizzes[quizzes[:, 0] == self.token_forward]
+        n_backward = quizzes[:, 0] == self.token_backward
+        backward = quizzes[n_backward]
+        quizzes[n_backward] = self.reverse_time(quizzes[n_backward])
+        return torch.logical_not(
+            self.problem.trivial_prompts_and_answers(
+                quizzes[:, 1 : 1 + self.prompt_len],
+                quizzes[:, 2 + self.prompt_len :],
+            )
+        )
+
     def reverse_time(self, quizzes):
         i_forward, i_backward = self.indices_forward_and_backward(quizzes)
 
     def reverse_time(self, quizzes):
         i_forward, i_backward = self.indices_forward_and_backward(quizzes)
 
@@ -221,32 +237,18 @@ class QuizMachine:
         self.prompt_len = None
         self.answer_len = None
 
         self.prompt_len = None
         self.answer_len = None
 
-        self.train_w_quizzes = self.generate_token_sequences(nb_train_samples)
-        self.reverse_random_half_in_place(self.train_w_quizzes)
-        self.train_w_quizzes = self.train_w_quizzes.to(device)
-
-        self.test_w_quizzes = self.generate_token_sequences(nb_test_samples).to(device)
-        self.reverse_random_half_in_place(self.test_w_quizzes)
-        self.test_w_quizzes = self.test_w_quizzes.to(device)
-
+        self.LOCK_C_QUIZZES = threading.Lock()
         self.train_c_quizzes = []
         self.test_c_quizzes = []
 
         self.train_c_quizzes = []
         self.test_c_quizzes = []
 
-        if result_dir is not None:
-            self.save_quizzes(
-                result_dir,
-                "culture_w_quizzes",
-                self.train_w_quizzes[:72],
-            )
-
-    def save_quizzes(
+    def save_quiz_illustrations(
         self,
         result_dir,
         filename_prefix,
         quizzes,
         mistakes=None,
     ):
         self,
         result_dir,
         filename_prefix,
         quizzes,
         mistakes=None,
     ):
-        quizzes = quizzes.clone()
+        quizzes = quizzes.clone().to("cpu")
         n_forward = quizzes[quizzes[:, 0] == self.token_forward]
         n_backward = quizzes[:, 0] == self.token_backward
         backward = quizzes[n_backward]
         n_forward = quizzes[quizzes[:, 0] == self.token_forward]
         n_backward = quizzes[:, 0] == self.token_backward
         backward = quizzes[n_backward]
@@ -257,14 +259,14 @@ class QuizMachine:
         predicted_answers = 1 - predicted_prompts
         if mistakes is not None:
             # 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct
         predicted_answers = 1 - predicted_prompts
         if mistakes is not None:
             # 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct
-            predicted_prompts *= mistakes
-            predicted_answers *= mistakes
+            predicted_prompts *= mistakes.to("cpu")
+            predicted_answers *= mistakes.to("cpu")
         else:
             # 0/2 ~ not-to-predict / to predict
             predicted_prompts *= 2
             predicted_answers *= 2
 
         else:
             # 0/2 ~ not-to-predict / to predict
             predicted_prompts *= 2
             predicted_answers *= 2
 
-        self.problem.save_quizzes(
+        self.problem.save_quiz_illustrations(
             result_dir,
             filename_prefix,
             quizzes[:, 1 : 1 + self.prompt_len],
             result_dir,
             filename_prefix,
             quizzes[:, 1 : 1 + self.prompt_len],
@@ -273,34 +275,41 @@ class QuizMachine:
             predicted_answers,
         )
 
             predicted_answers,
         )
 
-    def batches(self, split="train", desc=None):
-        assert split in {"train", "test"}
-        if split == "train":
-            w_quizzes = self.train_w_quizzes
-            c_quizzes = self.train_c_quizzes
-        else:
-            w_quizzes = self.test_w_quizzes
-            c_quizzes = self.test_c_quizzes
+    def vocabulary_size(self):
+        return self.nb_token_values
 
 
-        if len(c_quizzes) > 0:
-            c_quizzes = torch.cat(c_quizzes, dim=0)
-            if c_quizzes.size(0) > w_quizzes.size(0) // 2:
-                i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2]
-                c_quizzes = c_quizzes[i]
+    ######################################################################
 
 
-            i = torch.randperm(w_quizzes.size(0))[
-                : w_quizzes.size(0) - c_quizzes.size(0)
-            ]
-            w_quizzes = w_quizzes[i]
+    def batches(self, model, split="train", desc=None):
+        assert split in {"train", "test"}
 
 
-            self.nb_batch_w_quizzes = w_quizzes.size(0)
-            self.nb_batch_c_quizzes = c_quizzes.size(0)
+        with self.LOCK_C_QUIZZES:
+            if split == "train":
+                w_quizzes = model.train_w_quizzes
+                c_quizzes = self.train_c_quizzes
+            else:
+                w_quizzes = model.test_w_quizzes
+                c_quizzes = self.test_c_quizzes
+
+            if len(c_quizzes) > 0:
+                c_quizzes = torch.cat(c_quizzes, dim=0)
+                if c_quizzes.size(0) > w_quizzes.size(0) // 2:
+                    i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2]
+                    c_quizzes = c_quizzes[i]
+
+                i = torch.randperm(w_quizzes.size(0))[
+                    : w_quizzes.size(0) - c_quizzes.size(0)
+                ]
+                w_quizzes = w_quizzes[i]
 
 
-            input = torch.cat([w_quizzes, c_quizzes], dim=0)
-        else:
-            input = w_quizzes
-            self.nb_batch_w_quizzes = w_quizzes.size(0)
-            self.nb_batch_c_quizzes = 0
+                self.nb_batch_w_quizzes = w_quizzes.size(0)
+                self.nb_batch_c_quizzes = c_quizzes.size(0)
+
+                input = torch.cat([w_quizzes, c_quizzes], dim=0)
+            else:
+                input = w_quizzes
+                self.nb_batch_w_quizzes = w_quizzes.size(0)
+                self.nb_batch_c_quizzes = 0
 
         # Shuffle
         input = input[torch.randperm(input.size(0))]
 
         # Shuffle
         input = input[torch.randperm(input.size(0))]
@@ -312,13 +321,13 @@ class QuizMachine:
         ):
             yield batch
 
         ):
             yield batch
 
-    def vocabulary_size(self):
-        return self.nb_token_values
+    ######################################################################
 
     def produce_results(
         self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000
     ):
         def compute_accuracy(input, log_prefix=None):
 
     def produce_results(
         self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000
     ):
         def compute_accuracy(input, log_prefix=None):
+            input = input.to(self.device)
             ar_mask = self.make_ar_mask(input)
             result = input.clone() * (1 - ar_mask)
             seq_logproba = torch.empty(input.size(0), device=self.device)
             ar_mask = self.make_ar_mask(input)
             result = input.clone() * (1 - ar_mask)
             seq_logproba = torch.empty(input.size(0), device=self.device)
@@ -359,19 +368,15 @@ class QuizMachine:
                 backward_nb_total = correct[n_backward].size(0)
 
                 self.logger(
                 backward_nb_total = correct[n_backward].size(0)
 
                 self.logger(
-                    f"{log_prefix}_forward_accuracy {n_epoch} model {model.id} nb_correct {forward_nb_correct} / {forward_nb_total}"
-                )
-
-                self.logger(
-                    f"{log_prefix}_backward_accuracy {n_epoch} model {model.id} nb_correct {backward_nb_correct} / {backward_nb_total}"
+                    f"{log_prefix}_accuracy {n_epoch} model {model.id} forward {forward_nb_correct} / {forward_nb_total} backward {backward_nb_correct} / {backward_nb_total}"
                 )
 
             return result, correct
 
                 )
 
             return result, correct
 
-        compute_accuracy(self.train_w_quizzes[:nmax], log_prefix="train")
+        # compute_accuracy(model.train_w_quizzes[:nmax], log_prefix="train")
 
         test_result, test_correct = compute_accuracy(
 
         test_result, test_correct = compute_accuracy(
-            self.test_w_quizzes[:nmax], log_prefix="test"
+            model.test_w_quizzes[:nmax], log_prefix="test"
         )
 
         main_test_accuracy = test_correct.sum() / test_correct.size(0)
         )
 
         main_test_accuracy = test_correct.sum() / test_correct.size(0)
@@ -379,7 +384,7 @@ class QuizMachine:
 
         ##############################
 
 
         ##############################
 
-        self.save_quizzes(
+        self.save_quiz_illustrations(
             result_dir,
             f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
             quizzes=test_result[:72],
             result_dir,
             f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
             quizzes=test_result[:72],
@@ -388,19 +393,60 @@ class QuizMachine:
 
         return main_test_accuracy
 
 
         return main_test_accuracy
 
-    def renew_w_quizzes(self, nb, for_train=True):
-        input = self.train_w_quizzes if for_train else self.test_w_quizzes
+    ######################################################################
+
+    def renew_w_quizzes(self, model, nb, for_train=True):
+        input = model.train_w_quizzes if for_train else model.test_w_quizzes
         nb = min(nb, input.size(0))
         input[:-nb] = input[nb:].clone()
         fresh_w_quizzes = self.generate_token_sequences(nb)
         self.reverse_random_half_in_place(fresh_w_quizzes)
         nb = min(nb, input.size(0))
         input[:-nb] = input[nb:].clone()
         fresh_w_quizzes = self.generate_token_sequences(nb)
         self.reverse_random_half_in_place(fresh_w_quizzes)
-        input[-nb:] = fresh_w_quizzes.to(self.device)
+        input[-nb:] = fresh_w_quizzes.to("cpu")
+
+    ######################################################################
 
     def store_c_quizzes(self, new_c_quizzes, for_train=True):
 
     def store_c_quizzes(self, new_c_quizzes, for_train=True):
-        if for_train:
-            self.train_c_quizzes.append(new_c_quizzes)
-        else:
-            self.test_c_quizzes.append(new_c_quizzes)
+        with self.LOCK_C_QUIZZES:
+            if for_train:
+                self.train_c_quizzes.append(new_c_quizzes.to("cpu"))
+            else:
+                self.test_c_quizzes.append(new_c_quizzes.to("cpu"))
+
+    def save_c_quizzes(self, filename):
+        torch.save((self.train_c_quizzes, self.test_c_quizzes), filename)
+
+    def load_c_quizzes(self, filename):
+        self.train_c_quizzes, self.test_c_quizzes = torch.load(filename)
+
+    ######################################################################
+
+    def logproba_of_solutions(self, models, c_quizzes):
+        logproba = c_quizzes.new_zeros(
+            c_quizzes.size(0), len(models), device=self.device, dtype=torch.float32
+        )
+
+        for model in models:
+            with torch.autograd.no_grad():
+                t = model.training
+                model.eval()
+
+                for input, l in zip(
+                    c_quizzes.split(self.batch_size), logproba.split(self.batch_size)
+                ):
+                    input = input.to(self.device)
+                    ar_mask = self.make_ar_mask(input)
+                    output = model(mygpt.BracketedSequence(input)).x
+                    ce = (
+                        F.cross_entropy(output.transpose(1, 2), input, reduction="none")
+                        * ar_mask
+                    )
+                    l[:, model.id] = -ce.sum(dim=-1)
+
+                model.train(t)
+
+        return logproba.to("cpu")
+
+    ###############################################################
 
     def compute_correctness(
         self,
 
     def compute_correctness(
         self,
@@ -420,11 +466,11 @@ class QuizMachine:
 
         nb_correct = 0
 
 
         nb_correct = 0
 
+        seq_logproba[...] = 0.0
+
         for model in models_for_validation:
             result = c_quizzes.clone()
 
         for model in models_for_validation:
             result = c_quizzes.clone()
 
-            seq_logproba[...] = 0.0
-
             ar_mask = self.make_ar_mask(result)
 
             masked_inplace_autoregression(
             ar_mask = self.make_ar_mask(result)
 
             masked_inplace_autoregression(
@@ -474,7 +520,10 @@ class QuizMachine:
 
     def generate_quizzes(self, nb, model_for_generation, temperature=1.0):
         c_quizzes = torch.empty(
 
     def generate_quizzes(self, nb, model_for_generation, temperature=1.0):
         c_quizzes = torch.empty(
-            nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64
+            nb,
+            self.prompt_len + self.answer_len + 2,
+            device=self.device,
+            dtype=torch.int64,
         )
 
         seq_logproba = torch.zeros(nb, device=self.device)
         )
 
         seq_logproba = torch.zeros(nb, device=self.device)
@@ -524,4 +573,4 @@ class QuizMachine:
             device=self.device,
         )
 
             device=self.device,
         )
 
-        return c_quizzes
+        return c_quizzes.to("cpu")