92699e82f50f06e9edcd156a84d1e02366e8ace0
[culture.git] / reasoning.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, tqdm, os, warnings
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 ######################################################################
16
17 import problem
18
19
20 class Reasoning(problem.Problem):
21     named_colors = [
22         ("white", [255, 255, 255]),
23         ("red", [255, 0, 0]),
24         ("green", [0, 192, 0]),
25         ("blue", [0, 0, 255]),
26         ("orange", [255, 192, 0]),
27         ("cyan", [0, 255, 255]),
28         ("violet", [255, 0, 255]),
29         ("lightgreen", [192, 255, 192]),
30         ("pink", [255, 192, 192]),
31         ("lightblue", [192, 192, 255]),
32         ("gray", [192, 192, 192]),
33     ]
34
35     def __init__(
36         self,
37     ):
38         self.colors = torch.tensor([c for _, c in self.named_colors])
39         self.name2color = dict([(p[0], i) for i, p in enumerate(self.named_colors)])
40         self.height = 10
41         self.width = 10
42
43     ######################################################################
44
45     def frame2img(self, x, scale=15):
46         x = x.reshape(x.size(0), self.height, -1)
47         x = self.colors[x].permute(0, 3, 1, 2)
48         s = x.shape
49         x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
50         x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
51
52         x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
53         x[:, :, torch.arange(0, x.size(2), scale), :] = 0
54         x = x[:, :, 1:, 1:]
55
56         return x
57
58     def save_image(
59         self,
60         result_dir,
61         filename,
62         prompts,
63         answers,
64         predicted_prompts=None,
65         predicted_answers=None,
66     ):
67         prompts = prompts.reshape(prompts.size(0), self.height, -1)
68         answers = answers.reshape(answers.size(0), self.height, -1)
69
70         if predicted_prompts is None:
71             predicted_prompts = 255
72
73         if predicted_answers is None:
74             predicted_answers = 255
75
76         def add_frame(x, c, margin, bottom=False):
77             if bottom:
78                 h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0
79             else:
80                 h, w, di, dj = (
81                     x.size(2) + 2 * margin,
82                     x.size(3) + 2 * margin,
83                     margin,
84                     margin,
85                 )
86
87             y = x.new_full((x.size(0), x.size(1), h, w), 0)
88
89             if type(c) is int:
90                 y[...] = c
91             else:
92                 c = c.long()[:, None]
93                 c = c * torch.tensor([192, 192, 192], device=c.device) + (
94                     1 - c
95                 ) * torch.tensor([255, 255, 255], device=c.device)
96                 y[...] = c[:, :, None, None]
97
98             y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x
99
100             return y
101
102         margin = 8
103
104         img_prompts = torch.cat(
105             [
106                 add_frame(
107                     add_frame(self.frame2img(x), c=0, margin=1),
108                     c=predicted_prompts,
109                     margin=margin,
110                 )
111                 for x in prompts.to("cpu").split(split_size=self.width, dim=2)
112             ],
113             dim=3,
114         )
115
116         h = img_prompts.size(2)
117         img_answers = add_frame(
118             add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1),
119             c=predicted_answers,
120             margin=margin,
121         )
122
123         separator_size = 2 * margin
124
125         separator = img_prompts.new_full(
126             (
127                 img_prompts.size(0),
128                 img_prompts.size(1),
129                 img_prompts.size(2),
130                 separator_size,
131             ),
132             255,
133         )
134
135         marker = img_prompts.new_full(
136             (
137                 img_prompts.size(0),
138                 img_prompts.size(1),
139                 img_prompts.size(2),
140                 separator_size,
141             ),
142             255,
143         )
144
145         # marker[:, :, 0] = 0
146         # marker[:, :, h - 1] = 0
147
148         for k in range(1, 2 * separator_size - 8):
149             i = k - (separator_size - 4)
150             j = separator_size - 5 - abs(i)
151             marker[:, :, h // 2 - 1 + i, 2 + j] = 0
152             marker[:, :, h // 2 - 1 + i + 1, 2 + j] = 0
153
154         img = torch.cat(
155             [
156                 img_prompts,
157                 marker,
158                 img_answers,
159             ],
160             dim=3,
161         )
162
163         image_name = os.path.join(result_dir, filename)
164         torchvision.utils.save_image(
165             img.float() / 255.0, image_name, nrow=4, padding=margin * 4, pad_value=1.0
166         )
167
168     ######################################################################
169
170     def nb_token_values(self):
171         return len(self.colors)
172
173     def rec_coo(self, x, n, min_height=3, min_width=3):
174         collision = x.new(x.size())
175         while True:
176             collision[...] = 0
177             result = []
178             for _ in range(n):
179                 while True:
180                     i1, i2 = torch.randint(x.size(0), (2,))
181                     if i1 + min_height <= i2:
182                         break
183                 while True:
184                     j1, j2 = torch.randint(x.size(1), (2,))
185                     if j1 + min_width <= j2:
186                         break
187                 collision[i1:i2, j1:j2] += 1
188                 if collision.max() > 1:
189                     break
190                 result.append((i1, j1, i2, j2))
191             if collision.max() == 1:
192                 break
193         return result
194
195     ######################################################################
196
197     def task_replace_color(self, A, f_A, B, f_B):
198         N = 3
199         c = torch.randperm(len(self.colors) - 1)[: N + 1] + 1
200         for X, f_X in [(A, f_A), (B, f_B)]:
201             r = self.rec_coo(X, N)
202             for n in range(N):
203                 i1, j1, i2, j2 = r[n]
204                 X[i1:i2, j1:j2] = c[n]
205                 f_X[i1:i2, j1:j2] = c[n if n > 0 else -1]
206
207     def task_move(self, A, f_A, B, f_B):
208         di, dj = torch.randint(2, (2,)) * 2 - 1
209         N = 3
210         c = torch.randperm(len(self.colors) - 1)[:N] + 1
211         for X, f_X in [(A, f_A), (B, f_B)]:
212             while True:
213                 r = self.rec_coo(X, N)
214                 i1, j1, i2, j2 = r[N - 1]
215                 if (
216                     i1 + di >= 0
217                     and i2 + di < X.size(0)
218                     and j1 + dj >= 0
219                     and j2 + dj < X.size(1)
220                 ):
221                     break
222
223             for n in range(N):
224                 i1, j1, i2, j2 = r[n]
225                 X[i1:i2, j1:j2] = c[n]
226                 if n == N - 1:
227                     f_X[i1 + di : i2 + di, j1 + dj : j2 + dj] = c[n]
228                 else:
229                     f_X[i1:i2, j1:j2] = c[n]
230
231     def task_grow(self, A, f_A, B, f_B):
232         di, dj = torch.randint(2, (2,)) * 2 - 1
233         N = 3
234         c = torch.randperm(len(self.colors) - 1)[:N] + 1
235         direction = torch.randint(2, (1,))
236         for X, f_X in [(A, f_A), (B, f_B)]:
237             while True:
238                 r = self.rec_coo(X, N)
239                 i1, j1, i2, j2 = r[N - 1]
240                 if i1 + 3 < i2 and j1 + 3 < j2:
241                     break
242
243             for n in range(N):
244                 i1, j1, i2, j2 = r[n]
245                 if n == N - 1:
246                     if direction == 0:
247                         X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
248                         f_X[i1:i2, j1:j2] = c[n]
249                     else:
250                         X[i1:i2, j1:j2] = c[n]
251                         f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
252                 else:
253                     X[i1:i2, j1:j2] = c[n]
254                     f_X[i1:i2, j1:j2] = c[n]
255
256     def task_color_grow(self, A, f_A, B, f_B):
257         di, dj = torch.randint(2, (2,)) * 2 - 1
258         N = 3
259         c = torch.randperm(len(self.colors) - 1)[: 2 * N] + 1
260         direction = torch.randint(2, (1,))
261         for X, f_X in [(A, f_A), (B, f_B)]:
262             r = self.rec_coo(X, N)
263             for n in range(N):
264                 i1, j1, i2, j2 = r[n]
265                 i = (i1 + i2) // 2
266                 X[i1:i2, j1:j2] = c[2 * n]
267                 X[i : i + 1, j1:j2] = c[2 * n + 1]
268                 f_X[i1:i2, j1:j2] = c[2 * n]
269                 if n == N - 1:
270                     f_X[i:i2, j1:j2] = c[2 * n + 1]
271                 else:
272                     f_X[i : i + 1, j1:j2] = c[2 * n + 1]
273
274     def task_frame(self, A, f_A, B, f_B):
275         N = 3
276         c = torch.randperm(len(self.colors) - 1)[: N + 1] + 1
277         for X, f_X in [(A, f_A), (B, f_B)]:
278             r = self.rec_coo(X, N)
279             for n in range(N):
280                 i1, j1, i2, j2 = r[n]
281                 X[i1:i2, j1:j2] = c[n]
282                 f_X[i1:i2, j1:j2] = c[n]
283                 if n == N - 1:
284                     f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = 0
285
286     def task_detect(self, A, f_A, B, f_B):
287         N = 3
288         c = torch.randperm(len(self.colors) - 1)[: N + 1] + 1
289         for X, f_X in [(A, f_A), (B, f_B)]:
290             r = self.rec_coo(X, N)
291             for n in range(N):
292                 i1, j1, i2, j2 = r[n]
293                 X[i1:i2, j1:j2] = c[n]
294                 f_X[i1, j1] = c[-1]
295
296     ######################################################################
297
298     def generate_prompts_and_answers(self, nb):
299         tasks = [
300             self.task_replace_color,
301             self.task_move,
302             self.task_grow,
303             self.task_color_grow,
304             self.task_frame,
305             self.task_detect,
306         ]
307         prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64)
308         answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64)
309         w = self.width
310         for prompt, answer in zip(prompts, answers):
311             A = prompt[:, 0 * w : 1 * w]
312             f_A = prompt[:, 1 * w : 2 * w]
313             B = prompt[:, 2 * w : 3 * w]
314             f_B = answer
315             task = tasks[torch.randint(len(tasks), (1,))]
316             task(A, f_A, B, f_B)
317         return prompts.flatten(1), answers.flatten(1)
318
319     def save_quizzes(
320         self,
321         result_dir,
322         filename_prefix,
323         prompts,
324         answers,
325         predicted_prompts=None,
326         predicted_answers=None,
327     ):
328         self.save_image(
329             result_dir,
330             filename_prefix + ".png",
331             prompts,
332             answers,
333             predicted_prompts,
334             predicted_answers,
335         )
336
337
338 ######################################################################
339
340 if __name__ == "__main__":
341     import time
342
343     reasoning = Reasoning()
344
345     start_time = time.perf_counter()
346     prompts, answers = reasoning.generate_prompts_and_answers(100)
347     delay = time.perf_counter() - start_time
348     print(f"{prompts.size(0)/delay:02f} seq/s")
349
350     # predicted_prompts = torch.rand(prompts.size(0)) < 0.5
351     # predicted_answers = torch.logical_not(predicted_prompts)
352
353     reasoning.save_quizzes(
354         "/tmp",
355         "test",
356         prompts[:36],
357         answers[:36],
358         # You can add a bool to put a frame around the predicted parts
359         # predicted_prompts, predicted_answers
360     )