return torch.cat(result, dim=0)
- def frame2img(self, x, upscale=15):
+ def frame2img(self, x, scale=15):
x = x.reshape(-1, self.height, self.width)
m = torch.logical_and(
x >= 0, x < self.first_bird_token + self.nb_bird_tokens
).long()
x = self.colors[x * m].permute(0, 3, 1, 2)
s = x.shape
- x = x[:, :, :, None, :, None].expand(-1, -1, -1, upscale, -1, upscale)
- x = x.reshape(s[0], s[1], s[2] * upscale, s[3] * upscale)
+ x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
+ x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
- x[:, :, :, torch.arange(0, x.size(3), upscale)] = 0
- x[:, :, torch.arange(0, x.size(2), upscale), :] = 0
+ x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
+ x[:, :, torch.arange(0, x.size(2), scale), :] = 0
x = x[:, :, 1:, 1:]
for n in range(m.size(0)):
for i in range(m.size(1)):
for j in range(m.size(2)):
if m[n, i, j] == 0:
- for k in range(2, upscale - 2):
- x[n, :, i * upscale + k, j * upscale + k] = 0
- x[n, :, i * upscale + upscale - 1 - k, j * upscale + k] = 0
+ for k in range(2, scale - 2):
+ for l in [0, 1]:
+ x[n, :, i * scale + k, j * scale + k - l] = 0
+ x[
+ n, :, i * scale + scale - 1 - k, j * scale + k - l
+ ] = 0
return x
- def seq2img(self, seq, upscale=15):
+ def seq2img(self, seq, scale=15):
f_first = seq[:, : self.height * self.width].reshape(
-1, self.height, self.width
)
direction = seq[:, self.height * self.width]
direction_symbol = torch.full(
- (direction.size(0), self.height * upscale - 1, upscale), 0
+ (direction.size(0), self.height * scale - 1, scale), 0
)
direction_symbol = self.colors[direction_symbol].permute(0, 3, 1, 2)
- separator = torch.full((direction.size(0), 3, self.height * upscale - 1, 1), 0)
+ separator = torch.full((direction.size(0), 3, self.height * scale - 1, 1), 0)
for n in range(direction_symbol.size(0)):
if direction[n] == self.token_forward:
- for k in range(upscale):
- direction_symbol[
- n,
- :,
- (self.height * upscale) // 2 - upscale // 2 + k,
- 3 + upscale // 2 - abs(k - upscale // 2),
- ] = 0
+ for k in range(scale):
+ for l in [0, 1]:
+ direction_symbol[
+ n,
+ :,
+ (self.height * scale) // 2 - scale // 2 + k - l,
+ 3 + scale // 2 - abs(k - scale // 2),
+ ] = 0
elif direction[n] == self.token_backward:
- for k in range(upscale):
- direction_symbol[
- n,
- :,
- (self.height * upscale) // 2 - upscale // 2 + k,
- 3 + abs(k - upscale // 2),
- ] = 0
+ for k in range(scale):
+ for l in [0, 1]:
+ direction_symbol[
+ n,
+ :,
+ (self.height * scale) // 2 - scale // 2 + k - l,
+ 3 + abs(k - scale // 2),
+ ] = 0
else:
- for k in range(2, upscale - 2):
- direction_symbol[
- n, :, (self.height * upscale) // 2 - upscale // 2 + k, k
- ] = 0
- direction_symbol[
- n,
- :,
- (self.height * upscale) // 2 - upscale // 2 + k,
- upscale - 1 - k,
- ] = 0
+ for k in range(2, scale - 2):
+ for l in [0, 1]:
+ direction_symbol[
+ n,
+ :,
+ (self.height * scale) // 2 - scale // 2 + k - l,
+ k,
+ ] = 0
+ direction_symbol[
+ n,
+ :,
+ (self.height * scale) // 2 - scale // 2 + k - l,
+ scale - 1 - k,
+ ] = 0
return torch.cat(
[
- self.frame2img(f_first, upscale),
+ self.frame2img(f_first, scale),
separator,
direction_symbol,
separator,
- self.frame2img(f_second, upscale),
+ self.frame2img(f_second, scale),
],
dim=3,
)