3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, tqdm, os
10 import torch, torchvision
13 from torch.nn import functional as F
15 ######################################################################
20 class Sky(problem.Problem):
21 colors = torch.tensor(
39 nb_bird_tokens = colors.size(0) - 1
42 "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
45 def nb_token_values(self):
46 return len(self.colors)
59 self.nb_birds = nb_birds
61 self.nb_iterations = nb_iterations
62 self.avoid_collision = avoid_collision
64 def generate_frame_sequences(self, nb):
67 for _ in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"):
69 torch.empty(self.nb_birds, dtype=torch.int64),
70 torch.empty(self.nb_birds, dtype=torch.int64),
71 torch.empty(self.nb_birds, dtype=torch.int64),
72 torch.empty(self.nb_birds, dtype=torch.int64),
76 if not self.avoid_collision:
79 count = torch.zeros(self.height, self.width, dtype=torch.int64)
81 for n in range(self.nb_birds):
82 count[i[n], j[n]] += 1
83 count[i[n] - vi[n], j[n]] += 1
84 count[i[n], j[n] - vj[n]] += 1
86 return count.max() <= 1
89 torch.randperm(self.colors.size(0) - 1)[: self.nb_birds].sort().values
95 for n in range(self.nb_birds):
97 i[n] = torch.randint(self.height, (1,))
98 j[n] = torch.randint(self.width, (1,))
99 vm = torch.randint(4, (1,))
100 vi[n], vj[n] = (vm % 2) * 2 - 1, (vm // 2) * 2 - 1
103 and i[n] - vi[n] < self.height
104 and j[n] - vj[n] >= 0
105 and j[n] - vj[n] < self.width
112 result = torch.zeros(
113 self.nb_iterations * self.speed,
119 fine = torch.empty(self.nb_iterations * self.speed)
122 torch.arange(self.nb_iterations, device=result.device) * self.speed
125 for l in range(self.nb_iterations * self.speed):
126 fine[l] = collision_okay()
127 for n in range(self.nb_birds):
129 result[l, i[n], j[n]] = c
130 result[l, i[n] - vi[n], j[n]] = c
131 result[l, i[n], j[n] - vj[n]] = c
133 if (i[n] == 0 and vi[n] == -1) or (
134 i[n] == self.height - 1 and vi[n] == 1
138 if (j[n] == 0 and vj[n] == -1) or (
139 j[n] == self.width - 1 and vj[n] == 1
146 result = result[t_to_keep]
147 fine = fine[t_to_keep]
152 frame_sequences.append(result)
154 return frame_sequences
156 ######################################################################
158 def generate_prompts_and_answers(self, nb):
159 frame_sequences = self.generate_frame_sequences(nb)
160 frame_sequences = torch.cat([x[None] for x in frame_sequences], dim=0)
161 prompts = frame_sequences[:, : frame_sequences.size(1) // 2].flatten(1)
162 answers = frame_sequences[:, frame_sequences.size(1) // 2 :].flatten(1)
163 return prompts, answers
165 ######################################################################
167 def frame2img(self, x, scale=15):
168 x = x.reshape(x.size(0), self.height, -1)
169 m = torch.logical_and(
170 x >= 0, x < self.first_bird_token + self.nb_bird_tokens
172 x = self.colors[x * m].permute(0, 3, 1, 2)
174 x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
175 x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
177 x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
178 x[:, :, torch.arange(0, x.size(2), scale), :] = 0
181 for n in range(m.size(0)):
182 for i in range(m.size(1)):
183 for j in range(m.size(2)):
185 for k in range(2, scale - 2):
187 x[n, :, i * scale + k, j * scale + k - l] = 0
189 n, :, i * scale + scale - 1 - k, j * scale + k - l
194 def seq2str(self, seq):
197 result.append("".join([self.token2char[v] for v in s]))
206 predicted_prompts=None,
207 predicted_answers=None,
209 if predicted_prompts is None:
210 predicted_prompts = 255
212 if predicted_answers is None:
213 predicted_answers = 255
215 def add_frame(x, c, margin):
217 (x.size(0), x.size(1), x.size(2) + 2 * margin, x.size(3) + 2 * margin),
223 c = c.long()[:, None]
224 c = c * torch.tensor([192, 192, 192], device=c.device) + (
226 ) * torch.tensor([255, 255, 255], device=c.device)
227 y[...] = c[:, :, None, None]
228 y[:, :, margin:-margin, margin:-margin] = x
233 img_prompts = add_frame(self.frame2img(prompts.to("cpu")), 0, 1)
234 img_answers = add_frame(self.frame2img(answers.to("cpu")), 0, 1)
236 # img_prompts = add_frame(img_prompts, 255, margin)
237 # img_answers = add_frame(img_answers, 255, margin)
239 img_prompts = add_frame(img_prompts, predicted_prompts, margin)
240 img_answers = add_frame(img_answers, predicted_answers, margin)
242 separator = img_prompts.new_full(
243 (img_prompts.size(0), img_prompts.size(1), img_prompts.size(2), margin), 255
246 img = torch.cat([img_prompts, img_answers], dim=3)
248 image_name = os.path.join(result_dir, filename)
249 torchvision.utils.save_image(
250 img.float() / 255.0, image_name, nrow=6, padding=margin * 2, pad_value=1.0
259 predicted_prompts=None,
260 predicted_answers=None,
264 filename_prefix + ".png",
272 ######################################################################
274 if __name__ == "__main__":
277 sky = Sky(height=6, width=8, speed=1, nb_iterations=4)
279 prompts, answers = sky.generate_prompts_and_answers(4)
281 predicted_prompts = torch.rand(prompts.size(0)) < 0.5
282 predicted_answers = torch.rand(answers.size(0)) < 0.5
285 "/tmp", "test", prompts, answers, predicted_prompts, predicted_answers
288 # start_time = time.perf_counter()
289 # token_sequences = sky.generate_token_sequences(nb=64)
290 # delay = time.perf_counter() - start_time
291 # print(f"{token_sequences.size(0)/delay:02f} seq/s")
293 # print(sky.seq2str(seq[:4]))
295 # for t in range(len(it[0])):
296 # img = torch.cat([sky.frame2img(f[t]) for f in it], dim=0)
297 # torchvision.utils.save_image(
298 # img.float() / 255.0,
299 # f"/tmp/frame_{t:03d}.png",
305 # m = (torch.rand(seq.size()) < 0.05).long()
306 # seq = (1 - m) * seq + m * 23
309 # img = sky.seq2img(token_sequences)
312 # torchvision.utils.save_image(
313 # img.float() / 255.0, "/tmp/world.png", nrow=6, padding=6, pad_value=0