3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, tqdm, os
10 import torch, torchvision
13 from torch.nn import functional as F
15 ######################################################################
19 def generate_seq(self, nb_train_samples):
22 def save_quizzes(self, input, result_dir, filename_prefix, logger):
25 def direction_tokens(self):
30 colors = torch.tensor(
48 nb_bird_tokens = colors.size(0) - 1
49 token_forward = first_bird_token + nb_bird_tokens
50 token_backward = token_forward + 1
53 "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
56 def __init__(self, height=6, width=8, nb_birds=3, nb_iterations=2):
59 self.nb_birds = nb_birds
60 self.nb_iterations = nb_iterations
62 def direction_tokens(self):
63 return self.token_forward, self.token_backward
65 def generate_seq(self, nb, return_iterations=False):
69 for _ in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"):
73 f_start = torch.zeros(self.height, self.width, dtype=torch.int64)
76 torch.empty(self.nb_birds, dtype=torch.int64),
77 torch.empty(self.nb_birds, dtype=torch.int64),
78 torch.empty(self.nb_birds, dtype=torch.int64),
79 torch.empty(self.nb_birds, dtype=torch.int64),
83 torch.randperm(self.colors.size(0) - 1)[: self.nb_birds]
89 for n in range(self.nb_birds):
94 torch.randint(self.height, (1,))[0],
95 torch.randint(self.width, (1,))[0],
97 vm = torch.randint(4, (1,))[0]
98 vi[n], vj[n] = (vm % 2) * 2 - 1, (vm // 2) * 2 - 1
101 and i[n] - vi[n] < self.height
102 and j[n] - vj[n] >= 0
103 and j[n] - vj[n] < self.width
104 and f_start[i[n], j[n]] == 0
105 and f_start[i[n] - vi[n], j[n]] == 0
106 and f_start[i[n], j[n] - vj[n]] == 0
110 f_start[i[n], j[n]] = c
111 f_start[i[n] - vi[n], j[n]] = c
112 f_start[i[n], j[n] - vj[n]] = c
114 f_end = f_start.clone()
116 for l in range(self.nb_iterations):
117 iterations.append(f_end.clone())
120 for n in range(self.nb_birds):
130 if (i[n] == 0 and vi[n] == -1) or (
131 i[n] == self.height - 1 and vi[n] == 1
134 if (j[n] == 0 and vj[n] == -1) or (
135 j[n] == self.width - 1 and vj[n] == 1
143 f_end[i[n], j[n]] == 0
144 and f_end[i[n] - vi[n], j[n]] == 0
145 and f_end[i[n], j[n] - vj[n]] == 0
149 f_end[i[n], j[n]] = c
150 f_end[i[n] - vi[n], j[n]] = c
151 f_end[i[n], j[n] - vj[n]] = c
153 iterations.append(f_end.clone())
155 if nb_collisions == 0:
158 kept_iterations.append(iterations)
159 pairs.append((f_start, f_end))
163 if torch.rand(1) < 0.5:
168 torch.tensor([self.token_forward]),
179 torch.tensor([self.token_backward]),
186 if return_iterations:
187 # iterations = torch.cat([ torch.cat([ x[None, None] for x in l], dim = 1) for l in kept_iterations ], dim=0)
188 return torch.cat(result, dim=0), kept_iterations
190 return torch.cat(result, dim=0)
192 ######################################################################
194 def generate_seq_old(
200 for n in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"):
201 f_start = torch.zeros(self.height, self.width, dtype=torch.int64)
202 f_end = torch.zeros(self.height, self.width, dtype=torch.int64)
203 n = torch.arange(f_start.size(0))
206 (torch.randperm(self.nb_bird_tokens) + self.first_bird_token)[
213 torch.randint(self.height - 2, (1,))[0] + 1,
214 torch.randint(self.width - 2, (1,))[0] + 1,
216 vm = torch.randint(4, (1,))[0]
217 vi, vj = (vm // 2) * (2 * (vm % 2) - 1), (1 - vm // 2) * (
222 f_start[i - vi, j - vj] = c
223 f_start[i + vj, j - vi] = c
224 f_start[i - vj, j + vi] = c
226 for l in range(self.nb_iterations):
229 if i < 0 or i >= self.height or j < 0 or j >= self.width:
237 f_end[i - vi, j - vj] = c
238 f_end[i + vj, j - vi] = c
239 f_end[i - vj, j + vi] = c
241 pairs.append((f_start, f_end))
245 if torch.rand(1) < 0.5:
250 torch.tensor([self.token_forward]),
261 torch.tensor([self.token_backward]),
268 return torch.cat(result, dim=0)
270 def frame2img(self, x, upscale=15):
271 x = x.reshape(-1, self.height, self.width)
272 m = torch.logical_and(
273 x >= 0, x < self.first_bird_token + self.nb_bird_tokens
275 x = self.colors[x * m].permute(0, 3, 1, 2)
277 x = x[:, :, :, None, :, None].expand(-1, -1, -1, upscale, -1, upscale)
278 x = x.reshape(s[0], s[1], s[2] * upscale, s[3] * upscale)
280 x[:, :, :, torch.arange(0, x.size(3), upscale)] = 0
281 x[:, :, torch.arange(0, x.size(2), upscale), :] = 0
284 for n in range(m.size(0)):
285 for i in range(m.size(1)):
286 for j in range(m.size(2)):
288 for k in range(2, upscale - 2):
289 x[n, :, i * upscale + k, j * upscale + k] = 0
290 x[n, :, i * upscale + upscale - 1 - k, j * upscale + k] = 0
294 def seq2img(self, seq, upscale=15):
295 f_first = seq[:, : self.height * self.width].reshape(
296 -1, self.height, self.width
298 f_second = seq[:, self.height * self.width + 1 :].reshape(
299 -1, self.height, self.width
301 direction = seq[:, self.height * self.width]
303 direction_symbol = torch.full(
304 (direction.size(0), self.height * upscale - 1, upscale), 0
306 direction_symbol = self.colors[direction_symbol].permute(0, 3, 1, 2)
307 separator = torch.full((direction.size(0), 3, self.height * upscale - 1, 1), 0)
309 for n in range(direction_symbol.size(0)):
310 if direction[n] == self.token_forward:
311 for k in range(upscale):
315 (self.height * upscale) // 2 - upscale // 2 + k,
316 3 + upscale // 2 - abs(k - upscale // 2),
318 elif direction[n] == self.token_backward:
319 for k in range(upscale):
323 (self.height * upscale) // 2 - upscale // 2 + k,
324 3 + abs(k - upscale // 2),
327 for k in range(2, upscale - 2):
329 n, :, (self.height * upscale) // 2 - upscale // 2 + k, k
334 (self.height * upscale) // 2 - upscale // 2 + k,
340 self.frame2img(f_first, upscale),
344 self.frame2img(f_second, upscale),
349 def seq2str(self, seq):
352 result.append("".join([self.token2char[v] for v in s]))
355 def save_image(self, input, result_dir, filename, logger):
356 img = self.seq2img(input.to("cpu"))
357 image_name = os.path.join(result_dir, filename)
358 torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
359 logger(f"wrote {image_name}")
361 def save_quizzes(self, input, result_dir, filename_prefix, logger):
362 self.save_image(input, result_dir, filename_prefix + ".png", logger)
365 ######################################################################
367 if __name__ == "__main__":
370 sky = Sky(height=6, width=8, nb_iterations=100)
372 start_time = time.perf_counter()
373 seq, it = sky.generate_seq(nb=64, return_iterations=True)
374 delay = time.perf_counter() - start_time
375 print(f"{seq.size(0)/delay:02f} samples/s")
377 print(sky.seq2str(seq[:4]))
379 for t in range(len(it[0])):
380 img = torch.cat([sky.frame2img(f[t]) for f in it], dim=0)
381 torchvision.utils.save_image(
383 f"/tmp/frame_{t:03d}.png",
389 # m = (torch.rand(seq.size()) < 0.05).long()
390 # seq = (1 - m) * seq + m * 23
392 img = sky.seq2img(seq)
395 torchvision.utils.save_image(
396 img.float() / 255.0, "/tmp/world.png", nrow=6, padding=6, pad_value=0