3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, tqdm, os, warnings
10 import torch, torchvision
13 from torch.nn import functional as F
15 ######################################################################
20 class Lang(problem.Problem):
22 ("white", [255, 255, 255]),
24 ("green", [0, 192, 0]),
25 ("blue", [0, 0, 255]),
26 ("orange", [255, 192, 0]),
27 ("cyan", [0, 255, 255]),
28 ("violet", [255, 0, 255]),
29 ("lightgreen", [192, 255, 192]),
30 ("pink", [255, 192, 192]),
31 ("lightblue", [192, 192, 255]),
32 ("gray", [192, 192, 192]),
39 self.colors = torch.tensor([c for _, c in self.named_colors])
40 self.name2color = dict([(p[0], i) for i, p in enumerate(self.named_colors)])
43 self.nb_iterations = nb_iterations
45 ######################################################################
47 def frame2img(self, x, scale=15):
48 x = x.reshape(x.size(0), self.height, -1)
49 x = self.colors[x].permute(0, 3, 1, 2)
51 x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
52 x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
54 x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
55 x[:, :, torch.arange(0, x.size(2), scale), :] = 0
66 predicted_prompts=None,
67 predicted_answers=None,
69 prompts = prompts.reshape(prompts.size(0), self.height, -1)
70 answers = answers.reshape(answers.size(0), self.height, -1)
72 if predicted_prompts is None:
73 predicted_prompts = 255
75 if predicted_answers is None:
76 predicted_answers = 255
78 def add_frame(x, c, margin, bottom=False):
80 h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0
83 x.size(2) + 2 * margin,
84 x.size(3) + 2 * margin,
89 y = x.new_full((x.size(0), x.size(1), h, w), 0)
95 c = c * torch.tensor([192, 192, 192], device=c.device) + (
97 ) * torch.tensor([255, 255, 255], device=c.device)
98 y[...] = c[:, :, None, None]
100 y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x
106 img_prompts = torch.cat(
109 add_frame(self.frame2img(x), c=0, margin=1),
113 for x in prompts.to("cpu").split(split_size=self.width, dim=2)
118 h = img_prompts.size(2)
119 img_answers = add_frame(
120 add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1),
125 separator_size = 2 * margin
127 separator = img_prompts.new_full(
137 marker = img_prompts.new_full(
147 # marker[:, :, 0] = 0
148 # marker[:, :, h - 1] = 0
150 for k in range(1, 2 * separator_size - 8):
151 i = k - (separator_size - 4)
152 j = separator_size - 5 - abs(i)
153 marker[:, :, h // 2 - 1 + i, 2 + j] = 0
154 marker[:, :, h // 2 - 1 + i + 1, 2 + j] = 0
165 image_name = os.path.join(result_dir, filename)
166 torchvision.utils.save_image(
167 img.float() / 255.0, image_name, nrow=4, padding=margin * 4, pad_value=1.0
170 ######################################################################
172 def nb_token_values(self):
173 return len(self.colors)
175 def rec_coo(self, x, n, min_height=3, min_width=3):
177 collision = x.new_zeros(x.size())
181 i1, i2 = torch.randint(x.size(0), (2,))
182 if i1 + min_height <= i2:
185 j1, j2 = torch.randint(x.size(1), (2,))
186 if j1 + min_width <= j2:
188 collision[i1:i2, j1:j2] += 1
189 if collision.max() > 1:
191 result.append((i1, j1, i2, j2))
192 if collision.max() == 1:
196 ######################################################################
198 def task_replace_color(self, A, f_A, B, f_B):
200 c = torch.randperm(len(self.colors) - 1)[: N + 1] + 1
201 for X, f_X in [(A, f_A), (B, f_B)]:
202 r = self.rec_coo(X, N)
204 i1, j1, i2, j2 = r[n]
205 X[i1:i2, j1:j2] = c[n]
206 f_X[i1:i2, j1:j2] = c[n if n > 0 else -1]
208 def task_move(self, A, f_A, B, f_B):
209 di, dj = torch.randint(2, (2,)) * 2 - 1
211 c = torch.randperm(len(self.colors) - 1)[:N] + 1
212 for X, f_X in [(A, f_A), (B, f_B)]:
214 r = self.rec_coo(X, N)
215 i1, j1, i2, j2 = r[N - 1]
218 and i2 + di < X.size(0)
220 and j2 + dj < X.size(1)
225 i1, j1, i2, j2 = r[n]
226 X[i1:i2, j1:j2] = c[n]
228 f_X[i1 + di : i2 + di, j1 + dj : j2 + dj] = c[n]
230 f_X[i1:i2, j1:j2] = c[n]
232 def task_grow(self, A, f_A, B, f_B):
233 di, dj = torch.randint(2, (2,)) * 2 - 1
235 c = torch.randperm(len(self.colors) - 1)[:N] + 1
236 direction = torch.randint(2, (1,))
237 for X, f_X in [(A, f_A), (B, f_B)]:
239 r = self.rec_coo(X, N)
240 i1, j1, i2, j2 = r[N - 1]
241 if i1 + 3 < i2 and j1 + 3 < j2:
245 i1, j1, i2, j2 = r[n]
248 X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
249 f_X[i1:i2, j1:j2] = c[n]
251 X[i1:i2, j1:j2] = c[n]
252 f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
254 X[i1:i2, j1:j2] = c[n]
255 f_X[i1:i2, j1:j2] = c[n]
257 def task_color_grow(self, A, f_A, B, f_B):
258 di, dj = torch.randint(2, (2,)) * 2 - 1
260 c = torch.randperm(len(self.colors) - 1)[: 2 * N] + 1
261 direction = torch.randint(2, (1,))
262 for X, f_X in [(A, f_A), (B, f_B)]:
263 r = self.rec_coo(X, N)
265 i1, j1, i2, j2 = r[n]
266 X[i1 : (i1 + i2) // 2, j1:j2] = c[2 * n]
267 f_X[i1 : (i1 + i2) // 2, j1:j2] = c[2 * n]
268 X[(i1 + i2) // 2 : (i1 + i2) // 2 + 1, j1:j2] = c[2 * n + 1]
270 f_X[(i1 + i2) // 2 : i2, j1:j2] = c[2 * n + 1]
272 f_X[(i1 + i2) // 2 : (i1 + i2) // 2 + 1, j1:j2] = c[2 * n + 1]
274 def task_frame(self, A, f_A, B, f_B):
276 c = torch.randperm(len(self.colors) - 1)[: N + 1] + 1
277 for X, f_X in [(A, f_A), (B, f_B)]:
278 r = self.rec_coo(X, N)
280 i1, j1, i2, j2 = r[n]
281 X[i1:i2, j1:j2] = c[n]
282 f_X[i1:i2, j1:j2] = c[n]
284 f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = 0
286 ######################################################################
288 def generate_prompts_and_answers(self, nb):
290 self.task_replace_color,
293 self.task_color_grow,
296 prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64)
297 answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64)
299 for prompt, answer in zip(prompts, answers):
300 A = prompt[:, 0 * w : 1 * w]
301 f_A = prompt[:, 1 * w : 2 * w]
302 B = prompt[:, 2 * w : 3 * w]
304 tasks[torch.randint(len(tasks), (1,))](A, f_A, B, f_B)
305 return prompts.flatten(1), answers.flatten(1)
313 predicted_prompts=None,
314 predicted_answers=None,
318 filename_prefix + ".png",
326 ######################################################################
328 if __name__ == "__main__":
331 lang = Lang(nb_iterations=4)
333 prompts, answers = lang.generate_prompts_and_answers(36)
335 # predicted_prompts = torch.rand(prompts.size(0)) < 0.5
336 # predicted_answers = torch.logical_not(predicted_prompts)
343 # You can add a bool to put a frame around the predicted parts
344 # predicted_prompts, predicted_answers