Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 4 Jul 2024 16:25:10 +0000 (19:25 +0300)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 4 Jul 2024 16:25:10 +0000 (19:25 +0300)
quizz_machine.py
reasoning.py

index 65b6000..62ae8ce 100755 (executable)
@@ -346,10 +346,6 @@ class QuizzMachine:
                     .item()
                 )
 
-                self.logger(
-                    f"back_accuracy {n_epoch=} {model.id=} {nb_correct=} {nb_total=}"
-                )
-
                 n_backward = input[:, 0] == self.token_backward
                 back_input = self.reverse_time(result[n_backward])
 
@@ -358,11 +354,20 @@ class QuizzMachine:
                         n_backward, 1 : 1 + self.answer_len
                     ]
                     back_nb_total, back_nb_correct = compute_accuracy(back_input)
+
+                    self.logger(
+                        f"accuracy {n_epoch=} {model.id=} {nb_correct} / {nb_total}"
+                    )
                     self.logger(
-                        f"back_accuracy {n_epoch=} {model.id=} {back_nb_correct=} {back_nb_total=}"
+                        f"back_accuracy {n_epoch=} {model.id=} {back_nb_correct} / {back_nb_total}"
                     )
+
                     nb_total += back_nb_total
                     nb_correct += back_nb_correct
+                else:
+                    self.logger(
+                        f"accuracy {n_epoch=} {model.id=} {nb_correct} / {nb_total}"
+                    )
 
             else:
                 nb_total = input.size(0)
index 2874adc..54a4203 100755 (executable)
@@ -293,7 +293,7 @@ class Reasoning(problem.Problem):
                 X[i1:i2, j1:j2] = c[n]
                 f_X[i1:i2, j1:j2] = c[n if n > 0 else -1]
 
-    def task_move(self, A, f_A, B, f_B):
+    def task_translate(self, A, f_A, B, f_B):
         di, dj = torch.randint(3, (2,)) - 1
         nb_rec = 3
         c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
@@ -406,16 +406,31 @@ class Reasoning(problem.Problem):
                 if n < nb_rec - 1:
                     f_X[i1, j1] = c[-1]
 
+    def task_count(self, A, f_A, B, f_B):
+        N = torch.randint(3, (1,)) + 1
+        c = torch.randperm(len(self.colors) - 1)[:N] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            nb = torch.randint(self.width, (3,)) + 1
+            k = torch.randperm(self.height * self.width)[: nb.sum()]
+            p = 0
+            for n in range(N):
+                for m in range(nb[n]):
+                    i, j = k[p] % self.height, k[p] // self.height
+                    X[i, j] = c[n]
+                    f_X[n, m] = c[n]
+                    p += 1
+
     ######################################################################
 
     def generate_prompts_and_answers(self, nb, device="cpu"):
         tasks = [
             self.task_replace_color,
-            self.task_move,
+            self.task_translate,
             self.task_grow,
             self.task_color_grow,
             self.task_frame,
             self.task_detect,
+            self.task_count,
         ]
         prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64)
         answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64)
@@ -476,6 +491,6 @@ if __name__ == "__main__":
         prompts[:64],
         answers[:64],
         # You can add a bool to put a frame around the predicted parts
-        predicted_prompts[:64],
-        predicted_answers[:64],
+        predicted_prompts[:64],
+        predicted_answers[:64],
     )