3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
10 import torch, torchvision
13 from torch.nn import functional as F
15 ######################################################################
18 colors = torch.tensor(
36 nb_bird_tokens = colors.size(0) - 1
37 token_forward = first_bird_token + nb_bird_tokens
38 token_backward = token_forward + 1
40 token2char = "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
44 def __init__(self, height, width):
49 nb, height, width, nb_birds=3, nb_iterations=2, return_iterations=False
54 for _ in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"):
58 f_start = torch.zeros(height, width, dtype=torch.int64)
61 torch.empty(nb_birds, dtype=torch.int64),
62 torch.empty(nb_birds, dtype=torch.int64),
63 torch.empty(nb_birds, dtype=torch.int64),
64 torch.empty(nb_birds, dtype=torch.int64),
67 col = torch.randperm(colors.size(0) - 1)[:nb_birds].sort().values + 1
69 for n in range(nb_birds):
74 torch.randint(height, (1,))[0],
75 torch.randint(width, (1,))[0],
77 vm = torch.randint(4, (1,))[0]
78 vi[n], vj[n] = (vm % 2) * 2 - 1, (vm // 2) * 2 - 1
81 and i[n] - vi[n] < height
83 and j[n] - vj[n] < width
84 and f_start[i[n], j[n]] == 0
85 and f_start[i[n] - vi[n], j[n]] == 0
86 and f_start[i[n], j[n] - vj[n]] == 0
90 f_start[i[n], j[n]] = c
91 f_start[i[n] - vi[n], j[n]] = c
92 f_start[i[n], j[n] - vj[n]] = c
94 f_end = f_start.clone()
96 for l in range(nb_iterations):
97 iterations.append(f_end.clone())
100 for n in range(nb_birds):
110 if (i[n] == 0 and vi[n] == -1) or (
111 i[n] == height - 1 and vi[n] == 1
114 if (j[n] == 0 and vj[n] == -1) or (
115 j[n] == width - 1 and vj[n] == 1
123 f_end[i[n], j[n]] == 0
124 and f_end[i[n] - vi[n], j[n]] == 0
125 and f_end[i[n], j[n] - vj[n]] == 0
129 f_end[i[n], j[n]] = c
130 f_end[i[n] - vi[n], j[n]] = c
131 f_end[i[n], j[n] - vj[n]] = c
133 iterations.append(f_end.clone())
135 if nb_collisions == 0:
138 kept_iterations.append(iterations)
139 pairs.append((f_start, f_end))
143 if torch.rand(1) < 0.5:
146 [p[0].flatten(), torch.tensor([token_forward]), p[1].flatten()],
155 torch.tensor([token_backward]),
162 if return_iterations:
163 # iterations = torch.cat([ torch.cat([ x[None, None] for x in l], dim = 1) for l in kept_iterations ], dim=0)
164 return torch.cat(result, dim=0), kept_iterations
166 return torch.cat(result, dim=0)
168 ######################################################################
170 def generate_seq_old(
179 for n in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"):
180 f_start = torch.zeros(height, width, dtype=torch.int64)
181 f_end = torch.zeros(height, width, dtype=torch.int64)
182 n = torch.arange(f_start.size(0))
185 (torch.randperm(nb_bird_tokens) + first_bird_token)[:nb_birds]
190 torch.randint(height - 2, (1,))[0] + 1,
191 torch.randint(width - 2, (1,))[0] + 1,
193 vm = torch.randint(4, (1,))[0]
194 vi, vj = (vm // 2) * (2 * (vm % 2) - 1), (1 - vm // 2) * (
199 f_start[i - vi, j - vj] = c
200 f_start[i + vj, j - vi] = c
201 f_start[i - vj, j + vi] = c
203 for l in range(nb_iterations):
206 if i < 0 or i >= height or j < 0 or j >= width:
214 f_end[i - vi, j - vj] = c
215 f_end[i + vj, j - vi] = c
216 f_end[i - vj, j + vi] = c
218 pairs.append((f_start, f_end))
222 if torch.rand(1) < 0.5:
225 [p[0].flatten(), torch.tensor([token_forward]), p[1].flatten()],
234 torch.tensor([token_backward]),
241 return torch.cat(result, dim=0)
243 def frame2img(x, height, width, upscale=15):
244 x = x.reshape(-1, height, width)
245 m = torch.logical_and(x >= 0, x < first_bird_token + nb_bird_tokens).long()
246 x = colors[x * m].permute(0, 3, 1, 2)
248 x = x[:, :, :, None, :, None].expand(-1, -1, -1, upscale, -1, upscale)
249 x = x.reshape(s[0], s[1], s[2] * upscale, s[3] * upscale)
251 x[:, :, :, torch.arange(0, x.size(3), upscale)] = 0
252 x[:, :, torch.arange(0, x.size(2), upscale), :] = 0
255 for n in range(m.size(0)):
256 for i in range(m.size(1)):
257 for j in range(m.size(2)):
259 for k in range(2, upscale - 2):
260 x[n, :, i * upscale + k, j * upscale + k] = 0
261 x[n, :, i * upscale + upscale - 1 - k, j * upscale + k] = 0
265 def seq2img(seq, height, width, upscale=15):
266 f_first = seq[:, : height * width].reshape(-1, height, width)
267 f_second = seq[:, height * width + 1 :].reshape(-1, height, width)
268 direction = seq[:, height * width]
270 direction_symbol = torch.full(
271 (direction.size(0), height * upscale - 1, upscale), 0
273 direction_symbol = colors[direction_symbol].permute(0, 3, 1, 2)
274 separator = torch.full((direction.size(0), 3, height * upscale - 1, 1), 0)
276 for n in range(direction_symbol.size(0)):
277 if direction[n] == token_forward:
278 for k in range(upscale):
282 (height * upscale) // 2 - upscale // 2 + k,
283 3 + upscale // 2 - abs(k - upscale // 2),
285 elif direction[n] == token_backward:
286 for k in range(upscale):
290 (height * upscale) // 2 - upscale // 2 + k,
291 3 + abs(k - upscale // 2),
294 for k in range(2, upscale - 2):
296 n, :, (height * upscale) // 2 - upscale // 2 + k, k
301 (height * upscale) // 2 - upscale // 2 + k,
307 frame2img(f_first, height, width, upscale),
311 frame2img(f_second, height, width, upscale),
319 result.append("".join([token2char[v] for v in s]))
323 ######################################################################
325 if __name__ == "__main__":
329 start_time = time.perf_counter()
330 seq, it = generate_seq(
331 nb=64, height=height, width=width, nb_iterations=100, return_iterations=True
333 delay = time.perf_counter() - start_time
334 print(f"{seq.size(0)/delay:02f} samples/s")
336 print(seq2str(seq[:4]))
338 for t in range(len(it[0])):
339 img = torch.cat([frame2img(f[t], height, width) for f in it], dim=0)
340 torchvision.utils.save_image(
342 f"/tmp/frame_{t:03d}.png",
348 # m = (torch.rand(seq.size()) < 0.05).long()
349 # seq = (1 - m) * seq + m * 23
351 img = seq2img(seq, height, width)
354 torchvision.utils.save_image(
355 img.float() / 255.0, "/tmp/world.png", nrow=6, padding=6, pad_value=0