From 06b079511fb1aff6e684e5021b791de9b5efac7b Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 23 Jul 2024 11:12:25 +0200 Subject: [PATCH] Update. --- grids.py | 119 ++++++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 108 insertions(+), 11 deletions(-) diff --git a/grids.py b/grids.py index ba0131d..5778f85 100755 --- a/grids.py +++ b/grids.py @@ -214,14 +214,17 @@ class Grids(problem.Problem): self.task_half_fill, self.task_frame, self.task_detect, - # self.task_count, # NOT REVERSIBLE - self.task_trajectory, - self.task_bounce, self.task_scale, self.task_symbols, - self.task_isometry, self.task_corners, self.task_contact, + self.task_path, + self.task_fill, + ############################################ hard ones + self.task_isometry, + self.task_trajectory, + self.task_bounce, + # self.task_count, # NOT REVERSIBLE # self.task_islands, # TOO MESSY ] @@ -467,13 +470,13 @@ class Grids(problem.Problem): j[:, 1, 0], j[:, 1, 1], ) - no_overlap = torch.logical_not( + no_overlap = ( (A_i1 >= B_i2) - & (A_i2 <= B_i1) - & (A_j1 >= B_j1) - & (A_j2 <= B_j1) + | (A_i2 <= B_i1) + | (A_j1 >= B_j2) + | (A_j2 <= B_j1) ) - i, j = i[no_overlap], j[no_overlap] + i, j = (i[no_overlap], j[no_overlap]) elif nb_rec == 3: A_i1, A_i2, A_j1, A_j2 = ( i[:, 0, 0], @@ -1322,6 +1325,100 @@ class Grids(problem.Problem): X[i2 - 1, j1] = c[n] f_X[i1:i2, j1:j2] = c[n] + def compdist(self, X, i, j): + dd = X.new_full((self.height + 2, self.width + 2), self.height * self.width) + d = dd[1:-1, 1:-1] + m = (X > 0).long() + d[i, j] = 0 + e = d.clone() + while True: + e[...] = d + d[...] = ( + d.min(dd[:-2, 1:-1] + 1) + .min(dd[2:, 1:-1] + 1) + .min(dd[1:-1, :-2] + 1) + .min(dd[1:-1, 2:] + 1) + ) + d[...] = (1 - m) * d + m * self.height * self.width + if e.equal(d): + break + + return d + + # @torch.compile + def task_path(self, A, f_A, B, f_B): + nb_rec = 2 + c = torch.randperm(len(self.colors) - 1)[: nb_rec + 2] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + while True: + X[...] = 0 + f_X[...] = 0 + + r = self.rec_coo(nb_rec, prevent_overlap=True) + for n in range(nb_rec): + i1, j1, i2, j2 = r[n] + X[i1:i2, j1:j2] = c[n] + f_X[i1:i2, j1:j2] = c[n] + + while True: + i1, i2 = torch.randint(self.height, (2,)) + j1, j2 = torch.randint(self.width, (2,)) + if ( + abs(i1 - i2) + abs(j1 - j2) > 2 + and X[i1, j1] == 0 + and X[i2, j2] == 0 + ): + break + + d2 = self.compdist(X, i2, j2) + d = self.compdist(X, i1, j1) + + if d2[i1, j1] < 2 * self.width: + break + + m = ((d + d2) == d[i2, j2]).long() + f_X[...] = m * c[-1] + (1 - m) * f_X + + X[i1, j1] = c[-2] + X[i2, j2] = c[-2] + f_X[i1, j1] = c[-2] + f_X[i2, j2] = c[-2] + + # @torch.compile + def task_fill(self, A, f_A, B, f_B): + nb_rec = 3 + c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + accept_full = torch.rand(1) < 0.5 + + while True: + X[...] = 0 + f_X[...] = 0 + + r = self.rec_coo(nb_rec, prevent_overlap=True) + for n in range(nb_rec): + i1, j1, i2, j2 = r[n] + X[i1:i2, j1:j2] = c[n] + f_X[i1:i2, j1:j2] = c[n] + + while True: + i, j = ( + torch.randint(self.height, (1,)).item(), + torch.randint(self.width, (1,)).item(), + ) + if X[i, j] == 0: + break + + d = self.compdist(X, i, j) + m = (d < self.height * self.width).long() + X[i, j] = c[-1] + f_X[...] = m * c[-1] + (1 - m) * f_X + + if accept_full or (d * (X == 0)).max() == self.height * self.width: + break + + # end_tasks + ###################################################################### def trivial_prompts_and_answers(self, prompts, answers): @@ -1418,11 +1515,11 @@ if __name__ == "__main__": # exit(0) # if True: - nb, nrow = 8, 2 + nb, nrow = 128, 4 # nb, nrow = 8, 2 # for t in grids.all_tasks: - for t in [grids.task_count]: + for t in [grids.task_fill]: # for t in [grids.task_symbols]: print(t.__name__) prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t]) -- 2.39.5