3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, tqdm, os
10 import torch, torchvision
13 from torch.nn import functional as F
15 ######################################################################
20 class Sky(problem.Problem):
21 colors = torch.tensor(
39 nb_bird_tokens = colors.size(0) - 1
40 token_forward = first_bird_token + nb_bird_tokens
41 token_backward = token_forward + 1
44 "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
47 def __init__(self, height=6, width=8, nb_birds=3, nb_iterations=2):
50 self.nb_birds = nb_birds
51 self.nb_iterations = nb_iterations
53 def direction_tokens(self):
54 return self.token_forward, self.token_backward
56 def generate_seq(self, nb, return_iterations=False):
60 for _ in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"):
64 f_start = torch.zeros(self.height, self.width, dtype=torch.int64)
67 torch.empty(self.nb_birds, dtype=torch.int64),
68 torch.empty(self.nb_birds, dtype=torch.int64),
69 torch.empty(self.nb_birds, dtype=torch.int64),
70 torch.empty(self.nb_birds, dtype=torch.int64),
74 torch.randperm(self.colors.size(0) - 1)[: self.nb_birds]
80 for n in range(self.nb_birds):
85 torch.randint(self.height, (1,))[0],
86 torch.randint(self.width, (1,))[0],
88 vm = torch.randint(4, (1,))[0]
89 vi[n], vj[n] = (vm % 2) * 2 - 1, (vm // 2) * 2 - 1
92 and i[n] - vi[n] < self.height
94 and j[n] - vj[n] < self.width
95 and f_start[i[n], j[n]] == 0
96 and f_start[i[n] - vi[n], j[n]] == 0
97 and f_start[i[n], j[n] - vj[n]] == 0
101 f_start[i[n], j[n]] = c
102 f_start[i[n] - vi[n], j[n]] = c
103 f_start[i[n], j[n] - vj[n]] = c
105 f_end = f_start.clone()
107 for l in range(self.nb_iterations):
108 iterations.append(f_end.clone())
111 for n in range(self.nb_birds):
121 if (i[n] == 0 and vi[n] == -1) or (
122 i[n] == self.height - 1 and vi[n] == 1
125 if (j[n] == 0 and vj[n] == -1) or (
126 j[n] == self.width - 1 and vj[n] == 1
134 f_end[i[n], j[n]] == 0
135 and f_end[i[n] - vi[n], j[n]] == 0
136 and f_end[i[n], j[n] - vj[n]] == 0
140 f_end[i[n], j[n]] = c
141 f_end[i[n] - vi[n], j[n]] = c
142 f_end[i[n], j[n] - vj[n]] = c
144 iterations.append(f_end.clone())
146 if nb_collisions == 0:
149 kept_iterations.append(iterations)
150 pairs.append((f_start, f_end))
154 if torch.rand(1) < 0.5:
159 torch.tensor([self.token_forward]),
170 torch.tensor([self.token_backward]),
177 if return_iterations:
178 # iterations = torch.cat([ torch.cat([ x[None, None] for x in l], dim = 1) for l in kept_iterations ], dim=0)
179 return torch.cat(result, dim=0), kept_iterations
181 return torch.cat(result, dim=0)
183 ######################################################################
185 def generate_seq_old(
191 for n in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"):
192 f_start = torch.zeros(self.height, self.width, dtype=torch.int64)
193 f_end = torch.zeros(self.height, self.width, dtype=torch.int64)
194 n = torch.arange(f_start.size(0))
197 (torch.randperm(self.nb_bird_tokens) + self.first_bird_token)[
204 torch.randint(self.height - 2, (1,))[0] + 1,
205 torch.randint(self.width - 2, (1,))[0] + 1,
207 vm = torch.randint(4, (1,))[0]
208 vi, vj = (vm // 2) * (2 * (vm % 2) - 1), (1 - vm // 2) * (
213 f_start[i - vi, j - vj] = c
214 f_start[i + vj, j - vi] = c
215 f_start[i - vj, j + vi] = c
217 for l in range(self.nb_iterations):
220 if i < 0 or i >= self.height or j < 0 or j >= self.width:
228 f_end[i - vi, j - vj] = c
229 f_end[i + vj, j - vi] = c
230 f_end[i - vj, j + vi] = c
232 pairs.append((f_start, f_end))
236 if torch.rand(1) < 0.5:
241 torch.tensor([self.token_forward]),
252 torch.tensor([self.token_backward]),
259 return torch.cat(result, dim=0)
261 def frame2img(self, x, upscale=15):
262 x = x.reshape(-1, self.height, self.width)
263 m = torch.logical_and(
264 x >= 0, x < self.first_bird_token + self.nb_bird_tokens
266 x = self.colors[x * m].permute(0, 3, 1, 2)
268 x = x[:, :, :, None, :, None].expand(-1, -1, -1, upscale, -1, upscale)
269 x = x.reshape(s[0], s[1], s[2] * upscale, s[3] * upscale)
271 x[:, :, :, torch.arange(0, x.size(3), upscale)] = 0
272 x[:, :, torch.arange(0, x.size(2), upscale), :] = 0
275 for n in range(m.size(0)):
276 for i in range(m.size(1)):
277 for j in range(m.size(2)):
279 for k in range(2, upscale - 2):
280 x[n, :, i * upscale + k, j * upscale + k] = 0
281 x[n, :, i * upscale + upscale - 1 - k, j * upscale + k] = 0
285 def seq2img(self, seq, upscale=15):
286 f_first = seq[:, : self.height * self.width].reshape(
287 -1, self.height, self.width
289 f_second = seq[:, self.height * self.width + 1 :].reshape(
290 -1, self.height, self.width
292 direction = seq[:, self.height * self.width]
294 direction_symbol = torch.full(
295 (direction.size(0), self.height * upscale - 1, upscale), 0
297 direction_symbol = self.colors[direction_symbol].permute(0, 3, 1, 2)
298 separator = torch.full((direction.size(0), 3, self.height * upscale - 1, 1), 0)
300 for n in range(direction_symbol.size(0)):
301 if direction[n] == self.token_forward:
302 for k in range(upscale):
306 (self.height * upscale) // 2 - upscale // 2 + k,
307 3 + upscale // 2 - abs(k - upscale // 2),
309 elif direction[n] == self.token_backward:
310 for k in range(upscale):
314 (self.height * upscale) // 2 - upscale // 2 + k,
315 3 + abs(k - upscale // 2),
318 for k in range(2, upscale - 2):
320 n, :, (self.height * upscale) // 2 - upscale // 2 + k, k
325 (self.height * upscale) // 2 - upscale // 2 + k,
331 self.frame2img(f_first, upscale),
335 self.frame2img(f_second, upscale),
340 def seq2str(self, seq):
343 result.append("".join([self.token2char[v] for v in s]))
346 def save_image(self, input, result_dir, filename, logger):
347 img = self.seq2img(input.to("cpu"))
348 image_name = os.path.join(result_dir, filename)
349 torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
350 logger(f"wrote {image_name}")
352 def save_quizzes(self, input, result_dir, filename_prefix, logger):
353 self.save_image(input, result_dir, filename_prefix + ".png", logger)
356 ######################################################################
358 if __name__ == "__main__":
361 sky = Sky(height=6, width=8, nb_iterations=100)
363 start_time = time.perf_counter()
364 seq, it = sky.generate_seq(nb=64, return_iterations=True)
365 delay = time.perf_counter() - start_time
366 print(f"{seq.size(0)/delay:02f} samples/s")
368 print(sky.seq2str(seq[:4]))
370 for t in range(len(it[0])):
371 img = torch.cat([sky.frame2img(f[t]) for f in it], dim=0)
372 torchvision.utils.save_image(
374 f"/tmp/frame_{t:03d}.png",
380 # m = (torch.rand(seq.size()) < 0.05).long()
381 # seq = (1 - m) * seq + m * 23
383 img = sky.seq2img(seq)
386 torchvision.utils.save_image(
387 img.float() / 255.0, "/tmp/world.png", nrow=6, padding=6, pad_value=0