self.task_half_fill,
self.task_frame,
self.task_detect,
- # self.task_count, # NOT REVERSIBLE
- self.task_trajectory,
- self.task_bounce,
self.task_scale,
self.task_symbols,
- self.task_isometry,
self.task_corners,
self.task_contact,
+ self.task_path,
+ self.task_fill,
+ ############################################ hard ones
+ self.task_isometry,
+ self.task_trajectory,
+ self.task_bounce,
+ # self.task_count, # NOT REVERSIBLE
# self.task_islands, # TOO MESSY
]
j[:, 1, 0],
j[:, 1, 1],
)
- no_overlap = torch.logical_not(
+ no_overlap = (
(A_i1 >= B_i2)
- & (A_i2 <= B_i1)
- & (A_j1 >= B_j1)
- & (A_j2 <= B_j1)
+ | (A_i2 <= B_i1)
+ | (A_j1 >= B_j2)
+ | (A_j2 <= B_j1)
)
- i, j = i[no_overlap], j[no_overlap]
+ i, j = (i[no_overlap], j[no_overlap])
elif nb_rec == 3:
A_i1, A_i2, A_j1, A_j2 = (
i[:, 0, 0],
X[i2 - 1, j1] = c[n]
f_X[i1:i2, j1:j2] = c[n]
+ def compdist(self, X, i, j):
+ dd = X.new_full((self.height + 2, self.width + 2), self.height * self.width)
+ d = dd[1:-1, 1:-1]
+ m = (X > 0).long()
+ d[i, j] = 0
+ e = d.clone()
+ while True:
+ e[...] = d
+ d[...] = (
+ d.min(dd[:-2, 1:-1] + 1)
+ .min(dd[2:, 1:-1] + 1)
+ .min(dd[1:-1, :-2] + 1)
+ .min(dd[1:-1, 2:] + 1)
+ )
+ d[...] = (1 - m) * d + m * self.height * self.width
+ if e.equal(d):
+ break
+
+ return d
+
+ # @torch.compile
+ def task_path(self, A, f_A, B, f_B):
+ nb_rec = 2
+ c = torch.randperm(len(self.colors) - 1)[: nb_rec + 2] + 1
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ while True:
+ X[...] = 0
+ f_X[...] = 0
+
+ r = self.rec_coo(nb_rec, prevent_overlap=True)
+ for n in range(nb_rec):
+ i1, j1, i2, j2 = r[n]
+ X[i1:i2, j1:j2] = c[n]
+ f_X[i1:i2, j1:j2] = c[n]
+
+ while True:
+ i1, i2 = torch.randint(self.height, (2,))
+ j1, j2 = torch.randint(self.width, (2,))
+ if (
+ abs(i1 - i2) + abs(j1 - j2) > 2
+ and X[i1, j1] == 0
+ and X[i2, j2] == 0
+ ):
+ break
+
+ d2 = self.compdist(X, i2, j2)
+ d = self.compdist(X, i1, j1)
+
+ if d2[i1, j1] < 2 * self.width:
+ break
+
+ m = ((d + d2) == d[i2, j2]).long()
+ f_X[...] = m * c[-1] + (1 - m) * f_X
+
+ X[i1, j1] = c[-2]
+ X[i2, j2] = c[-2]
+ f_X[i1, j1] = c[-2]
+ f_X[i2, j2] = c[-2]
+
+ # @torch.compile
+ def task_fill(self, A, f_A, B, f_B):
+ nb_rec = 3
+ c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ accept_full = torch.rand(1) < 0.5
+
+ while True:
+ X[...] = 0
+ f_X[...] = 0
+
+ r = self.rec_coo(nb_rec, prevent_overlap=True)
+ for n in range(nb_rec):
+ i1, j1, i2, j2 = r[n]
+ X[i1:i2, j1:j2] = c[n]
+ f_X[i1:i2, j1:j2] = c[n]
+
+ while True:
+ i, j = (
+ torch.randint(self.height, (1,)).item(),
+ torch.randint(self.width, (1,)).item(),
+ )
+ if X[i, j] == 0:
+ break
+
+ d = self.compdist(X, i, j)
+ m = (d < self.height * self.width).long()
+ X[i, j] = c[-1]
+ f_X[...] = m * c[-1] + (1 - m) * f_X
+
+ if accept_full or (d * (X == 0)).max() == self.height * self.width:
+ break
+
+ # end_tasks
+
######################################################################
def trivial_prompts_and_answers(self, prompts, answers):
# exit(0)
# if True:
- nb, nrow = 8, 2
+ nb, nrow = 128, 4
# nb, nrow = 8, 2
# for t in grids.all_tasks:
- for t in [grids.task_count]:
+ for t in [grids.task_fill]:
# for t in [grids.task_symbols]:
print(t.__name__)
prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])