X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=grids.py;h=6e9e6c7cdc36809d0be5ce354fc9375a592ef4ce;hb=3aeec6f942e595694d43355ac57b33931d1d2480;hp=85d640dda9b3c6550cf7891037be1759a17a734e;hpb=428fd9169ecc3d03c9e8282d319682ddab0f098d;p=culture.git diff --git a/grids.py b/grids.py index 85d640d..6e9e6c7 100755 --- a/grids.py +++ b/grids.py @@ -344,7 +344,11 @@ class Grids(problem.Problem): # @torch.compile def task_translate(self, A, f_A, B, f_B): - di, dj = torch.randint(3, (2,)) - 1 + while True: + di, dj = torch.randint(3, (2,)) - 1 + if di.abs() + dj.abs() > 0: + break + nb_rec = 3 c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1 for X, f_X in [(A, f_A), (B, f_B)]: @@ -871,13 +875,14 @@ if __name__ == "__main__": # exit(0) # if True: - nb = 72 + # nb,nrow = 72,4 + nb, nrow = 8, 2 for t in grids.all_tasks(): # for t in [grids.task_replace_color]: print(t.__name__) prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t]) - grids.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=4) + grids.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow) nb = 1000