Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 10 Jul 2024 13:01:54 +0000 (15:01 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 10 Jul 2024 13:01:54 +0000 (15:01 +0200)
grids.py

index 85d640d..6e9e6c7 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -344,7 +344,11 @@ class Grids(problem.Problem):
 
     # @torch.compile
     def task_translate(self, A, f_A, B, f_B):
-        di, dj = torch.randint(3, (2,)) - 1
+        while True:
+            di, dj = torch.randint(3, (2,)) - 1
+            if di.abs() + dj.abs() > 0:
+                break
+
         nb_rec = 3
         c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
@@ -871,13 +875,14 @@ if __name__ == "__main__":
     # exit(0)
 
     # if True:
-    nb = 72
+    # nb,nrow = 72,4
+    nb, nrow = 8, 2
 
     for t in grids.all_tasks():
         # for t in [grids.task_replace_color]:
         print(t.__name__)
         prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
-        grids.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=4)
+        grids.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow)
 
     nb = 1000