Update.
[culture.git] / sky.py
diff --git a/sky.py b/sky.py
index 6ba3882..1164185 100755 (executable)
--- a/sky.py
+++ b/sky.py
@@ -118,8 +118,14 @@ class Sky(problem.Problem):
                     dtype=torch.int64,
                 )
 
+                fine = torch.empty(self.nb_iterations * self.speed)
+
+                t_to_keep = (
+                    torch.arange(self.nb_iterations, device=result.device) * self.speed
+                )
+
                 for l in range(self.nb_iterations * self.speed):
-                    fine = collision_okay()
+                    fine[l] = collision_okay()
                     for n in range(self.nb_birds):
                         c = col[n]
                         result[l, i[n], j[n]] = c
@@ -139,14 +145,13 @@ class Sky(problem.Problem):
                         i[n] += vi[n]
                         j[n] += vj[n]
 
-                if fine:
+                result = result[t_to_keep]
+                fine = fine[t_to_keep]
+
+                if fine[-1]:
                     break
 
-            frame_sequences.append(
-                result[
-                    torch.arange(self.nb_iterations, device=result.device) * self.speed
-                ]
-            )
+            frame_sequences.append(result)
 
         return frame_sequences
 
@@ -296,7 +301,7 @@ class Sky(problem.Problem):
 if __name__ == "__main__":
     import time
 
-    sky = Sky(height=6, width=8, speed=2, nb_iterations=2)
+    sky = Sky(height=6, width=8, speed=4, nb_iterations=2)
 
     start_time = time.perf_counter()
     token_sequences = sky.generate_token_sequences(nb=64)