3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
10 import torch, torchvision
13 from torch.nn import functional as F
15 ######################################################################
19 colors = torch.tensor(
37 nb_bird_tokens = colors.size(0) - 1
38 token_forward = first_bird_token + nb_bird_tokens
39 token_backward = token_forward + 1
42 "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
45 def __init__(self, height=6, width=8, nb_birds=3, nb_iterations=2):
48 self.nb_birds = nb_birds
49 self.nb_iterations = nb_iterations
51 def generate_seq(self, nb, return_iterations=False):
55 for _ in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"):
59 f_start = torch.zeros(self.height, self.width, dtype=torch.int64)
62 torch.empty(self.nb_birds, dtype=torch.int64),
63 torch.empty(self.nb_birds, dtype=torch.int64),
64 torch.empty(self.nb_birds, dtype=torch.int64),
65 torch.empty(self.nb_birds, dtype=torch.int64),
69 torch.randperm(self.colors.size(0) - 1)[: self.nb_birds]
75 for n in range(self.nb_birds):
80 torch.randint(self.height, (1,))[0],
81 torch.randint(self.width, (1,))[0],
83 vm = torch.randint(4, (1,))[0]
84 vi[n], vj[n] = (vm % 2) * 2 - 1, (vm // 2) * 2 - 1
87 and i[n] - vi[n] < self.height
89 and j[n] - vj[n] < self.width
90 and f_start[i[n], j[n]] == 0
91 and f_start[i[n] - vi[n], j[n]] == 0
92 and f_start[i[n], j[n] - vj[n]] == 0
96 f_start[i[n], j[n]] = c
97 f_start[i[n] - vi[n], j[n]] = c
98 f_start[i[n], j[n] - vj[n]] = c
100 f_end = f_start.clone()
102 for l in range(self.nb_iterations):
103 iterations.append(f_end.clone())
106 for n in range(self.nb_birds):
116 if (i[n] == 0 and vi[n] == -1) or (
117 i[n] == self.height - 1 and vi[n] == 1
120 if (j[n] == 0 and vj[n] == -1) or (
121 j[n] == self.width - 1 and vj[n] == 1
129 f_end[i[n], j[n]] == 0
130 and f_end[i[n] - vi[n], j[n]] == 0
131 and f_end[i[n], j[n] - vj[n]] == 0
135 f_end[i[n], j[n]] = c
136 f_end[i[n] - vi[n], j[n]] = c
137 f_end[i[n], j[n] - vj[n]] = c
139 iterations.append(f_end.clone())
141 if nb_collisions == 0:
144 kept_iterations.append(iterations)
145 pairs.append((f_start, f_end))
149 if torch.rand(1) < 0.5:
154 torch.tensor([self.token_forward]),
165 torch.tensor([self.token_backward]),
172 if return_iterations:
173 # iterations = torch.cat([ torch.cat([ x[None, None] for x in l], dim = 1) for l in kept_iterations ], dim=0)
174 return torch.cat(result, dim=0), kept_iterations
176 return torch.cat(result, dim=0)
178 ######################################################################
180 def generate_seq_old(
186 for n in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"):
187 f_start = torch.zeros(self.height, self.width, dtype=torch.int64)
188 f_end = torch.zeros(self.height, self.width, dtype=torch.int64)
189 n = torch.arange(f_start.size(0))
192 (torch.randperm(self.nb_bird_tokens) + self.first_bird_token)[
199 torch.randint(self.height - 2, (1,))[0] + 1,
200 torch.randint(self.width - 2, (1,))[0] + 1,
202 vm = torch.randint(4, (1,))[0]
203 vi, vj = (vm // 2) * (2 * (vm % 2) - 1), (1 - vm // 2) * (
208 f_start[i - vi, j - vj] = c
209 f_start[i + vj, j - vi] = c
210 f_start[i - vj, j + vi] = c
212 for l in range(self.nb_iterations):
215 if i < 0 or i >= self.height or j < 0 or j >= self.width:
223 f_end[i - vi, j - vj] = c
224 f_end[i + vj, j - vi] = c
225 f_end[i - vj, j + vi] = c
227 pairs.append((f_start, f_end))
231 if torch.rand(1) < 0.5:
236 torch.tensor([self.token_forward]),
247 torch.tensor([self.token_backward]),
254 return torch.cat(result, dim=0)
256 def frame2img(self, x, upscale=15):
257 x = x.reshape(-1, self.height, self.width)
258 m = torch.logical_and(
259 x >= 0, x < self.first_bird_token + self.nb_bird_tokens
261 x = self.colors[x * m].permute(0, 3, 1, 2)
263 x = x[:, :, :, None, :, None].expand(-1, -1, -1, upscale, -1, upscale)
264 x = x.reshape(s[0], s[1], s[2] * upscale, s[3] * upscale)
266 x[:, :, :, torch.arange(0, x.size(3), upscale)] = 0
267 x[:, :, torch.arange(0, x.size(2), upscale), :] = 0
270 for n in range(m.size(0)):
271 for i in range(m.size(1)):
272 for j in range(m.size(2)):
274 for k in range(2, upscale - 2):
275 x[n, :, i * upscale + k, j * upscale + k] = 0
276 x[n, :, i * upscale + upscale - 1 - k, j * upscale + k] = 0
280 def seq2img(self, seq, upscale=15):
281 f_first = seq[:, : self.height * self.width].reshape(
282 -1, self.height, self.width
284 f_second = seq[:, self.height * self.width + 1 :].reshape(
285 -1, self.height, self.width
287 direction = seq[:, self.height * self.width]
289 direction_symbol = torch.full(
290 (direction.size(0), self.height * upscale - 1, upscale), 0
292 direction_symbol = self.colors[direction_symbol].permute(0, 3, 1, 2)
293 separator = torch.full((direction.size(0), 3, self.height * upscale - 1, 1), 0)
295 for n in range(direction_symbol.size(0)):
296 if direction[n] == self.token_forward:
297 for k in range(upscale):
301 (self.height * upscale) // 2 - upscale // 2 + k,
302 3 + upscale // 2 - abs(k - upscale // 2),
304 elif direction[n] == self.token_backward:
305 for k in range(upscale):
309 (self.height * upscale) // 2 - upscale // 2 + k,
310 3 + abs(k - upscale // 2),
313 for k in range(2, upscale - 2):
315 n, :, (self.height * upscale) // 2 - upscale // 2 + k, k
320 (self.height * upscale) // 2 - upscale // 2 + k,
326 self.frame2img(f_first, upscale),
330 self.frame2img(f_second, upscale),
335 def seq2str(self, seq):
338 result.append("".join([self.token2char[v] for v in s]))
342 ######################################################################
344 if __name__ == "__main__":
347 sky = Sky(height=6, width=8, nb_iterations=100)
349 start_time = time.perf_counter()
350 seq, it = sky.generate_seq(nb=64, return_iterations=True)
351 delay = time.perf_counter() - start_time
352 print(f"{seq.size(0)/delay:02f} samples/s")
354 print(sky.seq2str(seq[:4]))
356 for t in range(len(it[0])):
357 img = torch.cat([sky.frame2img(f[t]) for f in it], dim=0)
358 torchvision.utils.save_image(
360 f"/tmp/frame_{t:03d}.png",
366 # m = (torch.rand(seq.size()) < 0.05).long()
367 # seq = (1 - m) * seq + m * 23
369 img = sky.seq2img(seq)
372 torchvision.utils.save_image(
373 img.float() / 255.0, "/tmp/world.png", nrow=6, padding=6, pad_value=0