Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 5 Jul 2024 18:56:42 +0000 (21:56 +0300)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 5 Jul 2024 18:56:42 +0000 (21:56 +0300)
quizz_machine.py
reasoning.py
sky.py

index 62ae8ce..632c9ae 100755 (executable)
@@ -238,10 +238,17 @@ class QuizzMachine:
                 result_dir,
                 "culture_w_quizzes",
                 self.train_w_quizzes[:72],
                 result_dir,
                 "culture_w_quizzes",
                 self.train_w_quizzes[:72],
-                prediction=True,
+                show_to_be_predicted=True,
             )
 
             )
 
-    def save_quizzes(self, result_dir, filename_prefix, quizzes, prediction=False):
+    def save_quizzes(
+        self,
+        result_dir,
+        filename_prefix,
+        quizzes,
+        show_to_be_predicted=False,
+        mistakes=None,
+    ):
         quizzes = quizzes.clone()
         forward = quizzes[quizzes[:, 0] == self.token_forward]
         ib = quizzes[:, 0] == self.token_backward
         quizzes = quizzes.clone()
         forward = quizzes[quizzes[:, 0] == self.token_forward]
         ib = quizzes[:, 0] == self.token_backward
@@ -249,9 +256,17 @@ class QuizzMachine:
         assert forward.size(0) + backward.size(0) == quizzes.size(0)
         quizzes[ib] = self.reverse_time(quizzes[ib])
 
         assert forward.size(0) + backward.size(0) == quizzes.size(0)
         quizzes[ib] = self.reverse_time(quizzes[ib])
 
-        if prediction:
-            predicted_prompts = ib
-            predicted_answers = torch.logical_not(ib)
+        if show_to_be_predicted:
+            predicted_prompts = ib.long()
+            predicted_answers = 1 - predicted_prompts
+            if mistakes is not None:
+                # 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct
+                predicted_prompts *= mistakes
+                predicted_answers *= mistakes
+            else:
+                # 0/2 ~ not-to-predict / to predict
+                predicted_prompts *= 2
+                predicted_answers *= 2
         else:
             predicted_prompts = None
             predicted_answers = None
         else:
             predicted_prompts = None
             predicted_answers = None
@@ -409,11 +424,14 @@ class QuizzMachine:
             device=self.device,
         )
 
             device=self.device,
         )
 
+        mistakes = (input == result).flatten(1).long().min(dim=1).values * 2 - 1
+
         self.save_quizzes(
             result_dir,
             f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
             quizzes=result[:72],
         self.save_quizzes(
             result_dir,
             f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
             quizzes=result[:72],
-            prediction=True,
+            show_to_be_predicted=True,
+            mistakes=mistakes[:72],
         )
 
         return main_test_accuracy
         )
 
         return main_test_accuracy
index cd726cb..5499bdf 100755 (executable)
@@ -87,6 +87,7 @@ class Reasoning(problem.Problem):
         answers,
         predicted_prompts=None,
         predicted_answers=None,
         answers,
         predicted_prompts=None,
         predicted_answers=None,
+        nrow=4,
     ):
         prompts = prompts.reshape(prompts.size(0), self.height, -1)
         answers = answers.reshape(answers.size(0), self.height, -1)
     ):
         prompts = prompts.reshape(prompts.size(0), self.height, -1)
         answers = answers.reshape(answers.size(0), self.height, -1)
@@ -114,9 +115,13 @@ class Reasoning(problem.Problem):
                 y[...] = c
             else:
                 c = c.long()[:, None]
                 y[...] = c
             else:
                 c = c.long()[:, None]
-                c = c * torch.tensor([192, 192, 192], device=c.device) + (
-                    1 - c
-                ) * torch.tensor([255, 255, 255], device=c.device)
+                c = (
+                    (1 - ((c == 1).long() + (c == 0).long() + (c == -1).long()))
+                    * torch.tensor([192, 192, 192], device=c.device)
+                    + (c == 1).long() * torch.tensor([0, 255, 0], device=c.device)
+                    + (c == 0).long() * torch.tensor([255, 255, 255], device=c.device)
+                    + (c == -1).long() * torch.tensor([255, 0, 0], device=c.device)
+                )
                 y[...] = c[:, :, None, None]
 
             y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x
                 y[...] = c[:, :, None, None]
 
             y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x
@@ -186,7 +191,11 @@ class Reasoning(problem.Problem):
 
         image_name = os.path.join(result_dir, filename)
         torchvision.utils.save_image(
 
         image_name = os.path.join(result_dir, filename)
         torchvision.utils.save_image(
-            img.float() / 255.0, image_name, nrow=4, padding=margin * 4, pad_value=1.0
+            img.float() / 255.0,
+            image_name,
+            nrow=nrow,
+            padding=margin * 4,
+            pad_value=1.0,
         )
 
     ######################################################################
         )
 
     ######################################################################
