From 3aeec6f942e595694d43355ac57b33931d1d2480 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 10 Jul 2024 15:01:54 +0200 Subject: [PATCH] Update. --- grids.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/grids.py b/grids.py index 85d640d..6e9e6c7 100755 --- a/grids.py +++ b/grids.py @@ -344,7 +344,11 @@ class Grids(problem.Problem): # @torch.compile def task_translate(self, A, f_A, B, f_B): - di, dj = torch.randint(3, (2,)) - 1 + while True: + di, dj = torch.randint(3, (2,)) - 1 + if di.abs() + dj.abs() > 0: + break + nb_rec = 3 c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1 for X, f_X in [(A, f_A), (B, f_B)]: @@ -871,13 +875,14 @@ if __name__ == "__main__": # exit(0) # if True: - nb = 72 + # nb,nrow = 72,4 + nb, nrow = 8, 2 for t in grids.all_tasks(): # for t in [grids.task_replace_color]: print(t.__name__) prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t]) - grids.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=4) + grids.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow) nb = 1000 -- 2.39.5