From f631304a61037c523a6a74387c940a29e662c905 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 21 Jul 2024 09:13:48 +0200 Subject: [PATCH] Update. --- grids.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/grids.py b/grids.py index 0b789ef..d3e7dcc 100755 --- a/grids.py +++ b/grids.py @@ -222,7 +222,7 @@ class Grids(problem.Problem): self.task_symbols, self.task_isometry, self.task_corners, - self.task_proximity, + self.task_contact, # self.task_islands, # TOO MESSY ] @@ -1277,7 +1277,7 @@ class Grids(problem.Problem): # @torch.compile # [ai1,ai2] [bi1,bi2] - def task_proximity(self, A, f_A, B, f_B): + def task_contact(self, A, f_A, B, f_B): def rec_dist(a, b): ai1, aj1, ai2, aj2 = a bi1, bj1, bi2, bj2 = b @@ -1297,10 +1297,12 @@ class Grids(problem.Problem): for n in range(nb_rec): i1, j1, i2, j2 = r[n] X[i1:i2, j1:j2] = c[n] + f_X[i1:i2, j1:j2] = c[n] if d[n] == 0: - f_X[i1:i2, j1:j2] = c[0] - else: - f_X[i1:i2, j1:j2] = c[n] + f_X[i1, j1:j2] = c[0] + f_X[i2 - 1, j1:j2] = c[0] + f_X[i1:i2, j1] = c[0] + f_X[i1:i2, j2 - 1] = c[0] # @torch.compile # [ai1,ai2] [bi1,bi2] @@ -1421,8 +1423,8 @@ if __name__ == "__main__": # nb, nrow = 8, 2 # for t in grids.all_tasks: - # for t in [grids.task_proximity, grids.task_corners]: - for t in [grids.task_symbols]: + for t in [grids.task_contact]: + # for t in [grids.task_symbols]: print(t.__name__) prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t]) # prompts[...] = torch.randint(grids.nb_token_values(), prompts.size()) -- 2.20.1