break
result = torch.zeros(
- self.nb_iterations, self.height, self.width, dtype=torch.int64
+ self.nb_iterations * self.speed,
+ self.height,
+ self.width,
+ dtype=torch.int64,
)
- for l in range(self.nb_iterations):
- fine = collision_okay()
+ fine = torch.empty(self.nb_iterations * self.speed)
+
+ t_to_keep = (
+ torch.arange(self.nb_iterations, device=result.device) * self.speed
+ )
+
+ for l in range(self.nb_iterations * self.speed):
+ fine[l] = collision_okay()
for n in range(self.nb_birds):
c = col[n]
result[l, i[n], j[n]] = c
i[n] += vi[n]
j[n] += vj[n]
- if fine:
+ result = result[t_to_keep]
+ fine = fine[t_to_keep]
+
+ if fine[-1]:
break
frame_sequences.append(result)
if __name__ == "__main__":
import time
- sky = Sky(height=6, width=8, speed=2, nb_iterations=2)
+ sky = Sky(height=6, width=8, speed=4, nb_iterations=2)
start_time = time.perf_counter()
token_sequences = sky.generate_token_sequences(nb=64)