dtype=torch.int64,
)
+ fine = torch.empty(self.nb_iterations * self.speed)
+
+ t_to_keep = (
+ torch.arange(self.nb_iterations, device=result.device) * self.speed
+ )
+
for l in range(self.nb_iterations * self.speed):
- fine = collision_okay()
+ fine[l] = collision_okay()
for n in range(self.nb_birds):
c = col[n]
result[l, i[n], j[n]] = c
i[n] += vi[n]
j[n] += vj[n]
- if fine:
+ result = result[t_to_keep]
+ fine = fine[t_to_keep]
+
+ if fine[-1]:
break
- frame_sequences.append(
- result[
- torch.arange(self.nb_iterations, device=result.device) * self.speed
- ]
- )
+ frame_sequences.append(result)
return frame_sequences
######################################################################
+ def generate_prompts_and_answers(self, nb):
+ frame_sequences = self.generate_frame_sequences(nb)
+ prompts = frame_sequences[:, : frame_sequences.size(0) // 2].flatten(1)
+ answers = frame_sequences[:, frame_sequences.size(0) // 2 :].flatten(1)
+ return prompts, answers
+
def generate_token_sequences(self, nb):
frame_sequences = self.generate_frame_sequences(nb)
if __name__ == "__main__":
import time
- sky = Sky(height=6, width=8, speed=2, nb_iterations=2)
+ sky = Sky(height=6, width=8, speed=4, nb_iterations=2)
start_time = time.perf_counter()
token_sequences = sky.generate_token_sequences(nb=64)