self.task_symbols,
self.task_isometry,
self.task_corners,
- self.task_proximity,
+ self.task_contact,
# self.task_islands, # TOO MESSY
]
# @torch.compile
# [ai1,ai2] [bi1,bi2]
- def task_proximity(self, A, f_A, B, f_B):
+ def task_contact(self, A, f_A, B, f_B):
def rec_dist(a, b):
ai1, aj1, ai2, aj2 = a
bi1, bj1, bi2, bj2 = b
for n in range(nb_rec):
i1, j1, i2, j2 = r[n]
X[i1:i2, j1:j2] = c[n]
+ f_X[i1:i2, j1:j2] = c[n]
if d[n] == 0:
- f_X[i1:i2, j1:j2] = c[0]
- else:
- f_X[i1:i2, j1:j2] = c[n]
+ f_X[i1, j1:j2] = c[0]
+ f_X[i2 - 1, j1:j2] = c[0]
+ f_X[i1:i2, j1] = c[0]
+ f_X[i1:i2, j2 - 1] = c[0]
# @torch.compile
# [ai1,ai2] [bi1,bi2]
# nb, nrow = 8, 2
# for t in grids.all_tasks:
- # for t in [grids.task_proximity, grids.task_corners]:
- for t in [grids.task_symbols]:
+ for t in [grids.task_contact]:
+ # for t in [grids.task_symbols]:
print(t.__name__)
prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
# prompts[...] = torch.randint(grids.nb_token_values(), prompts.size())