+ dist[1:-1, 1:-1] += (X != 0).long() * self.height * self.width
+ dist[0, :] = self.height * self.width
+ dist[-1, :] = self.height * self.width
+ dist[:, 0] = self.height * self.width
+ dist[:, -1] = self.height * self.width
+ # dist += torch.rand(dist.size())
+
+ i, j = i0 + 1, j0 + 1
+ while i != i1 + 1 or j != j1 + 1:
+ f_X[i - 1, j - 1] = c[2]
+ r, s, t, u = (
+ dist[i - 1, j],
+ dist[i, j - 1],
+ dist[i + 1, j],
+ dist[i, j + 1],
+ )
+ m = min(r, s, t, u)
+ if r == m:
+ i = i - 1
+ elif t == m:
+ i = i + 1
+ elif s == m:
+ j = j - 1
+ else:
+ j = j + 1
+
+ X[i0, j0] = c[2]
+ # f_X[i0, j0] = c[1]
+
+ X[i1, j1] = c[1]
+ f_X[i1, j1] = c[1]
+
+ # for X, f_X in [(A, f_A), (B, f_B)]:
+ # n = torch.arange(self.height * self.width).reshape(self.height, self.width)
+ # k = torch.randperm(self.height * self.width)
+ # X[...]=-1
+ # for q in k:
+ # i,j=q%self.height,q//self.height
+ # if
+
+ # @torch.compile
+ def task_puzzle(self, A, f_A, B, f_B):
+ S = 4
+ i0, j0 = (self.height - S) // 2, (self.width - S) // 2
+ c = torch.randperm(len(self.colors) - 1)[:4] + 1
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ while True:
+ f_X[...] = 0
+ h = list(torch.randperm(c.size(0)))
+ n = torch.zeros(c.max() + 1)
+ for _ in range(2):
+ k = torch.randperm(S * S)
+ for q in k:
+ i, j = q % S + i0, q // S + j0
+ if f_X[i, j] == 0:
+ r, s, t, u = (
+ f_X[i - 1, j],
+ f_X[i, j - 1],
+ f_X[i + 1, j],
+ f_X[i, j + 1],
+ )
+ r, s, t, u = torch.tensor([r, s, t, u])[torch.randperm(4)]
+ if r > 0 and n[r] < 6:
+ n[r] += 1
+ f_X[i, j] = r
+ elif s > 0 and n[s] < 6:
+ n[s] += 1
+ f_X[i, j] = s
+ elif t > 0 and n[t] < 6:
+ n[t] += 1
+ f_X[i, j] = t
+ elif u > 0 and n[u] < 6:
+ n[u] += 1
+ f_X[i, j] = u
+ else:
+ if len(h) > 0:
+ d = c[h.pop()]
+ n[d] += 1
+ f_X[i, j] = d
+
+ if n.sum() == S * S:
+ break
+
+ k = 0
+ for d in range(4):
+ while True:
+ ii, jj = (
+ torch.randint(self.height, (1,)).item(),
+ torch.randint(self.width, (1,)).item(),
+ )
+ e = 0
+ for i in range(S):
+ for j in range(S):
+ if (
+ ii + i >= self.height
+ or jj + j >= self.width
+ or (
+ f_X[i + i0, j + j0] == c[d]
+ and X[ii + i, jj + j] > 0
+ )
+ ):
+ e = 1
+ if e == 0:
+ break
+ for i in range(S):
+ for j in range(S):
+ if f_X[i + i0, j + j0] == c[d]:
+ X[ii + i, jj + j] = c[d]
+
+ def task_islands(self, A, f_A, B, f_B):
+ c = torch.randperm(len(self.colors) - 1)[:2] + 1
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ if not hasattr(self, "cache_islands") or len(self.cache_islands) == 0:
+ self.cache_islands = list(
+ grow_islands(
+ 1000,
+ self.height,
+ self.width,
+ nb_seeds=self.height * self.width // 20,
+ nb_iterations=self.height * self.width // 2,
+ )
+ )
+
+ A = self.cache_islands.pop()
+
+ while True:
+ i, j = (
+ torch.randint(self.height // 2, (1,)).item(),
+ torch.randint(self.width // 2, (1,)).item(),
+ )
+ if A[i, j] > 0:
+ break
+
+ X[...] = (A > 0) * c[0]
+ X[i, j] = c[1]
+ f_X[...] = (A == A[i, j]) * c[1] + ((A > 0) & (A != A[i, j])) * c[0]
+