+)
+
+token2char = "_X" + "".join([str(n) for n in range(len(colors) - 2)]) + ">"
+
+
+def generate(
+ nb,
+ height,
+ width,
+ max_nb_obj=len(colors) - 2,
+ nb_iterations=2,
+):
+ f_start = torch.zeros(nb, height, width, dtype=torch.int64)
+ f_end = torch.zeros(nb, height, width, dtype=torch.int64)
+ n = torch.arange(f_start.size(0))
+
+ for n in range(nb):
+ nb_fish = torch.randint(max_nb_obj, (1,)).item() + 1
+ for c in range(nb_fish):
+ i, j = (
+ torch.randint(height - 2, (1,))[0] + 1,
+ torch.randint(width - 2, (1,))[0] + 1,
+ )
+ vm = torch.randint(4, (1,))[0]
+ vi, vj = (vm // 2) * (2 * (vm % 2) - 1), (1 - vm // 2) * (2 * (vm % 2) - 1)
+
+ f_start[n, i, j] = c + 2
+ f_start[n, i - vi, j - vj] = c + 2
+ f_start[n, i + vj, j - vi] = c + 2
+ f_start[n, i - vj, j + vi] = c + 2
+
+ for l in range(nb_iterations):
+ i += vi
+ j += vj
+ if i < 0 or i >= height or j < 0 or j >= width:
+ i -= vi
+ j -= vj
+ vi, vj = -vi, -vj
+ i += vi
+ j += vj
+
+ f_end[n, i, j] = c + 2
+ f_end[n, i - vi, j - vj] = c + 2
+ f_end[n, i + vj, j - vi] = c + 2
+ f_end[n, i - vj, j + vi] = c + 2
+
+ return torch.cat(
+ [
+ f_end.flatten(1),
+ torch.full((f_end.size(0), 1), len(colors)),
+ f_start.flatten(1),
+ ],
+ dim=1,
+ )
+
+
+def sample2img(seq, height, width):
+ f_start = seq[:, : height * width].reshape(-1, height, width)
+ f_start = (f_start >= len(colors)).long() + (f_start < len(colors)).long() * f_start
+ f_end = seq[:, height * width + 1 :].reshape(-1, height, width)
+ f_end = (f_end >= len(colors)).long() + (f_end < len(colors)).long() * f_end
+
+ img_f_start, img_f_end = colors[f_start], colors[f_end]
+
+ img = torch.cat(
+ [
+ img_f_start,
+ torch.full(
+ (img_f_start.size(0), img_f_start.size(1), 1, img_f_start.size(3)), 1
+ ),
+ img_f_end,
+ ],
+ dim=2,
+ )
+
+ return img.permute(0, 3, 1, 2)
+
+
+def seq2str(seq):
+ result = []
+ for s in seq:
+ result.append("".join([token2char[v] for v in s]))
+ return result