+ def mosaic(x, upscale):
+ x = x.reshape(-1, height, width)
+ m = torch.logical_and(x >= 0, x < first_fish_token + nb_fish_tokens).long()
+ x = colors[x * m].permute(0, 3, 1, 2)
+ s = x.shape
+ x = x[:, :, :, None, :, None].expand(-1, -1, -1, upscale, -1, upscale)
+ x = x.reshape(s[0], s[1], s[2] * upscale, s[3] * upscale)
+
+ for n in range(m.size(0)):
+ for i in range(m.size(1)):
+ for j in range(m.size(2)):
+ if m[n, i, j] == 0:
+ for k in range(2, upscale - 2):
+ x[n, :, i * upscale + k, j * upscale + k] = 0
+ x[n, :, i * upscale + upscale - 1 - k, j * upscale + k] = 0
+
+ return x
+
+ return torch.cat([mosaic(f_start, upscale), mosaic(f_end, upscale)], dim=3)