3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
10 import torch, torchvision
13 from torch.nn import functional as F
15 ######################################################################
18 colors = torch.tensor(
36 nb_bird_tokens = colors.size(0) - 1
37 token_forward = first_bird_token + nb_bird_tokens
38 token_backward = token_forward + 1
40 token2char = "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
52 for _ in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"):
54 f_start = torch.zeros(height, width, dtype=torch.int64)
57 torch.empty(nb_birds, dtype=torch.int64),
58 torch.empty(nb_birds, dtype=torch.int64),
59 torch.empty(nb_birds, dtype=torch.int64),
60 torch.empty(nb_birds, dtype=torch.int64),
63 col = torch.randperm(colors.size(0) - 1)[:nb_birds].sort().values + 1
65 for n in range(nb_birds):
70 torch.randint(height, (1,))[0],
71 torch.randint(width, (1,))[0],
73 vm = torch.randint(4, (1,))[0]
74 vi[n], vj[n] = (vm % 2) * 2 - 1, (vm // 2) * 2 - 1
77 and i[n] - vi[n] < height
79 and j[n] - vj[n] < width
80 and f_start[i[n], j[n]] == 0
81 and f_start[i[n] - vi[n], j[n]] == 0
82 and f_start[i[n], j[n] - vj[n]] == 0
86 f_start[i[n], j[n]] = c
87 f_start[i[n] - vi[n], j[n]] = c
88 f_start[i[n], j[n] - vj[n]] = c
90 f_end = f_start.clone()
92 for l in range(nb_iterations):
95 for n in range(nb_birds):
105 if (i[n] == 0 and vi[n] == -1) or (
106 i[n] == height - 1 and vi[n] == 1
109 if (j[n] == 0 and vj[n] == -1) or (
110 j[n] == width - 1 and vj[n] == 1
118 f_end[i[n], j[n]] == 0
119 and f_end[i[n] - vi[n], j[n]] == 0
120 and f_end[i[n], j[n] - vj[n]] == 0
124 f_end[i[n], j[n]] = c
125 f_end[i[n] - vi[n], j[n]] = c
126 f_end[i[n], j[n] - vj[n]] = c
128 if nb_collisions == 0:
131 pairs.append((f_start, f_end))
135 if torch.rand(1) < 0.5:
138 [p[0].flatten(), torch.tensor([token_forward]), p[1].flatten()],
145 [p[1].flatten(), torch.tensor([token_backward]), p[0].flatten()],
150 return torch.cat(result, dim=0)
153 ######################################################################
156 def generate_seq_old(
165 for n in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"):
166 f_start = torch.zeros(height, width, dtype=torch.int64)
167 f_end = torch.zeros(height, width, dtype=torch.int64)
168 n = torch.arange(f_start.size(0))
171 (torch.randperm(nb_bird_tokens) + first_bird_token)[:nb_birds].sort().values
174 torch.randint(height - 2, (1,))[0] + 1,
175 torch.randint(width - 2, (1,))[0] + 1,
177 vm = torch.randint(4, (1,))[0]
178 vi, vj = (vm // 2) * (2 * (vm % 2) - 1), (1 - vm // 2) * (2 * (vm % 2) - 1)
181 f_start[i - vi, j - vj] = c
182 f_start[i + vj, j - vi] = c
183 f_start[i - vj, j + vi] = c
185 for l in range(nb_iterations):
188 if i < 0 or i >= height or j < 0 or j >= width:
196 f_end[i - vi, j - vj] = c
197 f_end[i + vj, j - vi] = c
198 f_end[i - vj, j + vi] = c
200 pairs.append((f_start, f_end))
204 if torch.rand(1) < 0.5:
207 [p[0].flatten(), torch.tensor([token_forward]), p[1].flatten()],
214 [p[1].flatten(), torch.tensor([token_backward]), p[0].flatten()],
219 return torch.cat(result, dim=0)
222 def sample2img(seq, height, width, upscale=15):
223 f_first = seq[:, : height * width].reshape(-1, height, width)
224 f_second = seq[:, height * width + 1 :].reshape(-1, height, width)
225 direction = seq[:, height * width]
227 def mosaic(x, upscale):
228 x = x.reshape(-1, height, width)
229 m = torch.logical_and(x >= 0, x < first_bird_token + nb_bird_tokens).long()
230 x = colors[x * m].permute(0, 3, 1, 2)
232 x = x[:, :, :, None, :, None].expand(-1, -1, -1, upscale, -1, upscale)
233 x = x.reshape(s[0], s[1], s[2] * upscale, s[3] * upscale)
235 x[:, :, :, torch.arange(0, x.size(3), upscale)] = 0
236 x[:, :, torch.arange(0, x.size(2), upscale), :] = 0
239 for n in range(m.size(0)):
240 for i in range(m.size(1)):
241 for j in range(m.size(2)):
243 for k in range(2, upscale - 2):
244 x[n, :, i * upscale + k, j * upscale + k] = 0
245 x[n, :, i * upscale + upscale - 1 - k, j * upscale + k] = 0
249 direction_symbol = torch.full((direction.size(0), height * upscale - 1, upscale), 0)
250 direction_symbol = colors[direction_symbol].permute(0, 3, 1, 2)
251 separator = torch.full((direction.size(0), 3, height * upscale - 1, 1), 0)
253 for n in range(direction_symbol.size(0)):
254 if direction[n] == token_forward:
255 for k in range(upscale):
259 (height * upscale) // 2 - upscale // 2 + k,
260 3 + upscale // 2 - abs(k - upscale // 2),
262 elif direction[n] == token_backward:
263 for k in range(upscale):
267 (height * upscale) // 2 - upscale // 2 + k,
268 3 + abs(k - upscale // 2),
271 for k in range(2, upscale - 2):
273 n, :, (height * upscale) // 2 - upscale // 2 + k, k
276 n, :, (height * upscale) // 2 - upscale // 2 + k, upscale - 1 - k
281 mosaic(f_first, upscale),
285 mosaic(f_second, upscale),
294 result.append("".join([token2char[v] for v in s]))
298 ######################################################################
300 if __name__ == "__main__":
304 start_time = time.perf_counter()
305 seq = generate_seq(nb=90, height=height, width=width)
306 delay = time.perf_counter() - start_time
307 print(f"{seq.size(0)/delay:02f} samples/s")
309 print(seq2str(seq[:4]))
311 # m = (torch.rand(seq.size()) < 0.05).long()
312 # seq = (1 - m) * seq + m * 23
314 img = sample2img(seq, height, width)
317 torchvision.utils.save_image(
318 img.float() / 255.0, "/tmp/world.png", nrow=6, padding=6, pad_value=0