Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 3 Jul 2024 19:59:59 +0000 (22:59 +0300)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 3 Jul 2024 19:59:59 +0000 (22:59 +0300)
main.py
reasoning.py

diff --git a/main.py b/main.py
index 8b5d9a4..a954af6 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -250,7 +250,7 @@ if args.problem == "sky":
         speed=args.sky_speed,
     )
 elif args.problem == "reasoning":
-    problem = reasoning.Reasoning()
+    problem = reasoning.Reasoning(device=device)
 else:
     raise ValueError
 
index 003806a..b8d39ee 100755 (executable)
@@ -32,13 +32,12 @@ class Reasoning(problem.Problem):
         ("gray", [192, 192, 192]),
     ]
 
-    def __init__(
-        self,
-    ):
+    def __init__(self, device=torch.device("cpu")):
         self.colors = torch.tensor([c for _, c in self.named_colors])
         self.name2color = dict([(p[0], i) for i, p in enumerate(self.named_colors)])
         self.height = 10
         self.width = 10
+        self.device = device
 
     ######################################################################
 
@@ -171,6 +170,68 @@ class Reasoning(problem.Problem):
         return len(self.colors)
 
     def rec_coo(self, x, n, min_height=3, min_width=3):
+        K = 3
+        N = 4000
+
+        while True:
+            v = (
+                (
+                    torch.rand(N * K, self.height + 1, device=self.device)
+                    .sort(dim=-1)
+                    .indices
+                    < 2
+                )
+                .long()
+                .cumsum(dim=1)
+                == 1
+            ).long()
+
+            h = (
+                (
+                    torch.rand(N * K, self.width + 1, device=self.device)
+                    .sort(dim=-1)
+                    .indices
+                    < 2
+                )
+                .long()
+                .cumsum(dim=1)
+                == 1
+            ).long()
+
+            i = torch.logical_and(
+                v.sum(dim=-1) >= min_height, h.sum(dim=-1) >= min_width
+            )
+
+            v, h = v[i], h[i]
+            v = v[: v.size(0) - v.size(0) % K]
+            h = h[: h.size(0) - h.size(0) % K]
+            v = v.reshape(v.size(0) // K, K, -1)
+            h = h.reshape(h.size(0) // K, K, -1)
+
+            r = v[:, :, :, None] * h[:, :, None, :]
+
+            valid = r.sum(dim=1).flatten(1).max(dim=-1).values == 1
+
+            v = v[valid]
+            h = h[valid]
+
+            if v.size(0) > 0:
+                break
+
+        av = torch.arange(v.size(2), device=self.device)[None, :]
+        ah = torch.arange(h.size(2), device=self.device)[None, :]
+
+        return [
+            (i1.item(), j1.item(), i2.item() + 1, j2.item() + 1)
+            for i1, j1, i2, j2 in zip(
+                v.size(2) - (v[0] * (v.size(2) - av)).max(dim=-1).values,
+                h.size(2) - (h[0] * (h.size(2) - ah)).max(dim=-1).values,
+                (v[0] * av).max(dim=-1).values,
+                (h[0] * ah).max(dim=-1).values,
+            )
+        ]
+
+    def rec_coo_(self, x, n, min_height=3, min_width=3):
         collision = x.new(x.size())
         while True:
             collision[...] = 0
@@ -295,7 +356,7 @@ class Reasoning(problem.Problem):
 
     ######################################################################
 
-    def generate_prompts_and_answers(self, nb):
+    def generate_prompts_and_answers(self, nb, device="cpu"):
         tasks = [
             self.task_replace_color,
             self.task_move,
@@ -304,8 +365,12 @@ class Reasoning(problem.Problem):
             self.task_frame,
             self.task_detect,
         ]
-        prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64)
-        answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64)
+        prompts = torch.zeros(
+            nb, self.height, self.width * 3, dtype=torch.int64, device=self.device
+        )
+        answers = torch.zeros(
+            nb, self.height, self.width, dtype=torch.int64, device=self.device
+        )
         w = self.width
 
         for prompt, answer in tqdm.tqdm(