Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 4 Jul 2024 01:48:26 +0000 (04:48 +0300)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 4 Jul 2024 01:48:26 +0000 (04:48 +0300)
main.py
quizz_machine.py
reasoning.py

diff --git a/main.py b/main.py
index a954af6..be0d8e0 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -249,8 +249,10 @@ if args.problem == "sky":
         nb_iterations=args.sky_nb_iterations,
         speed=args.sky_speed,
     )
+    back_accuracy = False
 elif args.problem == "reasoning":
     problem = reasoning.Reasoning(device=device)
+    back_accuracy = True
 else:
     raise ValueError
 
@@ -258,6 +260,7 @@ quizz_machine = quizz_machine.QuizzMachine(
     problem=problem,
     nb_train_samples=args.nb_train_samples,
     nb_test_samples=args.nb_test_samples,
+    back_accuracy=back_accuracy,
     batch_size=args.physical_batch_size,
     result_dir=args.result_dir,
     logger=log_string,
index 90f288e..6e57fb4 100755 (executable)
@@ -202,6 +202,7 @@ class QuizzMachine:
         problem,
         nb_train_samples,
         nb_test_samples,
+        back_accuracy,
         batch_size,
         result_dir,
         logger,
@@ -215,6 +216,7 @@ class QuizzMachine:
         self.nb_token_values = v + 2
 
         self.problem = problem
+        self.back_accuracy = back_accuracy
         self.batch_size = batch_size
         self.device = device
         self.logger = logger
@@ -308,7 +310,6 @@ class QuizzMachine:
         self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000
     ):
         def compute_accuracy(input):
-            input = input[:nmax]
             ar_mask = self.make_ar_mask(input)
             result = input.clone() * (1 - ar_mask)
             seq_logproba = torch.empty(input.size(0), device=self.device)
@@ -325,18 +326,38 @@ class QuizzMachine:
                 device=self.device,
             )
 
-            nb_total = input.size(0)
-            nb_correct = (input == result).long().min(dim=1).values.sum()
+            if self.back_accuracy:
+                n_forward = input[:, 0] == self.token_forward
+                nb_total = input[n_forward].size(0)
+                nb_correct = (
+                    (input[n_forward] == result[n_forward])
+                    .long()
+                    .min(dim=1)
+                    .values.sum()
+                )
+
+                n_backward = input[:, 0] == self.token_backward
+                back_input = self.reverse_time(result[n_backward])
+                if back_input.size(0) > 0:
+                    back_input[:, 2 + self.prompt_len :] = input[
+                        n_backward, 2 + self.prompt_len :
+                    ]
+                    back_nb_total, back_nb_correct = compute_accuracy(back_input)
+                    nb_total += back_nb_total
+                    nb_correct += back_nb_correct
+            else:
+                nb_total = input.size(0)
+                nb_correct = (input == result).long().min(dim=1).values.sum()
 
             return nb_total, nb_correct
 
-        train_nb_total, train_nb_correct = compute_accuracy(self.train_w_quizzes)
+        train_nb_total, train_nb_correct = compute_accuracy(self.train_w_quizzes[:nmax])
 
         self.logger(
             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
         )
 
