Update.
[culture.git] / sky.py
diff --git a/sky.py b/sky.py
index abcd394..1164185 100755 (executable)
--- a/sky.py
+++ b/sky.py
@@ -112,11 +112,20 @@ class Sky(problem.Problem):
                         break
 
                 result = torch.zeros(
-                    self.nb_iterations, self.height, self.width, dtype=torch.int64
+                    self.nb_iterations * self.speed,
+                    self.height,
+                    self.width,
+                    dtype=torch.int64,
                 )
 
-                for l in range(self.nb_iterations):
-                    fine = collision_okay()
+                fine = torch.empty(self.nb_iterations * self.speed)
+
+                t_to_keep = (
+                    torch.arange(self.nb_iterations, device=result.device) * self.speed
+                )
+
+                for l in range(self.nb_iterations * self.speed):
+                    fine[l] = collision_okay()
                     for n in range(self.nb_birds):
                         c = col[n]
                         result[l, i[n], j[n]] = c
@@ -136,7 +145,10 @@ class Sky(problem.Problem):
                         i[n] += vi[n]
                         j[n] += vj[n]
 
-                if fine:
+                result = result[t_to_keep]
+                fine = fine[t_to_keep]
+
+                if fine[-1]:
                     break
 
             frame_sequences.append(result)
@@ -289,7 +301,7 @@ class Sky(problem.Problem):
 if __name__ == "__main__":
     import time
 
-    sky = Sky(height=6, width=8, speed=2, nb_iterations=2)
+    sky = Sky(height=6, width=8, speed=4, nb_iterations=2)
 
     start_time = time.perf_counter()
     token_sequences = sky.generate_token_sequences(nb=64)