-def sample2img(seq, height, width, upscale=15):
- f_first = seq[:, : height * width].reshape(-1, height, width)
- f_second = seq[:, height * width + 1 :].reshape(-1, height, width)
- direction = seq[:, height * width]
+def frame2img(x, height, width, upscale=15):
+ x = x.reshape(-1, height, width)
+ m = torch.logical_and(x >= 0, x < first_bird_token + nb_bird_tokens).long()
+ x = colors[x * m].permute(0, 3, 1, 2)
+ s = x.shape
+ x = x[:, :, :, None, :, None].expand(-1, -1, -1, upscale, -1, upscale)
+ x = x.reshape(s[0], s[1], s[2] * upscale, s[3] * upscale)