b8d39ee0a08d8563946901e032709de65d1e1160
[culture.git] / reasoning.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, tqdm, os, warnings
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 ######################################################################
16
17 import problem
18
19
20 class Reasoning(problem.Problem):
21     named_colors = [
22         ("white", [255, 255, 255]),
23         ("red", [255, 0, 0]),
24         ("green", [0, 192, 0]),
25         ("blue", [0, 0, 255]),
26         ("orange", [255, 192, 0]),
27         ("cyan", [0, 255, 255]),
28         ("violet", [255, 0, 255]),
29         ("lightgreen", [192, 255, 192]),
30         ("pink", [255, 192, 192]),
31         ("lightblue", [192, 192, 255]),
32         ("gray", [192, 192, 192]),
33     ]
34
35     def __init__(self, device=torch.device("cpu")):
36         self.colors = torch.tensor([c for _, c in self.named_colors])
37         self.name2color = dict([(p[0], i) for i, p in enumerate(self.named_colors)])
38         self.height = 10
39         self.width = 10
40         self.device = device
41
42     ######################################################################
43
44     def frame2img(self, x, scale=15):
45         x = x.reshape(x.size(0), self.height, -1)
46         x = self.colors[x].permute(0, 3, 1, 2)
47         s = x.shape
48         x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
49         x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
50
51         x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
52         x[:, :, torch.arange(0, x.size(2), scale), :] = 0
53         x = x[:, :, 1:, 1:]
54
55         return x
56
57     def save_image(
58         self,
59         result_dir,
60         filename,
61         prompts,
62         answers,
63         predicted_prompts=None,
64         predicted_answers=None,
65     ):
66         prompts = prompts.reshape(prompts.size(0), self.height, -1)
67         answers = answers.reshape(answers.size(0), self.height, -1)
68
69         if predicted_prompts is None:
70             predicted_prompts = 255
71
72         if predicted_answers is None:
73             predicted_answers = 255
74
75         def add_frame(x, c, margin, bottom=False):
76             if bottom:
77                 h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0
78             else:
79                 h, w, di, dj = (
80                     x.size(2) + 2 * margin,
81                     x.size(3) + 2 * margin,
82                     margin,
83                     margin,
84                 )
85
86             y = x.new_full((x.size(0), x.size(1), h, w), 0)
87
88             if type(c) is int:
89                 y[...] = c
90             else:
91                 c = c.long()[:, None]
92                 c = c * torch.tensor([192, 192, 192], device=c.device) + (
93                     1 - c
94                 ) * torch.tensor([255, 255, 255], device=c.device)
95                 y[...] = c[:, :, None, None]
96
97             y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x
98
99             return y
100
101         margin = 8
102
103         img_prompts = torch.cat(
104             [
105                 add_frame(
106                     add_frame(self.frame2img(x), c=0, margin=1),
107                     c=predicted_prompts,
108                     margin=margin,
109                 )
110                 for x in prompts.to("cpu").split(split_size=self.width, dim=2)
111             ],
112             dim=3,
113         )
114
115         h = img_prompts.size(2)
116         img_answers = add_frame(
117             add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1),
118             c=predicted_answers,
119             margin=margin,
120         )
121
122         separator_size = 2 * margin
123
124         separator = img_prompts.new_full(
125             (
126                 img_prompts.size(0),
127                 img_prompts.size(1),
128                 img_prompts.size(2),
129                 separator_size,
130             ),
131             255,
132         )
133
134         marker = img_prompts.new_full(
135             (
136                 img_prompts.size(0),
137                 img_prompts.size(1),
138                 img_prompts.size(2),
139                 separator_size,
140             ),
141             255,
142         )
143
144         # marker[:, :, 0] = 0
145         # marker[:, :, h - 1] = 0
146
147         for k in range(1, 2 * separator_size - 8):
148             i = k - (separator_size - 4)
149             j = separator_size - 5 - abs(i)
150             marker[:, :, h // 2 - 1 + i, 2 + j] = 0
151             marker[:, :, h // 2 - 1 + i + 1, 2 + j] = 0
152
153         img = torch.cat(
154             [
155                 img_prompts,
156                 marker,
157                 img_answers,
158             ],
159             dim=3,
160         )
161
162         image_name = os.path.join(result_dir, filename)
163         torchvision.utils.save_image(
164             img.float() / 255.0, image_name, nrow=4, padding=margin * 4, pad_value=1.0
165         )
166
167     ######################################################################
168
169     def nb_token_values(self):
170         return len(self.colors)
171
172     def rec_coo(self, x, n, min_height=3, min_width=3):
173         K = 3
174         N = 4000
175
176         while True:
177             v = (
178                 (
179                     torch.rand(N * K, self.height + 1, device=self.device)
180                     .sort(dim=-1)
181                     .indices
182                     < 2
183                 )
184                 .long()
185                 .cumsum(dim=1)
186                 == 1
187             ).long()
188
189             h = (
190                 (
191                     torch.rand(N * K, self.width + 1, device=self.device)
192                     .sort(dim=-1)
193                     .indices
194                     < 2
195                 )
196                 .long()
197                 .cumsum(dim=1)
198                 == 1
199             ).long()
200
201             i = torch.logical_and(
202                 v.sum(dim=-1) >= min_height, h.sum(dim=-1) >= min_width
203             )
204
205             v, h = v[i], h[i]
206             v = v[: v.size(0) - v.size(0) % K]
207             h = h[: h.size(0) - h.size(0) % K]
208             v = v.reshape(v.size(0) // K, K, -1)
209             h = h.reshape(h.size(0) // K, K, -1)
210
211             r = v[:, :, :, None] * h[:, :, None, :]
212
213             valid = r.sum(dim=1).flatten(1).max(dim=-1).values == 1
214
215             v = v[valid]
216             h = h[valid]
217
218             if v.size(0) > 0:
219                 break
220
221         av = torch.arange(v.size(2), device=self.device)[None, :]
222         ah = torch.arange(h.size(2), device=self.device)[None, :]
223
224         return [
225             (i1.item(), j1.item(), i2.item() + 1, j2.item() + 1)
226             for i1, j1, i2, j2 in zip(
227                 v.size(2) - (v[0] * (v.size(2) - av)).max(dim=-1).values,
228                 h.size(2) - (h[0] * (h.size(2) - ah)).max(dim=-1).values,
229                 (v[0] * av).max(dim=-1).values,
230                 (h[0] * ah).max(dim=-1).values,
231             )
232         ]
233
234     def rec_coo_(self, x, n, min_height=3, min_width=3):
235         collision = x.new(x.size())
236         while True:
237             collision[...] = 0
238             result = []
239             for _ in range(n):
240                 while True:
241                     i1, i2 = torch.randint(x.size(0), (2,))
242                     if i1 + min_height <= i2:
243                         break
244                 while True:
245                     j1, j2 = torch.randint(x.size(1), (2,))
246                     if j1 + min_width <= j2:
247                         break
248                 collision[i1:i2, j1:j2] += 1
249                 if collision.max() > 1:
250                     break
251                 result.append((i1, j1, i2, j2))
252             if collision.max() == 1:
253                 break
254         return result
255
256     ######################################################################
257
258     def task_replace_color(self, A, f_A, B, f_B):
259         N = 3
260         c = torch.randperm(len(self.colors) - 1)[: N + 1] + 1
261         for X, f_X in [(A, f_A), (B, f_B)]:
262             r = self.rec_coo(X, N)
263             for n in range(N):
264                 i1, j1, i2, j2 = r[n]
265                 X[i1:i2, j1:j2] = c[n]
266                 f_X[i1:i2, j1:j2] = c[n if n > 0 else -1]
267
268     def task_move(self, A, f_A, B, f_B):
269         di, dj = torch.randint(2, (2,)) * 2 - 1
270         N = 3
271         c = torch.randperm(len(self.colors) - 1)[:N] + 1
272         for X, f_X in [(A, f_A), (B, f_B)]:
273             while True:
274                 r = self.rec_coo(X, N)
275                 i1, j1, i2, j2 = r[N - 1]
276                 if (
277                     i1 + di >= 0
278                     and i2 + di < X.size(0)
279                     and j1 + dj >= 0
280                     and j2 + dj < X.size(1)
281                 ):
282                     break
283
284             for n in range(N):
285                 i1, j1, i2, j2 = r[n]
286                 X[i1:i2, j1:j2] = c[n]
287                 if n == N - 1:
288                     f_X[i1 + di : i2 + di, j1 + dj : j2 + dj] = c[n]
289                 else:
290                     f_X[i1:i2, j1:j2] = c[n]
291
292     def task_grow(self, A, f_A, B, f_B):
293         di, dj = torch.randint(2, (2,)) * 2 - 1
294         N = 3
295         c = torch.randperm(len(self.colors) - 1)[:N] + 1
296         direction = torch.randint(2, (1,))
297         for X, f_X in [(A, f_A), (B, f_B)]:
298             while True:
299                 r = self.rec_coo(X, N)
300                 i1, j1, i2, j2 = r[N - 1]
301                 if i1 + 3 < i2 and j1 + 3 < j2:
302                     break
303
304             for n in range(N):
305                 i1, j1, i2, j2 = r[n]
306                 if n == N - 1:
307                     if direction == 0:
308                         X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
309                         f_X[i1:i2, j1:j2] = c[n]
310                     else:
311                         X[i1:i2, j1:j2] = c[n]
312                         f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
313                 else:
314                     X[i1:i2, j1:j2] = c[n]
315                     f_X[i1:i2, j1:j2] = c[n]
316
317     def task_color_grow(self, A, f_A, B, f_B):
318         di, dj = torch.randint(2, (2,)) * 2 - 1
319         N = 3
320         c = torch.randperm(len(self.colors) - 1)[: 2 * N] + 1
321         direction = torch.randint(2, (1,))
322         for X, f_X in [(A, f_A), (B, f_B)]:
323             r = self.rec_coo(X, N)
324             for n in range(N):
325                 i1, j1, i2, j2 = r[n]
326                 i = (i1 + i2) // 2
327                 X[i1:i2, j1:j2] = c[2 * n]
328                 X[i : i + 1, j1:j2] = c[2 * n + 1]
329                 f_X[i1:i2, j1:j2] = c[2 * n]
330                 if n == N - 1:
331                     f_X[i:i2, j1:j2] = c[2 * n + 1]
332                 else:
333                     f_X[i : i + 1, j1:j2] = c[2 * n + 1]
334
335     def task_frame(self, A, f_A, B, f_B):
336         N = 3
337         c = torch.randperm(len(self.colors) - 1)[: N + 1] + 1
338         for X, f_X in [(A, f_A), (B, f_B)]:
339             r = self.rec_coo(X, N)
340             for n in range(N):
341                 i1, j1, i2, j2 = r[n]
342                 X[i1:i2, j1:j2] = c[n]
343                 f_X[i1:i2, j1:j2] = c[n]
344                 if n == N - 1:
345                     f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = 0
346
347     def task_detect(self, A, f_A, B, f_B):
348         N = 3
349         c = torch.randperm(len(self.colors) - 1)[: N + 1] + 1
350         for X, f_X in [(A, f_A), (B, f_B)]:
351             r = self.rec_coo(X, N)
352             for n in range(N):
353                 i1, j1, i2, j2 = r[n]
354                 X[i1:i2, j1:j2] = c[n]
355                 f_X[i1, j1] = c[-1]
356
357     ######################################################################
358
359     def generate_prompts_and_answers(self, nb, device="cpu"):
360         tasks = [
361             self.task_replace_color,
362             self.task_move,
363             self.task_grow,
364             self.task_color_grow,
365             self.task_frame,
366             self.task_detect,
367         ]
368         prompts = torch.zeros(
369             nb, self.height, self.width * 3, dtype=torch.int64, device=self.device
370         )
371         answers = torch.zeros(
372             nb, self.height, self.width, dtype=torch.int64, device=self.device
373         )
374         w = self.width
375
376         for prompt, answer in tqdm.tqdm(
377             zip(prompts, answers),
378             dynamic_ncols=True,
379             desc="world generation",
380             total=prompts.size(0),
381         ):
382             A = prompt[:, 0 * w : 1 * w]
383             f_A = prompt[:, 1 * w : 2 * w]
384             B = prompt[:, 2 * w : 3 * w]
385             f_B = answer
386             task = tasks[torch.randint(len(tasks), (1,))]
387             task(A, f_A, B, f_B)
388         return prompts.flatten(1), answers.flatten(1)
389
390     def save_quizzes(
391         self,
392         result_dir,
393         filename_prefix,
394         prompts,
395         answers,
396         predicted_prompts=None,
397         predicted_answers=None,
398     ):
399         self.save_image(
400             result_dir,
401             filename_prefix + ".png",
402             prompts,
403             answers,
404             predicted_prompts,
405             predicted_answers,
406         )
407
408
409 ######################################################################
410
411 if __name__ == "__main__":
412     import time
413
414     reasoning = Reasoning()
415
416     start_time = time.perf_counter()
417     prompts, answers = reasoning.generate_prompts_and_answers(100)
418     delay = time.perf_counter() - start_time
419     print(f"{prompts.size(0)/delay:02f} seq/s")
420
421     # predicted_prompts = torch.rand(prompts.size(0)) < 0.5
422     # predicted_answers = torch.logical_not(predicted_prompts)
423
424     reasoning.save_quizzes(
425         "/tmp",
426         "test",
427         prompts[:36],
428         answers[:36],
429         # You can add a bool to put a frame around the predicted parts
430         # predicted_prompts, predicted_answers
431     )