43550d7520cfca2e3fa986fa2697daebb261e802
[culture.git] / lang.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, tqdm, os, warnings
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 ######################################################################
16
17 import problem
18
19
20 class Lang(problem.Problem):
21     named_colors = [
22         ("white", [255, 255, 255]),
23         ("red", [255, 0, 0]),
24         ("green", [0, 192, 0]),
25         ("blue", [0, 0, 255]),
26         ("orange", [255, 192, 0]),
27         ("cyan", [0, 255, 255]),
28         ("violet", [255, 0, 255]),
29         ("lightgreen", [192, 255, 192]),
30         ("pink", [255, 192, 192]),
31         ("lightblue", [192, 192, 255]),
32         ("gray", [192, 192, 192]),
33     ]
34
35     def __init__(
36         self,
37         nb_iterations=2,
38     ):
39         self.colors = torch.tensor([c for _, c in self.named_colors])
40         self.name2color = dict([(p[0], i) for i, p in enumerate(self.named_colors)])
41         self.height = 10
42         self.width = 10
43         self.nb_iterations = nb_iterations
44
45     ######################################################################
46
47     def frame2img(self, x, scale=15):
48         x = x.reshape(x.size(0), self.height, -1)
49         x = self.colors[x].permute(0, 3, 1, 2)
50         s = x.shape
51         x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
52         x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
53
54         x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
55         x[:, :, torch.arange(0, x.size(2), scale), :] = 0
56         x = x[:, :, 1:, 1:]
57
58         return x
59
60     def save_image(
61         self,
62         result_dir,
63         filename,
64         prompts,
65         answers,
66         predicted_prompts=None,
67         predicted_answers=None,
68     ):
69         prompts = prompts.reshape(prompts.size(0), self.height, -1)
70         answers = answers.reshape(answers.size(0), self.height, -1)
71
72         if predicted_prompts is None:
73             predicted_prompts = 255
74
75         if predicted_answers is None:
76             predicted_answers = 255
77
78         def add_frame(x, c, margin, bottom=False):
79             if bottom:
80                 h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0
81             else:
82                 h, w, di, dj = (
83                     x.size(2) + 2 * margin,
84                     x.size(3) + 2 * margin,
85                     margin,
86                     margin,
87                 )
88
89             y = x.new_full((x.size(0), x.size(1), h, w), 0)
90
91             if type(c) is int:
92                 y[...] = c
93             else:
94                 c = c.long()[:, None]
95                 c = c * torch.tensor([192, 192, 192], device=c.device) + (
96                     1 - c
97                 ) * torch.tensor([255, 255, 255], device=c.device)
98                 y[...] = c[:, :, None, None]
99
100             y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x
101
102             return y
103
104         margin = 8
105
106         img_prompts = torch.cat(
107             [
108                 add_frame(
109                     add_frame(self.frame2img(x), c=0, margin=1),
110                     c=predicted_prompts,
111                     margin=margin,
112                 )
113                 for x in prompts.to("cpu").split(split_size=self.width, dim=2)
114             ],
115             dim=3,
116         )
117
118         h = img_prompts.size(2)
119         img_answers = add_frame(
120             add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1),
121             c=predicted_answers,
122             margin=margin,
123         )
124
125         separator_size = 2 * margin
126
127         separator = img_prompts.new_full(
128             (
129                 img_prompts.size(0),
130                 img_prompts.size(1),
131                 img_prompts.size(2),
132                 separator_size,
133             ),
134             255,
135         )
136
137         marker = img_prompts.new_full(
138             (
139                 img_prompts.size(0),
140                 img_prompts.size(1),
141                 img_prompts.size(2),
142                 separator_size,
143             ),
144             255,
145         )
146
147         # marker[:, :, 0] = 0
148         # marker[:, :, h - 1] = 0
149
150         for k in range(1, 2 * separator_size - 8):
151             i = k - (separator_size - 4)
152             j = separator_size - 5 - abs(i)
153             marker[:, :, h // 2 - 1 + i, 2 + j] = 0
154             marker[:, :, h // 2 - 1 + i + 1, 2 + j] = 0
155
156         img = torch.cat(
157             [
158                 img_prompts,
159                 marker,
160                 img_answers,
161             ],
162             dim=3,
163         )
164
165         image_name = os.path.join(result_dir, filename)
166         torchvision.utils.save_image(
167             img.float() / 255.0, image_name, nrow=4, padding=margin * 4, pad_value=1.0
168         )
169
170     ######################################################################
171
172     def nb_token_values(self):
173         return len(self.colors)
174
175     def rec_coo(self, x, n, min_height=3, min_width=3):
176         while True:
177             collision = x.new_zeros(x.size())
178             result = []
179             for _ in range(n):
180                 while True:
181                     i1, i2 = torch.randint(x.size(0), (2,))
182                     if i1 + min_height <= i2:
183                         break
184                 while True:
185                     j1, j2 = torch.randint(x.size(1), (2,))
186                     if j1 + min_width <= j2:
187                         break
188                 collision[i1:i2, j1:j2] += 1
189                 if collision.max() > 1:
190                     break
191                 result.append((i1, j1, i2, j2))
192             if collision.max() == 1:
193                 break
194         return result
195
196     ######################################################################
197
198     def task_replace_color(self, A, f_A, B, f_B):
199         N = 3
200         c = torch.randperm(len(self.colors) - 1)[: N + 1] + 1
201         for X, f_X in [(A, f_A), (B, f_B)]:
202             r = self.rec_coo(X, N)
203             for n in range(N):
204                 i1, j1, i2, j2 = r[n]
205                 X[i1:i2, j1:j2] = c[n]
206                 f_X[i1:i2, j1:j2] = c[n if n > 0 else -1]
207
208     def task_move(self, A, f_A, B, f_B):
209         di, dj = torch.randint(2, (2,)) * 2 - 1
210         N = 3
211         c = torch.randperm(len(self.colors) - 1)[:N] + 1
212         for X, f_X in [(A, f_A), (B, f_B)]:
213             while True:
214                 r = self.rec_coo(X, N)
215                 i1, j1, i2, j2 = r[N - 1]
216                 if (
217                     i1 + di >= 0
218                     and i2 + di < X.size(0)
219                     and j1 + dj >= 0
220                     and j2 + dj < X.size(1)
221                 ):
222                     break
223
224             for n in range(N):
225                 i1, j1, i2, j2 = r[n]
226                 X[i1:i2, j1:j2] = c[n]
227                 if n == N - 1:
228                     f_X[i1 + di : i2 + di, j1 + dj : j2 + dj] = c[n]
229                 else:
230                     f_X[i1:i2, j1:j2] = c[n]
231
232     def task_grow(self, A, f_A, B, f_B):
233         di, dj = torch.randint(2, (2,)) * 2 - 1
234         N = 3
235         c = torch.randperm(len(self.colors) - 1)[:N] + 1
236         direction = torch.randint(2, (1,))
237         for X, f_X in [(A, f_A), (B, f_B)]:
238             while True:
239                 r = self.rec_coo(X, N)
240                 i1, j1, i2, j2 = r[N - 1]
241                 if i1 + 3 < i2 and j1 + 3 < j2:
242                     break
243
244             for n in range(N):
245                 i1, j1, i2, j2 = r[n]
246                 if n == N - 1:
247                     if direction == 0:
248                         X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
249                         f_X[i1:i2, j1:j2] = c[n]
250                     else:
251                         X[i1:i2, j1:j2] = c[n]
252                         f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
253                 else:
254                     X[i1:i2, j1:j2] = c[n]
255                     f_X[i1:i2, j1:j2] = c[n]
256
257     def task_color_grow(self, A, f_A, B, f_B):
258         di, dj = torch.randint(2, (2,)) * 2 - 1
259         N = 3
260         c = torch.randperm(len(self.colors) - 1)[: 2 * N] + 1
261         direction = torch.randint(2, (1,))
262         for X, f_X in [(A, f_A), (B, f_B)]:
263             r = self.rec_coo(X, N)
264             for n in range(N):
265                 i1, j1, i2, j2 = r[n]
266                 X[i1 : (i1 + i2) // 2, j1:j2] = c[2 * n]
267                 f_X[i1 : (i1 + i2) // 2, j1:j2] = c[2 * n]
268                 X[(i1 + i2) // 2 : (i1 + i2) // 2 + 1, j1:j2] = c[2 * n + 1]
269                 if n == N - 1:
270                     f_X[(i1 + i2) // 2 : i2, j1:j2] = c[2 * n + 1]
271                 else:
272                     f_X[(i1 + i2) // 2 : (i1 + i2) // 2 + 1, j1:j2] = c[2 * n + 1]
273
274     def task_frame(self, A, f_A, B, f_B):
275         N = 3
276         c = torch.randperm(len(self.colors) - 1)[: N + 1] + 1
277         for X, f_X in [(A, f_A), (B, f_B)]:
278             r = self.rec_coo(X, N)
279             for n in range(N):
280                 i1, j1, i2, j2 = r[n]
281                 X[i1:i2, j1:j2] = c[n]
282                 f_X[i1:i2, j1:j2] = c[n]
283                 if n == N - 1:
284                     f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = 0
285
286     ######################################################################
287
288     def generate_prompts_and_answers(self, nb):
289         tasks = [
290             self.task_replace_color,
291             self.task_move,
292             self.task_grow,
293             self.task_color_grow,
294             self.task_frame,
295         ]
296         prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64)
297         answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64)
298         w = self.width
299         for prompt, answer in zip(prompts, answers):
300             A = prompt[:, 0 * w : 1 * w]
301             f_A = prompt[:, 1 * w : 2 * w]
302             B = prompt[:, 2 * w : 3 * w]
303             f_B = answer
304             tasks[torch.randint(len(tasks), (1,))](A, f_A, B, f_B)
305         return prompts.flatten(1), answers.flatten(1)
306
307     def save_quizzes(
308         self,
309         result_dir,
310         filename_prefix,
311         prompts,
312         answers,
313         predicted_prompts=None,
314         predicted_answers=None,
315     ):
316         self.save_image(
317             result_dir,
318             filename_prefix + ".png",
319             prompts,
320             answers,
321             predicted_prompts,
322             predicted_answers,
323         )
324
325
326 ######################################################################
327
328 if __name__ == "__main__":
329     import time
330
331     lang = Lang(nb_iterations=4)
332
333     prompts, answers = lang.generate_prompts_and_answers(36)
334
335     # predicted_prompts = torch.rand(prompts.size(0)) < 0.5
336     # predicted_answers = torch.logical_not(predicted_prompts)
337
338     lang.save_quizzes(
339         "/tmp",
340         "test",
341         prompts,
342         answers,
343         # You can add a bool to put a frame around the predicted parts
344         # predicted_prompts, predicted_answers
345     )