- f_end[i, j] = c
- f_end[i - vi, j - vj] = c
- f_end[i + vj, j - vi] = c
- f_end[i - vj, j + vi] = c
-
- pairs.append((f_start, f_end))
-
- result = []
- for p in pairs:
- if torch.rand(1) < 0.5:
- result.append(
- torch.cat(
- [p[0].flatten(), torch.tensor([token_forward]), p[1].flatten()],
- dim=0,
- )[None, :]
- )
- else:
- result.append(
- torch.cat(
- [p[1].flatten(), torch.tensor([token_backward]), p[0].flatten()],
- dim=0,
- )[None, :]
- )
-
- return torch.cat(result, dim=0)
-
-
-def frame2img(x, height, width, upscale=15):
- x = x.reshape(-1, height, width)
- m = torch.logical_and(x >= 0, x < first_bird_token + nb_bird_tokens).long()
- x = colors[x * m].permute(0, 3, 1, 2)
- s = x.shape
- x = x[:, :, :, None, :, None].expand(-1, -1, -1, upscale, -1, upscale)
- x = x.reshape(s[0], s[1], s[2] * upscale, s[3] * upscale)
-
- x[:, :, :, torch.arange(0, x.size(3), upscale)] = 0
- x[:, :, torch.arange(0, x.size(2), upscale), :] = 0
- x = x[:, :, 1:, 1:]
-
- for n in range(m.size(0)):
- for i in range(m.size(1)):
- for j in range(m.size(2)):
- if m[n, i, j] == 0:
- for k in range(2, upscale - 2):
- x[n, :, i * upscale + k, j * upscale + k] = 0
- x[n, :, i * upscale + upscale - 1 - k, j * upscale + k] = 0
-
- return x
-
-
-def seq2img(seq, height, width, upscale=15):
- f_first = seq[:, : height * width].reshape(-1, height, width)
- f_second = seq[:, height * width + 1 :].reshape(-1, height, width)
- direction = seq[:, height * width]
-
- direction_symbol = torch.full((direction.size(0), height * upscale - 1, upscale), 0)
- direction_symbol = colors[direction_symbol].permute(0, 3, 1, 2)
- separator = torch.full((direction.size(0), 3, height * upscale - 1, 1), 0)
-
- for n in range(direction_symbol.size(0)):
- if direction[n] == token_forward:
- for k in range(upscale):
- direction_symbol[
- n,
- :,
- (height * upscale) // 2 - upscale // 2 + k,
- 3 + upscale // 2 - abs(k - upscale // 2),
- ] = 0
- elif direction[n] == token_backward:
- for k in range(upscale):
- direction_symbol[
- n,
- :,
- (height * upscale) // 2 - upscale // 2 + k,
- 3 + abs(k - upscale // 2),
- ] = 0
- else:
- for k in range(2, upscale - 2):
- direction_symbol[
- n, :, (height * upscale) // 2 - upscale // 2 + k, k
- ] = 0
- direction_symbol[
- n, :, (height * upscale) // 2 - upscale // 2 + k, upscale - 1 - k
- ] = 0
-
- return torch.cat(
- [
- frame2img(f_first, height, width, upscale),
- separator,
- direction_symbol,
- separator,
- frame2img(f_second, height, width, upscale),
- ],
- dim=3,
- )