X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=sky.py;h=4ca4ba7136b40a5324dcb64ba4c6a3a19523b3c5;hb=c9c018e4c19ce92892d7652082fb90719d57441c;hp=6ba3882f681fb81e99b617385bbb430c310562f8;hpb=29679cb42710602037fee650a5672f01a3292077;p=culture.git diff --git a/sky.py b/sky.py index 6ba3882..4ca4ba7 100755 --- a/sky.py +++ b/sky.py @@ -118,8 +118,14 @@ class Sky(problem.Problem): dtype=torch.int64, ) + fine = torch.empty(self.nb_iterations * self.speed) + + t_to_keep = ( + torch.arange(self.nb_iterations, device=result.device) * self.speed + ) + for l in range(self.nb_iterations * self.speed): - fine = collision_okay() + fine[l] = collision_okay() for n in range(self.nb_birds): c = col[n] result[l, i[n], j[n]] = c @@ -139,19 +145,24 @@ class Sky(problem.Problem): i[n] += vi[n] j[n] += vj[n] - if fine: + result = result[t_to_keep] + fine = fine[t_to_keep] + + if fine[-1]: break - frame_sequences.append( - result[ - torch.arange(self.nb_iterations, device=result.device) * self.speed - ] - ) + frame_sequences.append(result) return frame_sequences ###################################################################### + def generate_prompts_and_answers(self, nb): + frame_sequences = self.generate_frame_sequences(nb) + prompts = frame_sequences[:, : frame_sequences.size(0) // 2].flatten(1) + answers = frame_sequences[:, frame_sequences.size(0) // 2 :].flatten(1) + return prompts, answers + def generate_token_sequences(self, nb): frame_sequences = self.generate_frame_sequences(nb) @@ -296,7 +307,7 @@ class Sky(problem.Problem): if __name__ == "__main__": import time - sky = Sky(height=6, width=8, speed=2, nb_iterations=2) + sky = Sky(height=6, width=8, speed=4, nb_iterations=2) start_time = time.perf_counter() token_sequences = sky.generate_token_sequences(nb=64)