- def generate_seq_old(
- nb,
- height,
- width,
- nb_birds=3,
- nb_iterations=2,
- ):
- pairs = []
-
- for n in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"):
- f_start = torch.zeros(height, width, dtype=torch.int64)
- f_end = torch.zeros(height, width, dtype=torch.int64)
- n = torch.arange(f_start.size(0))
-
- for c in (
- (torch.randperm(nb_bird_tokens) + first_bird_token)[:nb_birds]
- .sort()
- .values
- ):
- i, j = (
- torch.randint(height - 2, (1,))[0] + 1,
- torch.randint(width - 2, (1,))[0] + 1,
- )
- vm = torch.randint(4, (1,))[0]
- vi, vj = (vm // 2) * (2 * (vm % 2) - 1), (1 - vm // 2) * (
- 2 * (vm % 2) - 1
- )
-
- f_start[i, j] = c
- f_start[i - vi, j - vj] = c
- f_start[i + vj, j - vi] = c
- f_start[i - vj, j + vi] = c
-
- for l in range(nb_iterations):
- i += vi
- j += vj
- if i < 0 or i >= height or j < 0 or j >= width:
- i -= vi
- j -= vj
- vi, vj = -vi, -vj
- i += vi
- j += vj
-
- f_end[i, j] = c
- f_end[i - vi, j - vj] = c
- f_end[i + vj, j - vi] = c
- f_end[i - vj, j + vi] = c
-
- pairs.append((f_start, f_end))
-
- result = []
- for p in pairs:
- if torch.rand(1) < 0.5:
- result.append(
- torch.cat(
- [p[0].flatten(), torch.tensor([token_forward]), p[1].flatten()],
- dim=0,
- )[None, :]
- )
- else:
- result.append(
- torch.cat(
- [
- p[1].flatten(),
- torch.tensor([token_backward]),
- p[0].flatten(),
- ],
- dim=0,
- )[None, :]
- )
-
- return torch.cat(result, dim=0)
-
- def frame2img(x, height, width, upscale=15):
- x = x.reshape(-1, height, width)
- m = torch.logical_and(x >= 0, x < first_bird_token + nb_bird_tokens).long()
- x = colors[x * m].permute(0, 3, 1, 2)
+ def frame2img(self, x, scale=15):
+ x = x.reshape(x.size(0), self.height, -1)
+ m = torch.logical_and(
+ x >= 0, x < self.first_bird_token + self.nb_bird_tokens
+ ).long()
+ x = self.colors[x * m].permute(0, 3, 1, 2)