3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, tqdm, os, warnings
10 import torch, torchvision
13 from torch.nn import functional as F
15 ######################################################################
20 class Sky(problem.Problem):
21 colors = torch.tensor(
39 nb_bird_tokens = colors.size(0) - 1
42 "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
53 max_nb_cached_chunks=None,
57 super().__init__(max_nb_cached_chunks, chunk_size, nb_threads)
60 self.nb_birds = nb_birds
62 self.nb_iterations = nb_iterations
63 self.avoid_collision = avoid_collision
65 def generate_frame_sequences(self, nb):
68 for _ in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"):
70 torch.empty(self.nb_birds, dtype=torch.int64),
71 torch.empty(self.nb_birds, dtype=torch.int64),
72 torch.empty(self.nb_birds, dtype=torch.int64),
73 torch.empty(self.nb_birds, dtype=torch.int64),
77 if not self.avoid_collision:
80 count = torch.zeros(self.height, self.width, dtype=torch.int64)
82 for n in range(self.nb_birds):
83 count[i[n], j[n]] += 1
84 count[i[n] - vi[n], j[n]] += 1
85 count[i[n], j[n] - vj[n]] += 1
87 return count.max() <= 1
90 torch.randperm(self.colors.size(0) - 1)[: self.nb_birds].sort().values
96 for n in range(self.nb_birds):
98 i[n] = torch.randint(self.height, (1,))
99 j[n] = torch.randint(self.width, (1,))
100 vm = torch.randint(4, (1,))
101 vi[n], vj[n] = (vm % 2) * 2 - 1, (vm // 2) * 2 - 1
104 and i[n] - vi[n] < self.height
105 and j[n] - vj[n] >= 0
106 and j[n] - vj[n] < self.width
113 result = torch.zeros(
114 self.nb_iterations * self.speed,
120 fine = torch.empty(self.nb_iterations * self.speed)
123 torch.arange(self.nb_iterations, device=result.device) * self.speed
126 for l in range(self.nb_iterations * self.speed):
127 fine[l] = collision_okay()
128 for n in range(self.nb_birds):
130 result[l, i[n], j[n]] = c
131 result[l, i[n] - vi[n], j[n]] = c
132 result[l, i[n], j[n] - vj[n]] = c
134 if (i[n] == 0 and vi[n] == -1) or (
135 i[n] == self.height - 1 and vi[n] == 1
139 if (j[n] == 0 and vj[n] == -1) or (
140 j[n] == self.width - 1 and vj[n] == 1
147 result = result[t_to_keep]
148 fine = fine[t_to_keep]
153 frame_sequences.append(result)
155 return frame_sequences
157 ######################################################################
159 def frame2img(self, x, scale=15):
160 x = x.reshape(x.size(0), self.height, -1)
161 m = torch.logical_and(
162 x >= 0, x < self.first_bird_token + self.nb_bird_tokens
164 x = self.colors[x * m].permute(0, 3, 1, 2)
166 x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
167 x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
169 x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
170 x[:, :, torch.arange(0, x.size(2), scale), :] = 0
173 for n in range(m.size(0)):
174 for i in range(m.size(1)):
175 for j in range(m.size(2)):
177 for k in range(2, scale - 2):
179 x[n, :, i * scale + k, j * scale + k - l] = 0
181 n, :, i * scale + scale - 1 - k, j * scale + k - l
186 def seq2str(self, seq):
189 result.append("".join([self.token2char[v] for v in s]))
198 predicted_prompts=None,
199 predicted_answers=None,
201 if predicted_prompts is None:
202 predicted_prompts = 255
204 if predicted_answers is None:
205 predicted_answers = 255
207 def add_frame(x, c, margin, bottom=False):
209 h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0
212 x.size(2) + 2 * margin,
213 x.size(3) + 2 * margin,
218 y = x.new_full((x.size(0), x.size(1), h, w), 0)
223 c = c.long()[:, None]
225 (c == 1).long() * torch.tensor([0, 255, 0], device=c.device)
226 + (c == 0).long() * torch.tensor([255, 255, 255], device=c.device)
227 + (c == -1).long() * torch.tensor([255, 0, 0], device=c.device)
229 y[...] = c[:, :, None, None]
231 y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x
237 img_prompts = add_frame(self.frame2img(prompts.to("cpu")), c=0, margin=1)
238 h = img_prompts.size(2)
239 img_answers = add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1)
241 img_prompts = add_frame(img_prompts, c=255, margin=margin, bottom=True)
242 img_answers = add_frame(img_answers, c=255, margin=margin, bottom=True)
244 img_prompts = add_frame(
245 img_prompts, c=predicted_prompts, margin=margin, bottom=True
247 img_answers = add_frame(
248 img_answers, c=predicted_answers, margin=margin, bottom=True
253 separator = img_prompts.new_full(
263 separator[:, :, 0] = 0
264 separator[:, :, h - 1] = 0
266 for k in range(1, 2 * marker_size - 8):
267 i = k - (marker_size - 4)
268 j = marker_size - 5 - abs(i)
269 separator[:, :, h // 2 - 1 + i, 2 + j] = 0
270 separator[:, :, h // 2 - 1 + i + 1, 2 + j] = 0
272 img = torch.cat([img_prompts, separator, img_answers], dim=3)
274 image_name = os.path.join(result_dir, filename)
275 torchvision.utils.save_image(
276 img.float() / 255.0, image_name, nrow=6, padding=margin * 4, pad_value=1.0
279 ######################################################################
281 def nb_token_values(self):
282 return len(self.colors)
284 def generate_prompts_and_answers(self, nb):
285 frame_sequences = self.generate_frame_sequences(nb)
286 frame_sequences = torch.cat([x[None] for x in frame_sequences], dim=0)
288 prompts = frame_sequences[:, : frame_sequences.size(1) // 2].flatten(1)
290 answers = frame_sequences[:, frame_sequences.size(1) // 2 :].flatten(1)
292 # warnings.warn("dirty test with longer answer", RuntimeWarning)
293 # answers = torch.cat(
295 # frame_sequences[:, frame_sequences.size(1) // 2 :],
296 # frame_sequences[:, frame_sequences.size(1) // 2 :],
301 return prompts, answers
309 predicted_prompts=None,
310 predicted_answers=None,
314 filename_prefix + ".png",
322 ######################################################################
324 if __name__ == "__main__":
327 sky = Sky(height=6, width=8, speed=1, nb_iterations=4)
329 prompts, answers = sky.generate_prompts_and_answers(4)
331 predicted_prompts = torch.randint(3, (prompts.size(0),)) - 1
332 predicted_answers = torch.randint(3, (prompts.size(0),)) - 1
335 "/tmp", "test", prompts, answers, predicted_prompts, predicted_answers
338 # start_time = time.perf_counter()
339 # token_sequences = sky.generate_token_sequences(nb=64)
340 # delay = time.perf_counter() - start_time
341 # print(f"{token_sequences.size(0)/delay:02f} seq/s")
343 # print(sky.seq2str(seq[:4]))
345 # for t in range(len(it[0])):
346 # img = torch.cat([sky.frame2img(f[t]) for f in it], dim=0)
347 # torchvision.utils.save_image(
348 # img.float() / 255.0,
349 # f"/tmp/frame_{t:03d}.png",
355 # m = (torch.rand(seq.size()) < 0.05).long()
356 # seq = (1 - m) * seq + m * 23
359 # img = sky.seq2img(token_sequences)
362 # torchvision.utils.save_image(
363 # img.float() / 255.0, "/tmp/world.png", nrow=6, padding=6, pad_value=0