Update.
[culture.git] / reasoning.py
index 003806a..c442947 100755 (executable)
@@ -32,13 +32,12 @@ class Reasoning(problem.Problem):
         ("gray", [192, 192, 192]),
     ]
 
-    def __init__(
-        self,
-    ):
+    def __init__(self, device=torch.device("cpu")):
         self.colors = torch.tensor([c for _, c in self.named_colors])
         self.name2color = dict([(p[0], i) for i, p in enumerate(self.named_colors)])
         self.height = 10
         self.width = 10
+        self.device = device
 
     ######################################################################
 
@@ -170,7 +169,73 @@ class Reasoning(problem.Problem):
     def nb_token_values(self):
         return len(self.colors)
 
+    # That's quite a tensorial spaghetti mess to sample
+    # non-overlapping rectangles quickly, but made the generation of
+    # 100k samples from 1h50 with a lame pure python code to 4min with
+    # this one.
     def rec_coo(self, x, n, min_height=3, min_width=3):
+        K = 3
+        N = 1000
+
+        while True:
+            v = (
+                (
+                    torch.rand(N * K, self.height + 1, device=self.device)
+                    .sort(dim=-1)
+                    .indices
+                    < 2
+                )
+                .long()
+                .cumsum(dim=1)
+                == 1
+            ).long()
+
+            h = (
+                (
+                    torch.rand(N * K, self.width + 1, device=self.device)
+                    .sort(dim=-1)
+                    .indices
+                    < 2
+                )
+                .long()
+                .cumsum(dim=1)
+                == 1
+            ).long()
+
+            i = torch.logical_and(
+                v.sum(dim=-1) >= min_height, h.sum(dim=-1) >= min_width
+            )
+
+            v, h = v[i], h[i]
+            v = v[: v.size(0) - v.size(0) % K]
+            h = h[: h.size(0) - h.size(0) % K]
+            v = v.reshape(v.size(0) // K, K, -1)
+            h = h.reshape(h.size(0) // K, K, -1)
+
+            r = v[:, :, :, None] * h[:, :, None, :]
+
+            valid = r.sum(dim=1).flatten(1).max(dim=-1).values == 1
+
+            v = v[valid]
+            h = h[valid]
+
+            if v.size(0) > 0:
+                break
+
+        av = torch.arange(v.size(2), device=self.device)[None, :]
+        ah = torch.arange(h.size(2), device=self.device)[None, :]
+
+        return [
+            (i1.item(), j1.item(), i2.item() + 1, j2.item() + 1)
+            for i1, j1, i2, j2 in zip(
+                v.size(2) - (v[0] * (v.size(2) - av)).max(dim=-1).values,
+                h.size(2) - (h[0] * (h.size(2) - ah)).max(dim=-1).values,
+                (v[0] * av).max(dim=-1).values,
+                (h[0] * ah).max(dim=-1).values,
+            )
+        ]
+
+    def rec_coo_(self, x, n, min_height=3, min_width=3):
         collision = x.new(x.size())
         while True:
             collision[...] = 0
@@ -295,7 +360,7 @@ class Reasoning(problem.Problem):
 
     ######################################################################
 
-    def generate_prompts_and_answers(self, nb):
+    def generate_prompts_and_answers(self, nb, device="cpu"):
         tasks = [
             self.task_replace_color,
             self.task_move,
@@ -320,6 +385,7 @@ class Reasoning(problem.Problem):
             f_B = answer
             task = tasks[torch.randint(len(tasks), (1,))]
             task(A, f_A, B, f_B)
+
         return prompts.flatten(1), answers.flatten(1)
 
     def save_quizzes(