Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 4 Jul 2024 21:09:08 +0000 (00:09 +0300)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 4 Jul 2024 21:09:08 +0000 (00:09 +0300)
reasoning.py

index 54a4203..9e26d64 100755 (executable)
@@ -407,6 +407,65 @@ class Reasoning(problem.Problem):
                     f_X[i1, j1] = c[-1]
 
     def task_count(self, A, f_A, B, f_B):
+        N = torch.randint(4, (1,)) + 2
+        c = torch.randperm(len(self.colors) - 1)[:N] + 1
+
+        for X, f_X in [(A, f_A), (B, f_B)]:
+
+            def contact(i, j, q):
+                nq, nq_diag = 0, 0
+                no = 0
+
+                for ii, jj in [
+                    (i - 1, j - 1),
+                    (i - 1, j),
+                    (i - 1, j + 1),
+                    (i, j - 1),
+                    (i, j + 1),
+                    (i + 1, j - 1),
+                    (i + 1, j),
+                    (i + 1, j + 1),
+                ]:
+                    if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width:
+                        if X[ii, jj] != 0 and X[ii, jj] != q:
+                            no += 1
+
+                for ii, jj in [
+                    (i - 1, j - 1),
+                    (i - 1, j + 1),
+                    (i + 1, j - 1),
+                    (i + 1, j + 1),
+                ]:
+                    if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width:
+                        if X[ii, jj] == q and X[i, jj] != q and X[ii, j] != q:
+                            nq_diag += 1
+
+                for ii, jj in [(i - 1, j), (i, j - 1), (i, j + 1), (i + 1, j)]:
+                    if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width:
+                        if X[ii, jj] == q:
+                            nq += 1
+
+                return no, nq, nq_diag
+
+            nb = torch.zeros(N, dtype=torch.int64)
+            q = torch.randint(N, (self.height * self.width,))
+            k = torch.randperm(self.height * self.width)
+            for p in range(self.height * self.width):
+                i, j = k[p] % self.height, k[p] // self.height
+                no, nq, nq_diag = contact(i, j, c[q[p]])
+                if no == 0 and nq_diag == 0:
+                    if nq == 0:
+                        if nb[q[p]] < self.width:
+                            X[i, j] = c[q[p]]
+                            nb[q[p]] += 1
+                    if nq == 1:
+                        X[i, j] = c[q[p]]
+
+            for n in range(N):
+                for j in range(nb[n]):
+                    f_X[n, j] = c[n]
+
+    def task_count_(self, A, f_A, B, f_B):
         N = torch.randint(3, (1,)) + 1
         c = torch.randperm(len(self.colors) - 1)[:N] + 1
         for X, f_X in [(A, f_A), (B, f_B)]: