3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
10 import torch, torchvision
13 from torch.nn import functional as F
15 ######################################################################
18 colors = torch.tensor(
36 nb_bird_tokens = colors.size(0) - 1
37 token_forward = first_bird_token + nb_bird_tokens
38 token_backward = token_forward + 1
40 token2char = "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
44 nb, height, width, nb_birds=3, nb_iterations=2, return_iterations=False
49 for _ in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"):
53 f_start = torch.zeros(height, width, dtype=torch.int64)
56 torch.empty(nb_birds, dtype=torch.int64),
57 torch.empty(nb_birds, dtype=torch.int64),
58 torch.empty(nb_birds, dtype=torch.int64),
59 torch.empty(nb_birds, dtype=torch.int64),
62 col = torch.randperm(colors.size(0) - 1)[:nb_birds].sort().values + 1
64 for n in range(nb_birds):
69 torch.randint(height, (1,))[0],
70 torch.randint(width, (1,))[0],
72 vm = torch.randint(4, (1,))[0]
73 vi[n], vj[n] = (vm % 2) * 2 - 1, (vm // 2) * 2 - 1
76 and i[n] - vi[n] < height
78 and j[n] - vj[n] < width
79 and f_start[i[n], j[n]] == 0
80 and f_start[i[n] - vi[n], j[n]] == 0
81 and f_start[i[n], j[n] - vj[n]] == 0
85 f_start[i[n], j[n]] = c
86 f_start[i[n] - vi[n], j[n]] = c
87 f_start[i[n], j[n] - vj[n]] = c
89 f_end = f_start.clone()
91 for l in range(nb_iterations):
92 iterations.append(f_end.clone())
95 for n in range(nb_birds):
105 if (i[n] == 0 and vi[n] == -1) or (
106 i[n] == height - 1 and vi[n] == 1
109 if (j[n] == 0 and vj[n] == -1) or (
110 j[n] == width - 1 and vj[n] == 1
118 f_end[i[n], j[n]] == 0
119 and f_end[i[n] - vi[n], j[n]] == 0
120 and f_end[i[n], j[n] - vj[n]] == 0
124 f_end[i[n], j[n]] = c
125 f_end[i[n] - vi[n], j[n]] = c
126 f_end[i[n], j[n] - vj[n]] = c
128 iterations.append(f_end.clone())
130 if nb_collisions == 0:
133 kept_iterations.append(iterations)
134 pairs.append((f_start, f_end))
138 if torch.rand(1) < 0.5:
141 [p[0].flatten(), torch.tensor([token_forward]), p[1].flatten()],
148 [p[1].flatten(), torch.tensor([token_backward]), p[0].flatten()],
153 if return_iterations:
154 # iterations = torch.cat([ torch.cat([ x[None, None] for x in l], dim = 1) for l in kept_iterations ], dim=0)
155 return torch.cat(result, dim=0), kept_iterations
157 return torch.cat(result, dim=0)
160 ######################################################################
163 def generate_seq_old(
172 for n in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"):
173 f_start = torch.zeros(height, width, dtype=torch.int64)
174 f_end = torch.zeros(height, width, dtype=torch.int64)
175 n = torch.arange(f_start.size(0))
178 (torch.randperm(nb_bird_tokens) + first_bird_token)[:nb_birds].sort().values
181 torch.randint(height - 2, (1,))[0] + 1,
182 torch.randint(width - 2, (1,))[0] + 1,
184 vm = torch.randint(4, (1,))[0]
185 vi, vj = (vm // 2) * (2 * (vm % 2) - 1), (1 - vm // 2) * (2 * (vm % 2) - 1)
188 f_start[i - vi, j - vj] = c
189 f_start[i + vj, j - vi] = c
190 f_start[i - vj, j + vi] = c
192 for l in range(nb_iterations):
195 if i < 0 or i >= height or j < 0 or j >= width:
203 f_end[i - vi, j - vj] = c
204 f_end[i + vj, j - vi] = c
205 f_end[i - vj, j + vi] = c
207 pairs.append((f_start, f_end))
211 if torch.rand(1) < 0.5:
214 [p[0].flatten(), torch.tensor([token_forward]), p[1].flatten()],
221 [p[1].flatten(), torch.tensor([token_backward]), p[0].flatten()],
226 return torch.cat(result, dim=0)
229 def frame2img(x, height, width, upscale=15):
230 x = x.reshape(-1, height, width)
231 m = torch.logical_and(x >= 0, x < first_bird_token + nb_bird_tokens).long()
232 x = colors[x * m].permute(0, 3, 1, 2)
234 x = x[:, :, :, None, :, None].expand(-1, -1, -1, upscale, -1, upscale)
235 x = x.reshape(s[0], s[1], s[2] * upscale, s[3] * upscale)
237 x[:, :, :, torch.arange(0, x.size(3), upscale)] = 0
238 x[:, :, torch.arange(0, x.size(2), upscale), :] = 0
241 for n in range(m.size(0)):
242 for i in range(m.size(1)):
243 for j in range(m.size(2)):
245 for k in range(2, upscale - 2):
246 x[n, :, i * upscale + k, j * upscale + k] = 0
247 x[n, :, i * upscale + upscale - 1 - k, j * upscale + k] = 0
252 def seq2img(seq, height, width, upscale=15):
253 f_first = seq[:, : height * width].reshape(-1, height, width)
254 f_second = seq[:, height * width + 1 :].reshape(-1, height, width)
255 direction = seq[:, height * width]
257 direction_symbol = torch.full((direction.size(0), height * upscale - 1, upscale), 0)
258 direction_symbol = colors[direction_symbol].permute(0, 3, 1, 2)
259 separator = torch.full((direction.size(0), 3, height * upscale - 1, 1), 0)
261 for n in range(direction_symbol.size(0)):
262 if direction[n] == token_forward:
263 for k in range(upscale):
267 (height * upscale) // 2 - upscale // 2 + k,
268 3 + upscale // 2 - abs(k - upscale // 2),
270 elif direction[n] == token_backward:
271 for k in range(upscale):
275 (height * upscale) // 2 - upscale // 2 + k,
276 3 + abs(k - upscale // 2),
279 for k in range(2, upscale - 2):
281 n, :, (height * upscale) // 2 - upscale // 2 + k, k
284 n, :, (height * upscale) // 2 - upscale // 2 + k, upscale - 1 - k
289 frame2img(f_first, height, width, upscale),
293 frame2img(f_second, height, width, upscale),
302 result.append("".join([token2char[v] for v in s]))
306 ######################################################################
308 if __name__ == "__main__":
312 start_time = time.perf_counter()
313 seq, it = generate_seq(
314 nb=64, height=height, width=width, nb_iterations=100, return_iterations=True
316 delay = time.perf_counter() - start_time
317 print(f"{seq.size(0)/delay:02f} samples/s")
319 print(seq2str(seq[:4]))
321 for t in range(len(it[0])):
322 img = torch.cat([frame2img(f[t], height, width) for f in it], dim=0)
323 torchvision.utils.save_image(
325 f"/tmp/frame_{t:03d}.png",
331 # m = (torch.rand(seq.size()) < 0.05).long()
332 # seq = (1 - m) * seq + m * 23
334 img = seq2img(seq, height, width)
337 torchvision.utils.save_image(
338 img.float() / 255.0, "/tmp/world.png", nrow=6, padding=6, pad_value=0