Update.
[culture.git] / sky.py
diff --git a/sky.py b/sky.py
index 6ba3882..4ca4ba7 100755 (executable)
--- a/sky.py
+++ b/sky.py
@@ -118,8 +118,14 @@ class Sky(problem.Problem):
                     dtype=torch.int64,
                 )
 
+                fine = torch.empty(self.nb_iterations * self.speed)
+
+                t_to_keep = (
+                    torch.arange(self.nb_iterations, device=result.device) * self.speed
+                )
+
                 for l in range(self.nb_iterations * self.speed):
-                    fine = collision_okay()
+                    fine[l] = collision_okay()
                     for n in range(self.nb_birds):
                         c = col[n]
                         result[l, i[n], j[n]] = c
@@ -139,19 +145,24 @@ class Sky(problem.Problem):
                         i[n] += vi[n]
                         j[n] += vj[n]
 
-                if fine:
+                result = result[t_to_keep]
+                fine = fine[t_to_keep]
+
+                if fine[-1]:
                     break
 
-            frame_sequences.append(
-                result[
-                    torch.arange(self.nb_iterations, device=result.device) * self.speed
-                ]
-            )
+            frame_sequences.append(result)
 
         return frame_sequences
 
     ######################################################################
 
+    def generate_prompts_and_answers(self, nb):
+        frame_sequences = self.generate_frame_sequences(nb)
+        prompts = frame_sequences[:, : frame_sequences.size(0) // 2].flatten(1)
+        answers = frame_sequences[:, frame_sequences.size(0) // 2 :].flatten(1)
+        return prompts, answers
+
     def generate_token_sequences(self, nb):
         frame_sequences = self.generate_frame_sequences(nb)
 
@@ -296,7 +307,7 @@ class Sky(problem.Problem):
 if __name__ == "__main__":
     import time
 
-    sky = Sky(height=6, width=8, speed=2, nb_iterations=2)
+    sky = Sky(height=6, width=8, speed=4, nb_iterations=2)
 
     start_time = time.perf_counter()
     token_sequences = sky.generate_token_sequences(nb=64)