M = F.conv2d(Z[:, None, :, :], w, padding=1)
M = torch.cat([M[:, :1], M[:, 1:].min(dim=1, keepdim=True).values], dim=1)
M = ((M[:, 0] == 0) & (Z == 0)).long()
+ Q = (M.flatten(1).max(dim=1).values > 0).long()[:, None]
M = M * torch.rand(M.size())
M = M.flatten(1)
M = F.one_hot(M.argmax(dim=1), num_classes=M.size(1))
- U += M
+ U += M * Q
for _ in range(nb_iterations):
M = F.conv2d(Z[:, None, :, :], w, padding=1)
M = torch.cat([M[:, :1], M[:, 1:].min(dim=1, keepdim=True).values], dim=1)
M = ((M[:, 1] >= 0) & (Z == 0)).long()
+ Q = (M.flatten(1).max(dim=1).values > 0).long()[:, None]
M = M * torch.rand(M.size())
M = M.flatten(1)
M = F.one_hot(M.argmax(dim=1), num_classes=M.size(1))
U = Z.flatten(1)
- U += M
+ U += M * Q
M = Z.clone()
Z = Z * (torch.arange(Z.size(1) * Z.size(2)) + 1).reshape(1, Z.size(1), Z.size(2))
- for _ in range(100):
+ while True:
+ W = Z.clone()
Z = F.max_pool2d(Z, 3, 1, 1) * M
+ if Z.equal(W):
+ break
Z = Z.long()
U = Z.flatten(1)
return no, nq, nq_diag
def task_count(self, A, f_A, B, f_B):
- N = torch.randint(4, (1,)).item() + 2
- c = torch.randperm(len(self.colors) - 1)[:N] + 1
-
- for X, f_X in [(A, f_A), (B, f_B)]:
- l_q = torch.randperm(self.height * self.width)[
- : self.height * self.width // 20
- ]
- l_d = torch.randint(N, l_q.size())
- nb = torch.zeros(N, dtype=torch.int64)
-
- for q, e in zip(l_q, l_d):
- d = c[e]
- i, j = q % self.height, q // self.height
- if (
- nb[e] < self.width
- and X[max(0, i - 1) : i + 2, max(0, j - 1) : j + 2] == 0
- ).all():
- X[i, j] = d
- nb[e] += 1
-
- l_q = torch.randperm((self.height - 2) * (self.width - 2))[
- : self.height * self.width // 2
- ]
- l_d = torch.randint(N, l_q.size())
- for q, e in zip(l_q, l_d):
- d = c[e]
- i, j = q % (self.height - 2) + 1, q // (self.height - 2) + 1
- a1, a2, a3 = X[i - 1, j - 1 : j + 2]
- a8, a4 = X[i, j - 1], X[i, j + 1]
- a7, a6, a5 = X[i + 1, j - 1 : j + 2]
- if (
- X[i, j] == 0
- and nb[e] < self.width
- and (a2 == 0 or a2 == d)
- and (a4 == 0 or a4 == d)
- and (a6 == 0 or a6 == d)
- and (a8 == 0 or a8 == d)
- and (a1 == 0 or a2 == d or a8 == d)
- and (a3 == 0 or a4 == d or a2 == d)
- and (a5 == 0 or a6 == d or a4 == d)
- and (a7 == 0 or a8 == d or a6 == d)
- ):
- o = (
- (a2 != 0).long()
- + (a4 != 0).long()
- + (a6 != 0).long()
- + (a8 != 0).long()
+ while True:
+ error = False
+
+ N = torch.randint(5, (1,)).item() + 1
+ c = torch.zeros(N + 1)
+ c[1:] = torch.randperm(len(self.colors) - 1)[:N] + 1
+
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ if not hasattr(self, "cache_count") or len(self.cache_count) == 0:
+ self.cache_count = list(
+ grow_islands(
+ 1000,
+ self.height,
+ self.width,
+ nb_seeds=self.height * self.width // 9,
+ nb_iterations=self.height * self.width // 20,
+ )
)
- if o <= 1:
- X[i, j] = d
- nb[e] += 1 - o
- for e in range(N):
- for j in range(nb[e]):
- f_X[e, j] = c[e]
+ X[...] = self.cache_count.pop()
+
+ k = (X.max() + 1 + (c.size(0) - 1)).item()
+ V = torch.arange(k) // (c.size(0) - 1)
+ V = (V + torch.rand(V.size())).sort().indices[: X.max() + 1] % (
+ c.size(0) - 1
+ ) + 1
+ V[0] = 0
+ X[...] = c[V[X]]
+
+ if F.one_hot(X.flatten()).max(dim=0).values.sum().item() == N + 1:
+ f_X[...] = 0
+ for e in range(1, N + 1):
+ for j in range((X == c[e]).sum() + 1):
+ if j < self.width:
+ f_X[e - 1, j] = c[e]
+ else:
+ error = True
+ break
+ else:
+ error = True
+ break
+
+ if not error:
+ break
# @torch.compile
def task_trajectory(self, A, f_A, B, f_B):
"/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow
)
- exit(0)
+ # exit(0)
nb = 1000
# for t in grids.all_tasks:
- for t in [grids.task_islands]:
+ for t in [grids.task_count]:
start_time = time.perf_counter()
prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
delay = time.perf_counter() - start_time