Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 19 Jul 2024 13:55:45 +0000 (15:55 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 19 Jul 2024 13:55:45 +0000 (15:55 +0200)
grids.py
main.py
quiz_machine.py

index c2ff0d1..e1eff00 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -118,6 +118,61 @@ class Grids(problem.Problem):
         ("gray", [128, 128, 128]),
     ]
 
+    def make_ar_mask(self, quizzes, first=False):
+        S = self.height * self.width
+
+        assert (
+            (
+                (quizzes[:, 0] == self.token_forward)
+                | (quizzes[:, 0] == self.token_backward)
+            )
+            & (quizzes[:, 0] == quizzes[:, 1 * (S + 1)])
+            & (quizzes[:, 0] == quizzes[:, 2 * (S + 1)])
+            & (quizzes[:, 0] == quizzes[:, 3 * (S + 1)])
+        ).all()
+
+        T = torch.arange(quizzes.size(1), device=quizzes.device)
+
+        if first:
+            forward_mask = ((T % (S + 1) != 0) & (T < 3 * (S + 1))).long()
+            backward_mask = ((T % (S + 1) != 0) & (T < S + 1)).long()
+        else:
+            forward_mask = ((T % (S + 1) != 0) & (T >= 3 * (S + 1))).long()
+            backward_mask = ((T % (S + 1) != 0) & (T >= S + 1)).long()
+
+        is_forward = (quizzes[:, 0] == self.token_forward).long()
+
+        return (
+            is_forward[:, None] * forward_mask[None, :]
+            + (1 - is_forward)[:, None] * backward_mask[None, :]
+        )
+
+    def p_a_flip(self, quizzes):
+        S = self.height * self.width
+
+        assert (
+            (
+                (quizzes[:, 0] == self.token_forward)
+                | (quizzes[:, 0] == self.token_backward)
+            )
+            & (quizzes[:, 0] == quizzes[:, 1 * (S + 1)])
+            & (quizzes[:, 0] == quizzes[:, 2 * (S + 1)])
+            & (quizzes[:, 0] == quizzes[:, 3 * (S + 1)])
+        ).all()
+
+        flipped = torch.cat(
+            [quizzes[:, k * (S + 1) : (k + 1) * (S + 1)] for k in range(3, -1, -1)],
+            dim=1,
+        )
+
+        m = (flipped[:, 0] == self.token_forward).long()
+        flipped[:, 0 * (S + 1)] = m * self.token_backward + (1 - m) * self.token_forward
+        flipped[:, 1 * (S + 1)] = m * self.token_backward + (1 - m) * self.token_forward
+        flipped[:, 2 * (S + 1)] = m * self.token_backward + (1 - m) * self.token_forward
+        flipped[:, 3 * (S + 1)] = m * self.token_backward + (1 - m) * self.token_forward
+
+        return flipped
+
     def __init__(
         self,
         max_nb_cached_chunks=None,
@@ -158,29 +213,40 @@ class Grids(problem.Problem):
     ######################################################################
 
     def frame2img(self, x, scale=15):
-        x = x.reshape(x.size(0), self.height, -1)
+        x = x.reshape(x.size(0), self.height, self.width)
         m = torch.logical_and(x >= 0, x < len(self.colors)).long()
-        x = self.colors[x * m].permute(0, 3, 1, 2)
-        s = x.shape
-        x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
-        x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
+        y = self.colors[x * m].permute(0, 3, 1, 2)
+        s = y.shape
+        y = y[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
+        y = y.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
 
-        x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
-        x[:, :, torch.arange(0, x.size(2), scale), :] = 0
-        x = x[:, :, 1:, 1:]
+        y[:, :, :, torch.arange(0, y.size(3), scale)] = 0
+        y[:, :, torch.arange(0, y.size(2), scale), :] = 0
+        y = y[:, :, 1:, 1:]
 
         for n in range(m.size(0)):
             for i in range(m.size(1)):
                 for j in range(m.size(2)):
-                    if m[n, i, j] == 0:
+                    if x[n, i, j] == self.token_forward:
+                        for k in range(2, scale - 2):
+                            y[
+                                n,
+                                :,
+                                i * scale + k,
+                                j * scale + scale - 5 - abs(k - scale // 2),
+                            ] = 0
+
+                    elif x[n, i, j] == self.token_backward:
                         for k in range(2, scale - 2):
-                            for l in [0, 1]:
-                                x[n, :, i * scale + k, j * scale + k - l] = 0
-                                x[
-                                    n, :, i * scale + scale - 1 - k, j * scale + k - l
-                                ] = 0
+                            y[
+                                n, :, i * scale + k, j * scale + 3 + abs(k - scale // 2)
+                            ] = 0
+                            # y[n, :, i * scale + k, j * scale + k - l] = 0
+                            # y[
+                            # n, :, i * scale + scale - 1 - k, j * scale + k - l
+                            # ] = 0
 
-        return x
+        return y
 
     def save_image(
         self,
@@ -1223,6 +1289,10 @@ class Grids(problem.Problem):
             )
             f_B = answer[1 : S + 1].view(self.height, self.width)
             task = tasks[torch.randint(len(tasks), (1,)).item()]
+            A[...] = 0
+            f_A[...] = 0
+            B[...] = 0
+            f_B[...] = 0
             task(A, f_A, B, f_B)
 
         return prompts.flatten(1), answers.flatten(1)
@@ -1277,23 +1347,24 @@ if __name__ == "__main__":
     # exit(0)
 
     # if True:
-    nb, nrow = 128, 4
+    nb, nrow = 8, 2
     # nb, nrow = 8, 2
 
     for t in grids.all_tasks:
         # for t in [grids.task_compute]:
         print(t.__name__)
         prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
+        # prompts[...] = torch.randint(grids.nb_token_values(), prompts.size())
         grids.save_quiz_illustrations(
             "/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow
         )
 
-    exit(0)
+    exit(0)
 
     nb = 1000
 
-    for t in grids.all_tasks:
-    for t in [grids.task_compute]:
+    for t in grids.all_tasks:
+        # for t in [grids.task_compute]:
         start_time = time.perf_counter()
         prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
         delay = time.perf_counter() - start_time
diff --git a/main.py b/main.py
index d9257db..0182e6a 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -3,6 +3,9 @@
 # Any copyright is dedicated to the Public Domain.
 # https://creativecommons.org/publicdomain/zero/1.0/
 
+# > A > f(A) > B ; > f(B)
+# < f(B) ; < B < f(A) < A
+
 # Written by Francois Fleuret <francois@fleuret.org>
 
 import math, sys, argparse, time, tqdm, os, datetime, warnings
@@ -496,11 +499,11 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
 
     v_train = validated_quizzes[:nb_for_train]
     quiz_machine.store_c_quizzes(v_train, for_train=True)
-    quiz_machine.store_c_quizzes(quiz_machine.p_a_flip(v_train), for_train=True)
+    quiz_machine.store_c_quizzes(quiz_machine.problem.p_a_flip(v_train), for_train=True)
 
     v_test = validated_quizzes[nb_for_train:nb_to_create]
     quiz_machine.store_c_quizzes(v_test, for_train=False)
-    quiz_machine.store_c_quizzes(quiz_machine.p_a_flip(v_test), for_train=False)
+    quiz_machine.store_c_quizzes(quiz_machine.problem.p_a_flip(v_test), for_train=False)
 
     ######################################################################
     # save images
index cc81086..046ab73 100755 (executable)
@@ -154,56 +154,17 @@ class QuizMachine:
         n_p2a = quizzes[quizzes[:, 0] == self.problem.token_forward]
         n_a2p = quizzes[:, 0] == self.problem.token_backward
         a2p = quizzes[n_a2p]
-        quizzes[n_a2p] = self.p_a_flip(quizzes[n_a2p])
+        quizzes[n_a2p] = self.problem.p_a_flip(quizzes[n_a2p])
         return torch.logical_not(
             self.problem.trivial_prompts_and_answers(
                 quizzes[:, : self.prompt_len], quizzes[:, self.prompt_len :]
             )
         )
 
-    def p_a_flip(self, quizzes):
-        i_p2a, i_a2p = self.indices_p2a_and_a2p(quizzes)
-
-        p2a_to_a2p = torch.cat(
-            [quizzes[:, self.prompt_len :], quizzes[:, : self.prompt_len]],
-            dim=1,
-        )
-
-        p2a_to_a2p[:, 0] = self.problem.token_backward
-        p2a_to_a2p[:, self.answer_len] = self.problem.token_backward
-
-        a2p_to_p2a = torch.cat(
-            [quizzes[:, self.answer_len :], quizzes[:, : self.answer_len]],
-            dim=1,
-        )
-
-        a2p_to_p2a[:, 0] = self.problem.token_forward
-        a2p_to_p2a[:, self.prompt_len] = self.problem.token_forward
-
-        m = i_p2a.long()[:, None]
-
-        return m * p2a_to_a2p + (1 - m) * a2p_to_p2a
-
     def p_a_flip_half_in_place(self, quizzes):
         i = torch.rand(quizzes.size(0)) < 0.5
         if i.any():
-            quizzes[i] = self.p_a_flip(quizzes[i])
-
-    def make_ar_mask(self, quizzes, first=False):
-        i_p2a, i_a2p = self.indices_p2a_and_a2p(quizzes)
-
-        t = torch.arange(quizzes.size(1), device=quizzes.device)
-
-        if first:
-            m_p2a = (t >= 1).long() * (t < self.prompt_len).long()
-            m_a2p = (t >= 1).long() * (t < self.answer_len).long()
-        else:
-            m_p2a = (t >= 1 + self.prompt_len).long()
-            m_a2p = (t >= 1 + self.answer_len).long()
-
-        m = i_p2a.long()[:, None]
-
-        return m * m_p2a + (1 - m) * m_a2p
+            quizzes[i] = self.problem.p_a_flip(quizzes[i])
 
     def generate_token_sequences(self, nb):
         prompts, answers = self.problem.generate_prompts_and_answers(nb)
@@ -261,7 +222,7 @@ class QuizMachine:
         n_a2p = quizzes[:, 0] == self.problem.token_backward
         a2p = quizzes[n_a2p]
         assert n_p2a.size(0) + a2p.size(0) == quizzes.size(0)
-        quizzes[n_a2p] = self.p_a_flip(quizzes[n_a2p])
+        quizzes[n_a2p] = self.problem.p_a_flip(quizzes[n_a2p])
 
         if show_part_to_predict:
             predicted_prompts = n_a2p.long()
@@ -332,7 +293,7 @@ class QuizMachine:
     def produce_results(self, n_epoch, model, result_dir, deterministic_synthesis):
         def compute_accuracy(input, log_prefix=None):
             input = input.to(self.device)
-            ar_mask = self.make_ar_mask(input)
+            ar_mask = self.problem.make_ar_mask(input)
             result = input.clone() * (1 - ar_mask)
             seq_logproba = torch.empty(input.size(0), device=self.device)
 
@@ -357,7 +318,7 @@ class QuizMachine:
 
             if self.back_accuracy and n_a2p.any():
                 # accuracy of B->A*->B*=B instead of B->A*=A
-                back_input = self.p_a_flip(result[n_a2p])
+                back_input = self.problem.p_a_flip(result[n_a2p])
                 back_input[:, 1 + self.prompt_len :] = input[n_a2p, 1 : self.answer_len]
                 _, correct[n_a2p] = compute_accuracy(back_input)
 
@@ -471,7 +432,7 @@ class QuizMachine:
                     c_quizzes.split(self.batch_size), logproba.split(self.batch_size)
                 ):
                     input = input.to(self.device)
-                    ar_mask = self.make_ar_mask(input)
+                    ar_mask = self.problem.make_ar_mask(input)
                     output = model(mygpt.BracketedSequence(input)).x
                     l[:, model.id] = (
                         -F.cross_entropy(
@@ -506,7 +467,7 @@ class QuizMachine:
             c_quizzes = c_quizzes.to(self.device)
             result = c_quizzes.clone()
 
-            ar_mask = self.make_ar_mask(result)
+            ar_mask = self.problem.make_ar_mask(result)
 
             masked_inplace_autoregression(
                 model=model,
@@ -545,14 +506,13 @@ class QuizMachine:
         seq_logproba = torch.zeros(nb, device=self.device)
 
         if p2a_only:
-            c_quizzes[:, 0] = self.problem.token_forward
-            c_quizzes[:, self.prompt_len] = self.problem.token_forward
+            c_quizzes[...] = self.problem.token_forward
 
             masked_inplace_autoregression(
                 model=model_for_generation,
                 batch_size=self.batch_size,
                 input=c_quizzes,
-                ar_mask=self.make_ar_mask(c_quizzes, first=True),
+                ar_mask=self.problem.make_ar_mask(c_quizzes, first=True),
                 seq_logproba=seq_logproba,
                 temperature=temperature_hot,
                 deterministic_synthesis=False,
@@ -563,7 +523,7 @@ class QuizMachine:
                 model=model_for_generation,
                 batch_size=self.batch_size,
                 input=c_quizzes,
-                ar_mask=self.make_ar_mask(c_quizzes),
+                ar_mask=self.problem.make_ar_mask(c_quizzes),
                 seq_logproba=seq_logproba,
                 temperature=temperature_cold,
                 deterministic_synthesis=False,
@@ -571,14 +531,13 @@ class QuizMachine:
             )
 
         else:
-            c_quizzes[:, 0] = self.problem.token_backward
-            c_quizzes[:, self.answer_len] = self.problem.token_backward
+            c_quizzes[...] = self.problem.token_backward
 
             masked_inplace_autoregression(
                 model=model_for_generation,
                 batch_size=self.batch_size,
                 input=c_quizzes,
-                ar_mask=self.make_ar_mask(c_quizzes, first=True),
+                ar_mask=self.problem.make_ar_mask(c_quizzes, first=True),
                 seq_logproba=seq_logproba,
                 temperature=temperature_hot,
                 deterministic_synthesis=False,
@@ -589,20 +548,20 @@ class QuizMachine:
                 model=model_for_generation,
                 batch_size=self.batch_size,
                 input=c_quizzes,
-                ar_mask=self.make_ar_mask(c_quizzes),
+                ar_mask=self.problem.make_ar_mask(c_quizzes),
                 seq_logproba=seq_logproba,
                 temperature=temperature_cold,
                 deterministic_synthesis=False,
                 device=self.device,
             )
 
-            c_quizzes = self.p_a_flip(c_quizzes)
+            c_quizzes = self.problem.p_a_flip(c_quizzes)
 
             masked_inplace_autoregression(
                 model=model_for_generation,
                 batch_size=self.batch_size,
                 input=c_quizzes,
-                ar_mask=self.make_ar_mask(c_quizzes),
+                ar_mask=self.problem.make_ar_mask(c_quizzes),
                 seq_logproba=seq_logproba,
                 temperature=temperature_cold,
                 deterministic_synthesis=False,