Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 24 Jul 2024 14:42:24 +0000 (16:42 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 24 Jul 2024 14:42:24 +0000 (16:42 +0200)
grids.py
main.py
problem.py
quiz_machine.py

index 99a9240..37ed6a0 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -204,10 +204,11 @@ class Grids(problem.Problem):
             self.token_f_B: "f_B",
         }
 
-        self.nb_token_values = self.token_f_B + 1
-
         self.height = 10
         self.width = 10
+        self.seq_len = 4 * (1 + self.height * self.width)
+        self.nb_token_values = self.token_f_B + 1
+
         self.cache_rec_coo = {}
 
         all_tasks = [
@@ -1378,27 +1379,30 @@ class Grids(problem.Problem):
 
     ######################################################################
 
+    def create_empty_quizzes(self, nb, struct=("A", "f_A", "B", "f_B")):
+        S = self.height * self.width
+        quizzes = torch.zeros(nb, 4 * (S + 1), dtype=torch.int64)
+        quizzes[:, 0 * (S + 1)] = self.l2tok(struct[0])
+        quizzes[:, 1 * (S + 1)] = self.l2tok(struct[1])
+        quizzes[:, 2 * (S + 1)] = self.l2tok(struct[2])
+        quizzes[:, 3 * (S + 1)] = self.l2tok(struct[3])
+
+        return quizzes
+
     def generate_w_quizzes_(self, nb, tasks=None, progress_bar=False):
         if tasks is None:
             tasks = self.all_tasks
 
-        S = self.height * self.width
-        quizzes = torch.empty(nb, 4 * (S + 1), dtype=torch.int64)
+        quizzes = self.create_empty_quizzes(nb, ("A", "f_A", "B", "f_B"))
 
         if progress_bar:
             quizzes = tqdm.tqdm(
                 quizzes,
                 dynamic_ncols=True,
                 desc="world quizzes generation",
-                total=prompts.size(0),
+                total=quizzes.size(0),
             )
 
-        quizzes[...] = 0
-        quizzes[:, 0 * (S + 1)] = self.token_A
-        quizzes[:, 1 * (S + 1)] = self.token_f_A
-        quizzes[:, 2 * (S + 1)] = self.token_B
-        quizzes[:, 3 * (S + 1)] = self.token_f_B
-
         for quiz in quizzes:
             q = quiz.reshape(4, S + 1)[:, 1:].reshape(4, self.height, self.width)
             q[...] = 0
@@ -1412,9 +1416,9 @@ class Grids(problem.Problem):
         nb, nrow = 128, 4
         for t in self.all_tasks:
             print(t.__name__)
-            prompts, answers = self.generate_w_quizzes_(nb, tasks=[t])
+            quizzes = self.generate_w_quizzes_(nb, tasks=[t])
             self.save_quizzes_as_image(
-                result_dir, t.__name__, prompts[:nb], answers[:nb], nrow=nrow
+                result_dir, t.__name__ + ".png", quizzes, nrow=nrow
             )
 
 
@@ -1499,9 +1503,9 @@ if __name__ == "__main__":
     predicted_prompts = m * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
     predicted_answers = (1 - m) * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
 
-    grids.save_quiz_illustrations(
+    grids.save_quizzes_as_image(
         "/tmp",
-        "test",
+        "test.png",
         prompts[:nb],
         answers[:nb],
         # You can add a bool to put a frame around the predicted parts
diff --git a/main.py b/main.py
index 122dd31..fcca116 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -103,8 +103,6 @@ parser.add_argument("--nb_rounds", type=int, default=3)
 
 parser.add_argument("--c_quiz_validation_mode", type=str, default="predict")
 
-parser.add_argument("--p2a_only", action="store_true", default=False)
-
 parser.add_argument("--dirty_debug", action="store_true", default=False)
 
 ######################################################################
@@ -394,11 +392,9 @@ def one_epoch(model, quiz_machine, local_device=main_device):
         acc_train_loss += loss.item() * input.size(0)
 
         loss_per_samples = loss_per_token.detach().flatten(1).mean(dim=1)
-        n_p2a = input[:, 0] == quiz_machine.problem.token_forward
-        to_store = from_w & n_p2a.to("cpu")
-        if to_store.any():
+        if from_w.any():
             hard_w_quizzes.append(
-                (input[to_store].to("cpu"), loss_per_samples[to_store].to("cpu"))
+                (input[from_w].to("cpu"), loss_per_samples[from_w].to("cpu"))
             )
 
         nb_train_samples += input.size(0)
@@ -452,7 +448,6 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
         c_quizzes = quiz_machine.generate_c_quizzes(
             nb_to_generate_per_iteration,
             model_for_generation=model_for_generation,
-            p2a_only=args.p2a_only,
             temperature_hot=args.temperature_hot,
             temperature_cold=args.temperature_cold,
         )
@@ -585,7 +580,6 @@ for k in range(args.nb_gpts):
         model=model,
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
-        p2a_only=args.p2a_only,
     )
 
     models.append(model)
@@ -729,7 +723,6 @@ for n_epoch in range(current_epoch, args.nb_epochs):
         c_quizzes = quiz_machine.generate_c_quizzes(
             128,
             model_for_generation=model,
-            p2a_only=args.p2a_only,
             temperature_hot=args.temperature_hot,
             temperature_cold=args.temperature_cold,
         )
@@ -741,7 +734,7 @@ for n_epoch in range(current_epoch, args.nb_epochs):
     # Renew the training samples
 
     for model in weakest_models:
-        quiz_machine.renew_train_w_quizzes(model=model, p2a_only=args.p2a_only)
+        quiz_machine.renew_train_w_quizzes(model=model)
 
     if args.log_command is not None:
         s = args.log_command.split()
index 61e4834..50376d6 100755 (executable)
@@ -25,46 +25,23 @@ class Problem:
         else:
             return self.queue.qsize() * self.chunk_size
 
-    def nb_token_values(self):
-        pass
-
-    def trivial_prompts_and_answers(self, prompts, answers):
-        pass
-
-    # The one to implement, returns two tensors nb x D and nb x D'
-    def generate_w_quizzes_(self, nb):
-        pass
-
-    # save a file to vizualize quizzes, you can save a txt or png file
-    def save_quiz_illustrations(
-        self,
-        result_dir,
-        filename_prefix,
-        prompts,
-        answers,
-        predicted_prompts=None,
-        predicted_answers=None,
-    ):
-        pass
-
     def fill_cache(self):
         while True:
-            prompts, answers = self.generate_w_quizzes_(self.chunk_size)
-
-            self.queue.put((prompts.to("cpu"), answers.to("cpu")), block=True)
+            quizzes = self.generate_w_quizzes_(self.chunk_size)
+            self.queue.put(quizzes.to("cpu"), block=True)
 
     def generate_w_quizzes(self, nb):
         if self.queue is None:
             return self.generate_w_quizzes_(nb)
 
         if self.rest is not None:
-            prompts, answers = rest
+            quizzes = rest
         else:
-            prompts, answers = [], []
+            quizzes = []
 
         self.rest = None
 
-        n = sum([p.size(0) for p in prompts])
+        n = sum([q.size(0) for q in quizzes])
 
         with tqdm.tqdm(
             total=nb,
@@ -72,22 +49,44 @@ class Problem:
             desc="world generation",
         ) as pbar:
             while n < nb:
-                p, s = self.queue.get(block=True)
-                prompts.append(p)
-                answers.append(s)
-                n += p.size(0)
-                pbar.update(p.size(0))
+                q = self.queue.get(block=True)
+                quizzes.append(q)
+                n += q.size(0)
+                pbar.update(q.size(0))
 
-        prompts, answers = torch.cat(prompts, dim=0), torch.cat(answers, dim=0)
-        assert n == prompts.size(0)
+        quizzes = torch.cat(quizzes, dim=0)
+        assert n == quizzes.size(0)
 
         k = n - nb
 
         if k > 0:
-            rest = (prompts[-k:], answers[-k:])
-            prompts, answers = prompts[:-k], answers[:-k]
+            rest = quizzes[-k:]
+            quizzes = quizzes[:-k]
 
-        return prompts, answers
+        return quizzes
+
+    ######################################################################
+
+    def trivial_prompts_and_answers(self, prompts, answers):
+        pass
+
+    # The one to implement, returns two tensors nb x D and nb x D'
+    def generate_w_quizzes_(self, nb):
+        pass
+
+    # save a file to vizualize quizzes, you can save a txt or png file
+    def save_quiz_illustrations(
+        self,
+        result_dir,
+        filename_prefix,
+        prompts,
+        answers,
+        predicted_prompts=None,
+        predicted_answers=None,
+    ):
+        pass
 
     def save_some_examples(self, result_dir):
         pass
+
+    ######################################################################
index bc2a358..bb62181 100755 (executable)
@@ -174,36 +174,36 @@ class QuizMachine:
 
     ######################################################################
 
-    def produce_results(
-        self, n_epoch, model, input, result_dir, deterministic_synthesis
-    ):
-        def predict(input, struct, mask):
-            ar_mask = self.problem.make_ar_mask(
-                quizzes=quizzes, struct=struct, mask=mask
-            )
-            result = quizzes * (1 - ar_mask)
-            seq_logproba = torch.empty(fwd_quizzes, device=self.device)
+    def predict(self, input, struct, mask):
+        ar_mask = self.problem.make_ar_mask(quizzes=quizzes, struct=struct, mask=mask)
+        result = quizzes * (1 - ar_mask)
 
-            masked_inplace_autoregression(
-                model=model,
-                batch_size=self.batch_size,
-                input=result,
-                ar_mask=ar_mask,
-                seq_logproba=seq_logproba,
-                deterministic_synthesis=deterministic_synthesis,
-                progress_bar_desc="accuracy",
-                device=self.device,
-            )
+        seq_logproba = torch.empty(fwd_quizzes, device=self.device)
 
-            nb_correct = (result == quizzes).min(dim=1).long()
+        masked_inplace_autoregression(
+            model=model,
+            batch_size=self.batch_size,
+            input=result,
+            ar_mask=ar_mask,
+            seq_logproba=seq_logproba,
+            deterministic_synthesis=deterministic_synthesis,
+            progress_bar_desc="accuracy",
+            device=self.device,
+        )
+
+        nb_correct = (result == quizzes).min(dim=1).long()
 
-            return result, correct
+        return result, correct
 
+    def produce_results(
+        self, n_epoch, model, input, result_dir, deterministic_synthesis
+    ):
         input = input.to(self.device)
         i = self.problem.indices_select(quizzes=input, struct=struct)
 
+        input_fwd = input[i]
         test_result_fwd, test_correct_fwd = predict(
-            input[i], ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
+            input_fwd, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
         )
 
         input_bck = self.problem.reconfigure(
@@ -211,8 +211,9 @@ class QuizMachine:
             struct=("A", "f_A", "B", "f_B"),
         )
 
-        l = input_bck.size(1)
+        l = input_bck.size(1) // 4
         input_bck[:, 3 * l :] = input[i == False][:, :l]
+
         test_result_bck, test_correct_bck = predict(
             input_bck, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
         )
@@ -221,11 +222,14 @@ class QuizMachine:
 
         ##############################
 
+        test_result = torch.cat([test_result_fwd[:64], test_result_bck[:64]], dim=0)
+        test_correct = torch.cat([test_correct_fwd[:64], test_correct_bck[:64]], dim=0)
+
         self.save_quiz_illustrations(
             result_dir,
             f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
-            quizzes=test_result[:128],
-            mistakes=test_correct[:128] * 2 - 1,
+            quizzes=test_result,
+            # mistakes=test_correct,
         )
 
         return main_test_accuracy
@@ -233,12 +237,16 @@ class QuizMachine:
     ######################################################################
 
     def flip_half_in_place(self, quizzes):
-        r = torch.randint(quizzes.size(0), device=quizzes.device) < 0.5
-        i = self.problem.indices_select(quizzes=input, struct=("A", "f_A", "B", "f_B"))
+        r = torch.rand(quizzes.size(0), device=quizzes.device) < 0.5
+        i = self.problem.indices_select(
+            quizzes=quizzes, struct=("A", "f_A", "B", "f_B")
+        )
         quizzes[i & r] = self.problem.reconfigure(
             quizzes[i & r], struct=("f_B", "f_A", "B", "A")
         )
-        j = self.problem.indices_select(quizzes=input, struct=("f_B", "f_A", "B", "A"))
+        j = self.problem.indices_select(
+            quizzes=quizzes, struct=("f_B", "f_A", "B", "A")
+        )
         quizzes[j & r] = self.problem.reconfigure(
             quizzes[j & r], struct=("A", "f_A", "B", "f_B")
         )
@@ -403,7 +411,7 @@ class QuizMachine:
     ):
         c_quizzes = torch.empty(
             nb,
-            self.prompt_len + self.answer_len,
+            self.problem.seq_len,
             device=self.device,
             dtype=torch.int64,
         )