@@ -581,8 +590,8 @@ class Reasoning(problem.Problem):
 
     ######################################################################
 
 
     ######################################################################
 
-    def generate_prompts_and_answers(self, nb, device="cpu"):
-        tasks = [
+    def all_tasks(self):
+        return [
             self.task_replace_color,
             self.task_translate,
             self.task_grow,
             self.task_replace_color,
             self.task_translate,
             self.task_grow,
@@ -594,6 +603,11 @@ class Reasoning(problem.Problem):
             self.task_bounce,
             self.task_scale,
         ]
             self.task_bounce,
             self.task_scale,
         ]
+
+    def generate_prompts_and_answers(self, nb, tasks=None, device="cpu"):
+        if tasks is None:
+            tasks = self.all_tasks()
+
         prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64)
         answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64)
         w = self.width
         prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64)
         answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64)
         w = self.width
@@ -621,6 +635,7 @@ class Reasoning(problem.Problem):
         answers,
         predicted_prompts=None,
         predicted_answers=None,
         answers,
         predicted_prompts=None,
         predicted_answers=None,
+        nrow=4,
     ):
         self.save_image(
             result_dir,
     ):
         self.save_image(
             result_dir,
@@ -629,6 +644,7 @@ class Reasoning(problem.Problem):
             answers,
             predicted_prompts,
             predicted_answers,
             answers,
             predicted_prompts,
             predicted_answers,
+            nrow,
         )
 
 
         )
 
 
@@ -637,22 +653,32 @@ class Reasoning(problem.Problem):
 if __name__ == "__main__":
     import time
 
 if __name__ == "__main__":
     import time
 
+    nb = 4
+
     reasoning = Reasoning()
 
     reasoning = Reasoning()
 
+    for t in reasoning.all_tasks():
+        print(t.__name__)
+        prompts, answers = reasoning.generate_prompts_and_answers(nb, tasks=[t])
+        reasoning.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=1)
+
+    exit(0)
+
     start_time = time.perf_counter()
     start_time = time.perf_counter()
-    prompts, answers = reasoning.generate_prompts_and_answers(100)
+    prompts, answers = reasoning.generate_prompts_and_answers(nb)
     delay = time.perf_counter() - start_time
     print(f"{prompts.size(0)/delay:02f} seq/s")
 
     delay = time.perf_counter() - start_time
     print(f"{prompts.size(0)/delay:02f} seq/s")
 
-    predicted_prompts = torch.rand(prompts.size(0)) < 0.5
-    predicted_answers = torch.logical_not(predicted_prompts)
+    # m = torch.randint(2, (prompts.size(0),))
+    # predicted_prompts = m * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
+    # predicted_answers = (1 - m) * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
 
     reasoning.save_quizzes(
         "/tmp",
         "test",
 
     reasoning.save_quizzes(
         "/tmp",
         "test",
-        prompts[:64],
-        answers[:64],
+        prompts[:nb],
+        answers[:nb],
         # You can add a bool to put a frame around the predicted parts
         # You can add a bool to put a frame around the predicted parts
-        # predicted_prompts[:64],
-        # predicted_answers[:64],
+        # predicted_prompts[:nb],
+        # predicted_answers[:nb],
     )
     )
diff --git a/sky.py b/sky.py
index 6ef8a3a..ed440d3 100755 (executable)
--- a/sky.py
+++ b/sky.py
@@ -217,9 +217,11 @@ class Sky(problem.Problem):
                 y[...] = c
             else:
                 c = c.long()[:, None]
                 y[...] = c
             else:
                 c = c.long()[:, None]
-                c = c * torch.tensor([0, 0, 0], device=c.device) + (
-                    1 - c
-                ) * torch.tensor([255, 255, 255], device=c.device)
+                c = (
+                    (c == 1).long() * torch.tensor([0, 255, 0], device=c.device)
+                    + (c == 0).long() * torch.tensor([255, 255, 255], device=c.device)
+                    + (c == -1).long() * torch.tensor([255, 0, 0], device=c.device)
+                )
                 y[...] = c[:, :, None, None]
 
             y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x
                 y[...] = c[:, :, None, None]
 
             y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x
@@ -322,8 +324,8 @@ if __name__ == "__main__":
 
     prompts, answers = sky.generate_prompts_and_answers(4)
 
 
     prompts, answers = sky.generate_prompts_and_answers(4)
 
-    predicted_prompts = torch.rand(prompts.size(0)) < 0.5
-    predicted_answers = torch.rand(answers.size(0)) < 0.5
+    predicted_prompts = torch.randint(3, (prompts.size(0),)) - 1
+    predicted_answers = torch.randint(3, (prompts.size(0),)) - 1
 
     sky.save_quizzes(
         "/tmp", "test", prompts, answers, predicted_prompts, predicted_answers
 
     sky.save_quizzes(
         "/tmp", "test", prompts, answers, predicted_prompts, predicted_answers