3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, tqdm, os, warnings
10 import torch, torchvision
13 from torch.nn import functional as F
15 ######################################################################
20 class Sky(problem.Problem):
21 colors = torch.tensor(
39 nb_bird_tokens = colors.size(0) - 1
42 "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
56 self.nb_birds = nb_birds
58 self.nb_iterations = nb_iterations
59 self.avoid_collision = avoid_collision
61 def generate_frame_sequences(self, nb):
64 for _ in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"):
66 torch.empty(self.nb_birds, dtype=torch.int64),
67 torch.empty(self.nb_birds, dtype=torch.int64),
68 torch.empty(self.nb_birds, dtype=torch.int64),
69 torch.empty(self.nb_birds, dtype=torch.int64),
73 if not self.avoid_collision:
76 count = torch.zeros(self.height, self.width, dtype=torch.int64)
78 for n in range(self.nb_birds):
79 count[i[n], j[n]] += 1
80 count[i[n] - vi[n], j[n]] += 1
81 count[i[n], j[n] - vj[n]] += 1
83 return count.max() <= 1
86 torch.randperm(self.colors.size(0) - 1)[: self.nb_birds].sort().values
92 for n in range(self.nb_birds):
94 i[n] = torch.randint(self.height, (1,))
95 j[n] = torch.randint(self.width, (1,))
96 vm = torch.randint(4, (1,))
97 vi[n], vj[n] = (vm % 2) * 2 - 1, (vm // 2) * 2 - 1
100 and i[n] - vi[n] < self.height
101 and j[n] - vj[n] >= 0
102 and j[n] - vj[n] < self.width
109 result = torch.zeros(
110 self.nb_iterations * self.speed,
116 fine = torch.empty(self.nb_iterations * self.speed)
119 torch.arange(self.nb_iterations, device=result.device) * self.speed
122 for l in range(self.nb_iterations * self.speed):
123 fine[l] = collision_okay()
124 for n in range(self.nb_birds):
126 result[l, i[n], j[n]] = c
127 result[l, i[n] - vi[n], j[n]] = c
128 result[l, i[n], j[n] - vj[n]] = c
130 if (i[n] == 0 and vi[n] == -1) or (
131 i[n] == self.height - 1 and vi[n] == 1
135 if (j[n] == 0 and vj[n] == -1) or (
136 j[n] == self.width - 1 and vj[n] == 1
143 result = result[t_to_keep]
144 fine = fine[t_to_keep]
149 frame_sequences.append(result)
151 return frame_sequences
153 ######################################################################
155 def frame2img(self, x, scale=15):
156 x = x.reshape(x.size(0), self.height, -1)
157 m = torch.logical_and(
158 x >= 0, x < self.first_bird_token + self.nb_bird_tokens
160 x = self.colors[x * m].permute(0, 3, 1, 2)
162 x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
163 x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
165 x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
166 x[:, :, torch.arange(0, x.size(2), scale), :] = 0
169 for n in range(m.size(0)):
170 for i in range(m.size(1)):
171 for j in range(m.size(2)):
173 for k in range(2, scale - 2):
175 x[n, :, i * scale + k, j * scale + k - l] = 0
177 n, :, i * scale + scale - 1 - k, j * scale + k - l
182 def seq2str(self, seq):
185 result.append("".join([self.token2char[v] for v in s]))
194 predicted_prompts=None,
195 predicted_answers=None,
197 if predicted_prompts is None:
198 predicted_prompts = 255
200 if predicted_answers is None:
201 predicted_answers = 255
203 def add_frame(x, c, margin, bottom=False):
205 h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0
208 x.size(2) + 2 * margin,
209 x.size(3) + 2 * margin,
214 y = x.new_full((x.size(0), x.size(1), h, w), 0)
219 c = c.long()[:, None]
221 (c == 1).long() * torch.tensor([0, 255, 0], device=c.device)
222 + (c == 0).long() * torch.tensor([255, 255, 255], device=c.device)
223 + (c == -1).long() * torch.tensor([255, 0, 0], device=c.device)
225 y[...] = c[:, :, None, None]
227 y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x
233 img_prompts = add_frame(self.frame2img(prompts.to("cpu")), c=0, margin=1)
234 h = img_prompts.size(2)
235 img_answers = add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1)
237 img_prompts = add_frame(img_prompts, c=255, margin=margin, bottom=True)
238 img_answers = add_frame(img_answers, c=255, margin=margin, bottom=True)
240 img_prompts = add_frame(
241 img_prompts, c=predicted_prompts, margin=margin, bottom=True
243 img_answers = add_frame(
244 img_answers, c=predicted_answers, margin=margin, bottom=True
249 separator = img_prompts.new_full(
259 separator[:, :, 0] = 0
260 separator[:, :, h - 1] = 0
262 for k in range(1, 2 * marker_size - 8):
263 i = k - (marker_size - 4)
264 j = marker_size - 5 - abs(i)
265 separator[:, :, h // 2 - 1 + i, 2 + j] = 0
266 separator[:, :, h // 2 - 1 + i + 1, 2 + j] = 0
268 img = torch.cat([img_prompts, separator, img_answers], dim=3)
270 image_name = os.path.join(result_dir, filename)
271 torchvision.utils.save_image(
272 img.float() / 255.0, image_name, nrow=6, padding=margin * 4, pad_value=1.0
275 ######################################################################
277 def nb_token_values(self):
278 return len(self.colors)
280 def generate_prompts_and_answers(self, nb):
281 frame_sequences = self.generate_frame_sequences(nb)
282 frame_sequences = torch.cat([x[None] for x in frame_sequences], dim=0)
284 prompts = frame_sequences[:, : frame_sequences.size(1) // 2].flatten(1)
286 answers = frame_sequences[:, frame_sequences.size(1) // 2 :].flatten(1)
288 # warnings.warn("dirty test with longer answer", RuntimeWarning)
289 # answers = torch.cat(
291 # frame_sequences[:, frame_sequences.size(1) // 2 :],
292 # frame_sequences[:, frame_sequences.size(1) // 2 :],
297 return prompts, answers
305 predicted_prompts=None,
306 predicted_answers=None,
310 filename_prefix + ".png",
318 ######################################################################
320 if __name__ == "__main__":
323 sky = Sky(height=6, width=8, speed=1, nb_iterations=4)
325 prompts, answers = sky.generate_prompts_and_answers(4)
327 predicted_prompts = torch.randint(3, (prompts.size(0),)) - 1
328 predicted_answers = torch.randint(3, (prompts.size(0),)) - 1
331 "/tmp", "test", prompts, answers, predicted_prompts, predicted_answers
334 # start_time = time.perf_counter()
335 # token_sequences = sky.generate_token_sequences(nb=64)
336 # delay = time.perf_counter() - start_time
337 # print(f"{token_sequences.size(0)/delay:02f} seq/s")
339 # print(sky.seq2str(seq[:4]))
341 # for t in range(len(it[0])):
342 # img = torch.cat([sky.frame2img(f[t]) for f in it], dim=0)
343 # torchvision.utils.save_image(
344 # img.float() / 255.0,
345 # f"/tmp/frame_{t:03d}.png",
351 # m = (torch.rand(seq.size()) < 0.05).long()
352 # seq = (1 - m) * seq + m * 23
355 # img = sky.seq2img(token_sequences)
358 # torchvision.utils.save_image(
359 # img.float() / 255.0, "/tmp/world.png", nrow=6, padding=6, pad_value=0