3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, tqdm, os
10 import torch, torchvision
13 from torch.nn import functional as F
15 ######################################################################
20 class Sky(problem.Problem):
21 colors = torch.tensor(
39 nb_bird_tokens = colors.size(0) - 1
40 token_forward = first_bird_token + nb_bird_tokens
41 token_backward = token_forward + 1
44 "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
58 self.nb_birds = nb_birds
60 self.nb_iterations = nb_iterations
61 self.avoid_collision = avoid_collision
63 def direction_tokens(self):
64 return self.token_forward, self.token_backward
66 def generate_frame_sequences(self, nb):
69 for _ in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"):
71 torch.empty(self.nb_birds, dtype=torch.int64),
72 torch.empty(self.nb_birds, dtype=torch.int64),
73 torch.empty(self.nb_birds, dtype=torch.int64),
74 torch.empty(self.nb_birds, dtype=torch.int64),
78 if not self.avoid_collision:
81 count = torch.zeros(self.height, self.width, dtype=torch.int64)
83 for n in range(self.nb_birds):
84 count[i[n], j[n]] += 1
85 count[i[n] - vi[n], j[n]] += 1
86 count[i[n], j[n] - vj[n]] += 1
88 return count.max() <= 1
91 torch.randperm(self.colors.size(0) - 1)[: self.nb_birds].sort().values
97 for n in range(self.nb_birds):
99 i[n] = torch.randint(self.height, (1,))
100 j[n] = torch.randint(self.width, (1,))
101 vm = torch.randint(4, (1,))
102 vi[n], vj[n] = (vm % 2) * 2 - 1, (vm // 2) * 2 - 1
105 and i[n] - vi[n] < self.height
106 and j[n] - vj[n] >= 0
107 and j[n] - vj[n] < self.width
114 result = torch.zeros(
115 self.nb_iterations * self.speed,
121 fine = torch.empty(self.nb_iterations * self.speed)
124 torch.arange(self.nb_iterations, device=result.device) * self.speed
127 for l in range(self.nb_iterations * self.speed):
128 fine[l] = collision_okay()
129 for n in range(self.nb_birds):
131 result[l, i[n], j[n]] = c
132 result[l, i[n] - vi[n], j[n]] = c
133 result[l, i[n], j[n] - vj[n]] = c
135 if (i[n] == 0 and vi[n] == -1) or (
136 i[n] == self.height - 1 and vi[n] == 1
140 if (j[n] == 0 and vj[n] == -1) or (
141 j[n] == self.width - 1 and vj[n] == 1
148 result = result[t_to_keep]
149 fine = fine[t_to_keep]
154 frame_sequences.append(result)
156 return frame_sequences
158 ######################################################################
160 def generate_prompts_and_answers(self, nb):
161 frame_sequences = self.generate_frame_sequences(nb)
162 prompts = frame_sequences[:, : frame_sequences.size(0) // 2].flatten(1)
163 answers = frame_sequences[:, frame_sequences.size(0) // 2 :].flatten(1)
164 return prompts, answers
166 def generate_token_sequences(self, nb):
167 frame_sequences = self.generate_frame_sequences(nb)
171 for frame_sequence in frame_sequences:
173 if torch.rand(1) < 0.5:
174 for frame in frame_sequence:
176 a.append(torch.tensor([self.token_forward]))
177 a.append(frame.flatten())
179 for frame in reversed(frame_sequence):
181 a.append(torch.tensor([self.token_backward]))
182 a.append(frame.flatten())
184 result.append(torch.cat(a, dim=0)[None, :])
186 return torch.cat(result, dim=0)
188 ######################################################################
190 def frame2img(self, x, scale=15):
191 x = x.reshape(-1, self.height, self.width)
192 m = torch.logical_and(
193 x >= 0, x < self.first_bird_token + self.nb_bird_tokens
195 x = self.colors[x * m].permute(0, 3, 1, 2)
197 x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
198 x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
200 x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
201 x[:, :, torch.arange(0, x.size(2), scale), :] = 0
204 for n in range(m.size(0)):
205 for i in range(m.size(1)):
206 for j in range(m.size(2)):
208 for k in range(2, scale - 2):
210 x[n, :, i * scale + k, j * scale + k - l] = 0
212 n, :, i * scale + scale - 1 - k, j * scale + k - l
217 def seq2img(self, seq, scale=15):
220 seq[:, : self.height * self.width].reshape(-1, self.height, self.width),
225 separator = torch.full((seq.size(0), 3, self.height * scale - 1, 1), 0)
227 t = self.height * self.width
229 while t < seq.size(1):
230 direction_tokens = seq[:, t]
233 direction_images = self.colors[
235 (direction_tokens.size(0), self.height * scale - 1, scale), 0
237 ].permute(0, 3, 1, 2)
239 for n in range(direction_tokens.size(0)):
240 if direction_tokens[n] == self.token_forward:
241 for k in range(scale):
246 (self.height * scale) // 2 - scale // 2 + k - l,
247 3 + scale // 2 - abs(k - scale // 2),
249 elif direction_tokens[n] == self.token_backward:
250 for k in range(scale):
255 (self.height * scale) // 2 - scale // 2 + k - l,
256 3 + abs(k - scale // 2),
259 for k in range(2, scale - 2):
264 (self.height * scale) // 2 - scale // 2 + k - l,
270 (self.height * scale) // 2 - scale // 2 + k - l,
279 seq[:, t : t + self.height * self.width].reshape(
280 -1, self.height, self.width
286 t += self.height * self.width
288 return torch.cat(all, dim=3)
290 def seq2str(self, seq):
293 result.append("".join([self.token2char[v] for v in s]))
296 def save_image(self, input, result_dir, filename):
297 img = self.seq2img(input.to("cpu"))
298 image_name = os.path.join(result_dir, filename)
299 torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
301 def save_quizzes(self, input, result_dir, filename_prefix):
302 self.save_image(input, result_dir, filename_prefix + ".png")
305 ######################################################################
307 if __name__ == "__main__":
310 sky = Sky(height=6, width=8, speed=4, nb_iterations=2)
312 start_time = time.perf_counter()
313 token_sequences = sky.generate_token_sequences(nb=64)
314 delay = time.perf_counter() - start_time
315 print(f"{token_sequences.size(0)/delay:02f} seq/s")
317 # print(sky.seq2str(seq[:4]))
319 # for t in range(len(it[0])):
320 # img = torch.cat([sky.frame2img(f[t]) for f in it], dim=0)
321 # torchvision.utils.save_image(
322 # img.float() / 255.0,
323 # f"/tmp/frame_{t:03d}.png",
329 # m = (torch.rand(seq.size()) < 0.05).long()
330 # seq = (1 - m) * seq + m * 23
333 img = sky.seq2img(token_sequences)
336 torchvision.utils.save_image(
337 img.float() / 255.0, "/tmp/world.png", nrow=6, padding=6, pad_value=0