Update.
[culture.git] / lang.py
diff --git a/lang.py b/lang.py
index 43550d7..abb7ca2 100755 (executable)
--- a/lang.py
+++ b/lang.py
@@ -34,13 +34,11 @@ class Lang(problem.Problem):
 
     def __init__(
         self,
-        nb_iterations=2,
     ):
         self.colors = torch.tensor([c for _, c in self.named_colors])
         self.name2color = dict([(p[0], i) for i, p in enumerate(self.named_colors)])
         self.height = 10
         self.width = 10
-        self.nb_iterations = nb_iterations
 
     ######################################################################
 
@@ -173,8 +171,9 @@ class Lang(problem.Problem):
         return len(self.colors)
 
     def rec_coo(self, x, n, min_height=3, min_width=3):
+        collision = x.new(x.size())
         while True:
-            collision = x.new_zeros(x.size())
+            collision[...] = 0
             result = []
             for _ in range(n):
                 while True:
@@ -263,13 +262,14 @@ class Lang(problem.Problem):
             r = self.rec_coo(X, N)
             for n in range(N):
                 i1, j1, i2, j2 = r[n]
-                X[i1 : (i1 + i2) // 2, j1:j2] = c[2 * n]
-                f_X[i1 : (i1 + i2) // 2, j1:j2] = c[2 * n]
-                X[(i1 + i2) // 2 : (i1 + i2) // 2 + 1, j1:j2] = c[2 * n + 1]
+                i = (i1 + i2) // 2
+                X[i1:i2, j1:j2] = c[2 * n]
+                X[i : i + 1, j1:j2] = c[2 * n + 1]
+                f_X[i1:i2, j1:j2] = c[2 * n]
                 if n == N - 1:
-                    f_X[(i1 + i2) // 2 : i2, j1:j2] = c[2 * n + 1]
+                    f_X[i:i2, j1:j2] = c[2 * n + 1]
                 else:
-                    f_X[(i1 + i2) // 2 : (i1 + i2) // 2 + 1, j1:j2] = c[2 * n + 1]
+                    f_X[i : i + 1, j1:j2] = c[2 * n + 1]
 
     def task_frame(self, A, f_A, B, f_B):
         N = 3
@@ -301,7 +301,8 @@ class Lang(problem.Problem):
             f_A = prompt[:, 1 * w : 2 * w]
             B = prompt[:, 2 * w : 3 * w]
             f_B = answer
-            tasks[torch.randint(len(tasks), (1,))](A, f_A, B, f_B)
+            task = tasks[torch.randint(len(tasks), (1,))]
+            task(A, f_A, B, f_B)
         return prompts.flatten(1), answers.flatten(1)
 
     def save_quizzes(
@@ -328,9 +329,12 @@ class Lang(problem.Problem):
 if __name__ == "__main__":
     import time
 
-    lang = Lang(nb_iterations=4)
+    lang = Lang()
 
-    prompts, answers = lang.generate_prompts_and_answers(36)
+    start_time = time.perf_counter()
+    prompts, answers = lang.generate_prompts_and_answers(100)
+    delay = time.perf_counter() - start_time
+    print(f"{prompts.size(0)/delay:02f} seq/s")
 
     # predicted_prompts = torch.rand(prompts.size(0)) < 0.5
     # predicted_answers = torch.logical_not(predicted_prompts)
@@ -338,8 +342,8 @@ if __name__ == "__main__":
     lang.save_quizzes(
         "/tmp",
         "test",
-        prompts,
-        answers,
+        prompts[:36],
+        answers[:36],
         # You can add a bool to put a frame around the predicted parts
         # predicted_prompts, predicted_answers
     )