X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=sky.py;h=11641853d8e8081f1cc7d0cf63d112f1ba30b518;hb=6b4e192557e03528ffd10364123de454aa9c9f08;hp=abcd394580c8bd443e1d44599ec1a92de57606f3;hpb=bee6e628aabc1380772409f6aabffb024c0e70ab;p=culture.git diff --git a/sky.py b/sky.py index abcd394..1164185 100755 --- a/sky.py +++ b/sky.py @@ -112,11 +112,20 @@ class Sky(problem.Problem): break result = torch.zeros( - self.nb_iterations, self.height, self.width, dtype=torch.int64 + self.nb_iterations * self.speed, + self.height, + self.width, + dtype=torch.int64, ) - for l in range(self.nb_iterations): - fine = collision_okay() + fine = torch.empty(self.nb_iterations * self.speed) + + t_to_keep = ( + torch.arange(self.nb_iterations, device=result.device) * self.speed + ) + + for l in range(self.nb_iterations * self.speed): + fine[l] = collision_okay() for n in range(self.nb_birds): c = col[n] result[l, i[n], j[n]] = c @@ -136,7 +145,10 @@ class Sky(problem.Problem): i[n] += vi[n] j[n] += vj[n] - if fine: + result = result[t_to_keep] + fine = fine[t_to_keep] + + if fine[-1]: break frame_sequences.append(result) @@ -289,7 +301,7 @@ class Sky(problem.Problem): if __name__ == "__main__": import time - sky = Sky(height=6, width=8, speed=2, nb_iterations=2) + sky = Sky(height=6, width=8, speed=4, nb_iterations=2) start_time = time.perf_counter() token_sequences = sky.generate_token_sequences(nb=64)