######################################################################
def frame2img(self, x, scale=15):
- x = x.reshape(-1, self.height, self.width)
+ x = x.reshape(x.size(0), self.height, -1)
m = torch.logical_and(
x >= 0, x < self.first_bird_token + self.nb_bird_tokens
).long()
if __name__ == "__main__":
import time
- sky = Sky(height=6, width=8, speed=4, nb_iterations=2)
+ sky = Sky(height=6, width=8, speed=1, nb_iterations=4)
prompts, answers = sky.generate_prompts_and_answers(4)