# @torch.compile
def task_translate(self, A, f_A, B, f_B):
- di, dj = torch.randint(3, (2,)) - 1
+ while True:
+ di, dj = torch.randint(3, (2,)) - 1
+ if di.abs() + dj.abs() > 0:
+ break
+
nb_rec = 3
c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
for X, f_X in [(A, f_A), (B, f_B)]:
# exit(0)
# if True:
- nb = 72
+ # nb,nrow = 72,4
+ nb, nrow = 8, 2
for t in grids.all_tasks():
# for t in [grids.task_replace_color]:
print(t.__name__)
prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
- grids.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=4)
+ grids.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow)
nb = 1000