-        test_nb_total, test_nb_correct = compute_accuracy(self.test_w_quizzes)
+        test_nb_total, test_nb_correct = compute_accuracy(self.test_w_quizzes[:nmax])
 
         self.logger(
             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
index 57e8056..768c15c 100755 (executable)
@@ -42,6 +42,31 @@ class Reasoning(problem.Problem):
     ######################################################################
 
     def frame2img(self, x, scale=15):
+        x = x.reshape(x.size(0), self.height, -1)
+        m = torch.logical_and(x >= 0, x < self.nb_token_values()).long()
+        x = self.colors[x * m].permute(0, 3, 1, 2)
+        s = x.shape
+        x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
+        x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
+
+        x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
+        x[:, :, torch.arange(0, x.size(2), scale), :] = 0
+        x = x[:, :, 1:, 1:]
+
+        for n in range(m.size(0)):
+            for i in range(m.size(1)):
+                for j in range(m.size(2)):
+                    if m[n, i, j] == 0:
+                        for k in range(2, scale - 2):
+                            for l in [0, 1]:
+                                x[n, :, i * scale + k, j * scale + k - l] = 0
+                                x[
+                                    n, :, i * scale + scale - 1 - k, j * scale + k - l
+                                ] = 0
+
+        return x
+
+    def frame2img_(self, x, scale=15):
         x = x.reshape(x.size(0), self.height, -1)
         x = self.colors[x].permute(0, 3, 1, 2)
         s = x.shape
@@ -173,14 +198,13 @@ class Reasoning(problem.Problem):
     # non-overlapping rectangles quickly, but made the generation of
     # 100k samples go from 1h50 with a lame pure python code to 3min30s
     # with this one.
-    def rec_coo(self, x, n, min_height=3, min_width=3):
-        K = 3
-        N = 200
+    def rec_coo(self, nb_rec, min_height=3, min_width=3):
+        nb_trials = 200
 
         while True:
             v = (
                 (
-                    torch.rand(N * K, self.height + 1, device=self.device)
+                    torch.rand(nb_trials * nb_rec, self.height + 1, device=self.device)
                     .sort(dim=-1)
                     .indices
                     < 2
@@ -192,7 +216,7 @@ class Reasoning(problem.Problem):
 
             h = (
                 (
-                    torch.rand(N * K, self.width + 1, device=self.device)
+                    torch.rand(nb_trials * nb_rec, self.width + 1, device=self.device)
                     .sort(dim=-1)
                     .indices
                     < 2
@@ -207,10 +231,10 @@ class Reasoning(problem.Problem):
             )
 
             v, h = v[i], h[i]
-            v = v[: v.size(0) - v.size(0) % K]
-            h = h[: h.size(0) - h.size(0) % K]
-            v = v.reshape(v.size(0) // K, K, -1)
-            h = h.reshape(h.size(0) // K, K, -1)
+            v = v[: v.size(0) - v.size(0) % nb_rec]
+            h = h[: h.size(0) - h.size(0) % nb_rec]
+            v = v.reshape(v.size(0) // nb_rec, nb_rec, -1)
+            h = h.reshape(h.size(0) // nb_rec, nb_rec, -1)
 
             r = v[:, :, :, None] * h[:, :, None, :]
 
@@ -260,23 +284,23 @@ class Reasoning(problem.Problem):
     ######################################################################
 
     def task_replace_color(self, A, f_A, B, f_B):
-        N = 3
-        c = torch.randperm(len(self.colors) - 1)[: N + 1] + 1
+        nb_rec = 3
+        c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
-            r = self.rec_coo(X, N)
-            for n in range(N):
+            r = self.rec_coo(nb_rec)
+            for n in range(nb_rec):
                 i1, j1, i2, j2 = r[n]
                 X[i1:i2, j1:j2] = c[n]
                 f_X[i1:i2, j1:j2] = c[n if n > 0 else -1]
 
     def task_move(self, A, f_A, B, f_B):
         di, dj = torch.randint(2, (2,)) * 2 - 1
-        N = 3
-        c = torch.randperm(len(self.colors) - 1)[:N] + 1
+        nb_rec = 3
+        c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
             while True:
-                r = self.rec_coo(X, N)
-                i1, j1, i2, j2 = r[N - 1]
+                r = self.rec_coo(nb_rec)
+                i1, j1, i2, j2 = r[nb_rec - 1]
                 if (
                     i1 + di >= 0
                     and i2 + di < X.size(0)
@@ -285,29 +309,29 @@ class Reasoning(problem.Problem):
                 ):
                     break
 
-            for n in range(N):
+            for n in range(nb_rec):
                 i1, j1, i2, j2 = r[n]
                 X[i1:i2, j1:j2] = c[n]
-                if n == N - 1:
+                if n == nb_rec - 1:
                     f_X[i1 + di : i2 + di, j1 + dj : j2 + dj] = c[n]
                 else:
                     f_X[i1:i2, j1:j2] = c[n]
 
     def task_grow(self, A, f_A, B, f_B):
         di, dj = torch.randint(2, (2,)) * 2 - 1
-        N = 3
-        c = torch.randperm(len(self.colors) - 1)[:N] + 1
+        nb_rec = 3
+        c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
         direction = torch.randint(2, (1,))
         for X, f_X in [(A, f_A), (B, f_B)]:
             while True:
-                r = self.rec_coo(X, N)
-                i1, j1, i2, j2 = r[N - 1]
+                r = self.rec_coo(nb_rec)
+                i1, j1, i2, j2 = r[nb_rec - 1]
                 if i1 + 3 < i2 and j1 + 3 < j2:
                     break
 
-            for n in range(N):
+            for n in range(nb_rec):
                 i1, j1, i2, j2 = r[n]
-                if n == N - 1:
+                if n == nb_rec - 1:
                     if direction == 0:
                         X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
                         f_X[i1:i2, j1:j2] = c[n]
@@ -320,12 +344,12 @@ class Reasoning(problem.Problem):
 
     def task_color_grow(self, A, f_A, B, f_B):
         di, dj = torch.randint(2, (2,)) * 2 - 1
-        N = 3
-        c = torch.randperm(len(self.colors) - 1)[: 2 * N] + 1
+        nb_rec = 3
+        c = torch.randperm(len(self.colors) - 1)[: 2 * nb_rec] + 1
         direction = torch.randint(4, (1,))
         for X, f_X in [(A, f_A), (B, f_B)]:
-            r = self.rec_coo(X, N)
-            for n in range(N):
+            r = self.rec_coo(nb_rec)
+            for n in range(nb_rec):
                 i1, j1, i2, j2 = r[n]
                 X[i1:i2, j1:j2] = c[2 * n]
                 f_X[i1:i2, j1:j2] = c[2 * n]
@@ -333,53 +357,54 @@ class Reasoning(problem.Problem):
                 if direction == 0:
                     i = (i1 + i2) // 2
                     X[i : i + 1, j1:j2] = c[2 * n + 1]
-                    if n == N - 1:
+                    if n == nb_rec - 1:
                         f_X[i:i2, j1:j2] = c[2 * n + 1]
                     else:
                         f_X[i : i + 1, j1:j2] = c[2 * n + 1]
                 elif direction == 1:
                     i = (i1 + i2 - 1) // 2
                     X[i : i + 1, j1:j2] = c[2 * n + 1]
-                    if n == N - 1:
+                    if n == nb_rec - 1:
                         f_X[i1 : i + 1, j1:j2] = c[2 * n + 1]
                     else:
                         f_X[i : i + 1, j1:j2] = c[2 * n + 1]
                 elif direction == 2:
                     j = (j1 + j2) // 2
                     X[i1:i2, j : j + 1] = c[2 * n + 1]
-                    if n == N - 1:
+                    if n == nb_rec - 1:
                         f_X[i1:i2, j:j2] = c[2 * n + 1]
                     else:
                         f_X[i1:i2, j : j + 1] = c[2 * n + 1]
                 elif direction == 3:
                     j = (j1 + j2 - 1) // 2
                     X[i1:i2, j : j + 1] = c[2 * n + 1]
-                    if n == N - 1:
+                    if n == nb_rec - 1:
                         f_X[i1:i2, j1 : j + 1] = c[2 * n + 1]
                     else:
                         f_X[i1:i2, j : j + 1] = c[2 * n + 1]
 
     def task_frame(self, A, f_A, B, f_B):
-        N = 3
-        c = torch.randperm(len(self.colors) - 1)[: N + 1] + 1
+        nb_rec = 3
+        c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
-            r = self.rec_coo(X, N)
-            for n in range(N):
+            r = self.rec_coo(nb_rec)
+            for n in range(nb_rec):
                 i1, j1, i2, j2 = r[n]
                 X[i1:i2, j1:j2] = c[n]
                 f_X[i1:i2, j1:j2] = c[n]
-                if n == N - 1:
+                if n == nb_rec - 1:
                     f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = 0
 
     def task_detect(self, A, f_A, B, f_B):
-        N = 3
-        c = torch.randperm(len(self.colors) - 1)[: N + 1] + 1
+        nb_rec = 3
+        c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
-            r = self.rec_coo(X, N)
-            for n in range(N):
+            r = self.rec_coo(nb_rec)
+            for n in range(nb_rec):
                 i1, j1, i2, j2 = r[n]
                 X[i1:i2, j1:j2] = c[n]
-                f_X[i1, j1] = c[-1]
+                if n < nb_rec - 1:
+                    f_X[i1, j1] = c[-1]
 
     ######################################################################
 
@@ -448,8 +473,8 @@ if __name__ == "__main__":
     reasoning.save_quizzes(
         "/tmp",
         "test",
-        prompts[:36],
-        answers[:36],
+        prompts[:64],
+        answers[:64],
         # You can add a bool to put a frame around the predicted parts
         # predicted_prompts, predicted_answers
     )