3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, tqdm, os
10 import torch, torchvision
13 from torch.nn import functional as F
15 ######################################################################
20 class Sky(problem.Problem):
21 colors = torch.tensor(
39 nb_bird_tokens = colors.size(0) - 1
42 "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
56 self.nb_birds = nb_birds
58 self.nb_iterations = nb_iterations
59 self.avoid_collision = avoid_collision
61 def generate_frame_sequences(self, nb):
64 for _ in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"):
66 torch.empty(self.nb_birds, dtype=torch.int64),
67 torch.empty(self.nb_birds, dtype=torch.int64),
68 torch.empty(self.nb_birds, dtype=torch.int64),
69 torch.empty(self.nb_birds, dtype=torch.int64),
73 if not self.avoid_collision:
76 count = torch.zeros(self.height, self.width, dtype=torch.int64)
78 for n in range(self.nb_birds):
79 count[i[n], j[n]] += 1
80 count[i[n] - vi[n], j[n]] += 1
81 count[i[n], j[n] - vj[n]] += 1
83 return count.max() <= 1
86 torch.randperm(self.colors.size(0) - 1)[: self.nb_birds].sort().values
92 for n in range(self.nb_birds):
94 i[n] = torch.randint(self.height, (1,))
95 j[n] = torch.randint(self.width, (1,))
96 vm = torch.randint(4, (1,))
97 vi[n], vj[n] = (vm % 2) * 2 - 1, (vm // 2) * 2 - 1
100 and i[n] - vi[n] < self.height
101 and j[n] - vj[n] >= 0
102 and j[n] - vj[n] < self.width
109 result = torch.zeros(
110 self.nb_iterations * self.speed,
116 fine = torch.empty(self.nb_iterations * self.speed)
119 torch.arange(self.nb_iterations, device=result.device) * self.speed
122 for l in range(self.nb_iterations * self.speed):
123 fine[l] = collision_okay()
124 for n in range(self.nb_birds):
126 result[l, i[n], j[n]] = c
127 result[l, i[n] - vi[n], j[n]] = c
128 result[l, i[n], j[n] - vj[n]] = c
130 if (i[n] == 0 and vi[n] == -1) or (
131 i[n] == self.height - 1 and vi[n] == 1
135 if (j[n] == 0 and vj[n] == -1) or (
136 j[n] == self.width - 1 and vj[n] == 1
143 result = result[t_to_keep]
144 fine = fine[t_to_keep]
149 frame_sequences.append(result)
151 return frame_sequences
153 ######################################################################
155 def frame2img(self, x, scale=15):
156 x = x.reshape(x.size(0), self.height, -1)
157 m = torch.logical_and(
158 x >= 0, x < self.first_bird_token + self.nb_bird_tokens
160 x = self.colors[x * m].permute(0, 3, 1, 2)
162 x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
163 x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
165 x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
166 x[:, :, torch.arange(0, x.size(2), scale), :] = 0
169 for n in range(m.size(0)):
170 for i in range(m.size(1)):
171 for j in range(m.size(2)):
173 for k in range(2, scale - 2):
175 x[n, :, i * scale + k, j * scale + k - l] = 0
177 n, :, i * scale + scale - 1 - k, j * scale + k - l
182 def seq2str(self, seq):
185 result.append("".join([self.token2char[v] for v in s]))
194 predicted_prompts=None,
195 predicted_answers=None,
197 if predicted_prompts is None:
198 predicted_prompts = 255
200 if predicted_answers is None:
201 predicted_answers = 255
203 def add_frame(x, c, margin):
205 (x.size(0), x.size(1), x.size(2) + 2 * margin, x.size(3) + 2 * margin),
211 c = c.long()[:, None]
212 c = c * torch.tensor([192, 192, 192], device=c.device) + (
214 ) * torch.tensor([255, 255, 255], device=c.device)
215 y[...] = c[:, :, None, None]
216 y[:, :, margin:-margin, margin:-margin] = x
221 img_prompts = add_frame(self.frame2img(prompts.to("cpu")), 0, 1)
222 img_answers = add_frame(self.frame2img(answers.to("cpu")), 0, 1)
224 # img_prompts = add_frame(img_prompts, 255, margin)
225 # img_answers = add_frame(img_answers, 255, margin)
227 img_prompts = add_frame(img_prompts, predicted_prompts, margin)
228 img_answers = add_frame(img_answers, predicted_answers, margin)
230 separator = img_prompts.new_full(
231 (img_prompts.size(0), img_prompts.size(1), img_prompts.size(2), margin), 255
234 img = torch.cat([img_prompts, img_answers], dim=3)
236 image_name = os.path.join(result_dir, filename)
237 torchvision.utils.save_image(
238 img.float() / 255.0, image_name, nrow=6, padding=margin * 2, pad_value=1.0
241 ######################################################################
243 def nb_token_values(self):
244 return len(self.colors)
246 def generate_prompts_and_answers(self, nb):
247 frame_sequences = self.generate_frame_sequences(nb)
248 frame_sequences = torch.cat([x[None] for x in frame_sequences], dim=0)
249 prompts = frame_sequences[:, : frame_sequences.size(1) // 2].flatten(1)
250 answers = frame_sequences[:, frame_sequences.size(1) // 2 :].flatten(1)
251 return prompts, answers
259 predicted_prompts=None,
260 predicted_answers=None,
264 filename_prefix + ".png",
272 ######################################################################
274 if __name__ == "__main__":
277 sky = Sky(height=6, width=8, speed=1, nb_iterations=4)
279 prompts, answers = sky.generate_prompts_and_answers(4)
281 predicted_prompts = torch.rand(prompts.size(0)) < 0.5
282 predicted_answers = torch.rand(answers.size(0)) < 0.5
285 "/tmp", "test", prompts, answers, predicted_prompts, predicted_answers
288 # start_time = time.perf_counter()
289 # token_sequences = sky.generate_token_sequences(nb=64)
290 # delay = time.perf_counter() - start_time
291 # print(f"{token_sequences.size(0)/delay:02f} seq/s")
293 # print(sky.seq2str(seq[:4]))
295 # for t in range(len(it[0])):
296 # img = torch.cat([sky.frame2img(f[t]) for f in it], dim=0)
297 # torchvision.utils.save_image(
298 # img.float() / 255.0,
299 # f"/tmp/frame_{t:03d}.png",
305 # m = (torch.rand(seq.size()) < 0.05).long()
306 # seq = (1 - m) * seq + m * 23
309 # img = sky.seq2img(token_sequences)
312 # torchvision.utils.save_image(
313 # img.float() / 255.0, "/tmp/world.png", nrow=6, padding=6, pad_value=0