X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=sky.py;h=4ca4ba7136b40a5324dcb64ba4c6a3a19523b3c5;hb=24b4eceaf1d057636e8a209a2bf52ddc85d01b57;hp=abcd394580c8bd443e1d44599ec1a92de57606f3;hpb=bee6e628aabc1380772409f6aabffb024c0e70ab;p=culture.git diff --git a/sky.py b/sky.py index abcd394..4ca4ba7 100755 --- a/sky.py +++ b/sky.py @@ -112,11 +112,20 @@ class Sky(problem.Problem): break result = torch.zeros( - self.nb_iterations, self.height, self.width, dtype=torch.int64 + self.nb_iterations * self.speed, + self.height, + self.width, + dtype=torch.int64, ) - for l in range(self.nb_iterations): - fine = collision_okay() + fine = torch.empty(self.nb_iterations * self.speed) + + t_to_keep = ( + torch.arange(self.nb_iterations, device=result.device) * self.speed + ) + + for l in range(self.nb_iterations * self.speed): + fine[l] = collision_okay() for n in range(self.nb_birds): c = col[n] result[l, i[n], j[n]] = c @@ -136,7 +145,10 @@ class Sky(problem.Problem): i[n] += vi[n] j[n] += vj[n] - if fine: + result = result[t_to_keep] + fine = fine[t_to_keep] + + if fine[-1]: break frame_sequences.append(result) @@ -145,6 +157,12 @@ class Sky(problem.Problem): ###################################################################### + def generate_prompts_and_answers(self, nb): + frame_sequences = self.generate_frame_sequences(nb) + prompts = frame_sequences[:, : frame_sequences.size(0) // 2].flatten(1) + answers = frame_sequences[:, frame_sequences.size(0) // 2 :].flatten(1) + return prompts, answers + def generate_token_sequences(self, nb): frame_sequences = self.generate_frame_sequences(nb) @@ -289,7 +307,7 @@ class Sky(problem.Problem): if __name__ == "__main__": import time - sky = Sky(height=6, width=8, speed=2, nb_iterations=2) + sky = Sky(height=6, width=8, speed=4, nb_iterations=2) start_time = time.perf_counter() token_sequences = sky.generate_token_sequences(nb=64)