3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, tqdm, os, warnings
10 import torch, torchvision
13 from torch.nn import functional as F
15 ######################################################################
20 class Lang(problem.Problem):
22 ("white", [255, 255, 255]),
24 ("green", [0, 192, 0]),
25 ("blue", [0, 0, 255]),
26 ("orange", [255, 192, 0]),
27 ("cyan", [0, 255, 255]),
28 ("violet", [255, 0, 255]),
29 ("lightgreen", [192, 255, 192]),
30 ("pink", [255, 192, 192]),
31 ("lightblue", [192, 192, 255]),
32 ("gray", [192, 192, 192]),
39 self.colors = torch.tensor([c for _, c in self.named_colors])
40 self.name2color = dict([(p[0], i) for i, p in enumerate(self.named_colors)])
43 self.nb_iterations = nb_iterations
45 ######################################################################
47 def frame2img(self, x, scale=15):
48 x = x.reshape(x.size(0), self.height, -1)
49 x = self.colors[x].permute(0, 3, 1, 2)
51 x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
52 x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
54 x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
55 x[:, :, torch.arange(0, x.size(2), scale), :] = 0
66 predicted_prompts=None,
67 predicted_answers=None,
69 if predicted_prompts is None:
70 predicted_prompts = 255
72 if predicted_answers is None:
73 predicted_answers = 255
75 def add_frame(x, c, margin, bottom=False):
77 h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0
80 x.size(2) + 2 * margin,
81 x.size(3) + 2 * margin,
86 y = x.new_full((x.size(0), x.size(1), h, w), 0)
92 c = c * torch.tensor([0, 0, 0], device=c.device) + (
94 ) * torch.tensor([255, 255, 255], device=c.device)
95 y[...] = c[:, :, None, None]
97 y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x
103 img_prompts = add_frame(self.frame2img(prompts.to("cpu")), c=0, margin=1)
104 h = img_prompts.size(2)
105 img_answers = add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1)
107 img_prompts = add_frame(img_prompts, c=255, margin=margin, bottom=True)
108 img_answers = add_frame(img_answers, c=255, margin=margin, bottom=True)
110 img_prompts = add_frame(
111 img_prompts, c=predicted_prompts, margin=margin, bottom=True
113 img_answers = add_frame(
114 img_answers, c=predicted_answers, margin=margin, bottom=True
119 separator = img_prompts.new_full(
129 separator[:, :, 0] = 0
130 separator[:, :, h - 1] = 0
132 for k in range(1, 2 * marker_size - 8):
133 i = k - (marker_size - 4)
134 j = marker_size - 5 - abs(i)
135 separator[:, :, h // 2 - 1 + i, 2 + j] = 0
136 separator[:, :, h // 2 - 1 + i + 1, 2 + j] = 0
138 img = torch.cat([img_prompts, separator, img_answers], dim=3)
140 image_name = os.path.join(result_dir, filename)
141 torchvision.utils.save_image(
142 img.float() / 255.0, image_name, nrow=4, padding=margin * 4, pad_value=1.0
145 ######################################################################
147 def nb_token_values(self):
148 return len(self.colors)
150 def rec_coo(self, x):
152 i1, i2 = torch.randint(x.size(0), (2,))
156 j1, j2 = torch.randint(x.size(1), (2,))
159 return i1, j1, i2, j2
161 def task_red_to_green(self, A, f_A, B, f_B):
162 i1, j1, i2, j2 = self.rec_coo(A)
163 A[i1:i2, j1:j2] = self.name2color["red"]
164 f_A[i1:i2, j1:j2] = self.name2color["green"]
165 i1, j1, i2, j2 = self.rec_coo(B)
166 B[i1:i2, j1:j2] = self.name2color["red"]
167 f_B[i1:i2, j1:j2] = self.name2color["green"]
169 def generate_prompts_and_answers(self, nb):
170 prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64)
171 answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64)
173 for prompt, answer in zip(prompts, answers):
174 self.task_red_to_green(
175 prompt[:, 0 * w : 1 * w],
176 prompt[:, 1 * w : 2 * w],
177 prompt[:, 2 * w : 3 * w],
180 return prompts, answers
188 predicted_prompts=None,
189 predicted_answers=None,
193 filename_prefix + ".png",
201 ######################################################################
203 if __name__ == "__main__":
206 lang = Lang(nb_iterations=4)
208 prompts, answers = lang.generate_prompts_and_answers(24)
210 # predicted_prompts = torch.rand(prompts.size(0)) < 0.5
211 # predicted_answers = torch.rand(answers.size(0)) < 0.5
214 "/tmp", "test", prompts, answers # , predicted_prompts, predicted_answers
217 # start_time = time.perf_counter()
218 # token_sequences = lang.generate_token_sequences(nb=64)
219 # delay = time.perf_counter() - start_time
220 # print(f"{token_sequences.size(0)/delay:02f} seq/s")
222 # print(lang.seq2str(seq[:4]))
224 # for t in range(len(it[0])):
225 # img = torch.cat([lang.frame2img(f[t]) for f in it], dim=0)
226 # torchvision.utils.save_image(
227 # img.float() / 255.0,
228 # f"/tmp/frame_{t:03d}.png",
234 # m = (torch.rand(seq.size()) < 0.05).long()
235 # seq = (1 - m) * seq + m * 23
238 # img = lang.seq2img(token_sequences)
241 # torchvision.utils.save_image(
242 # img.float() / 255.0, "/tmp/world.png", nrow=6, padding=6, pad_value=0