Update.
[culture.git] / reasoning.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, tqdm, os, warnings
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 ######################################################################
16
17 import problem
18
19
20 class Reasoning(problem.Problem):
21     named_colors = [
22         ("white", [255, 255, 255]),
23         ("red", [255, 0, 0]),
24         ("green", [0, 192, 0]),
25         ("blue", [0, 0, 255]),
26         ("orange", [255, 192, 0]),
27         ("cyan", [0, 255, 255]),
28         ("violet", [255, 0, 255]),
29         ("lightgreen", [192, 255, 192]),
30         ("pink", [255, 192, 192]),
31         ("lightblue", [192, 192, 255]),
32         ("gray", [192, 192, 192]),
33     ]
34
35     def __init__(self, device=torch.device("cpu")):
36         self.colors = torch.tensor([c for _, c in self.named_colors])
37         self.name2color = dict([(p[0], i) for i, p in enumerate(self.named_colors)])
38         self.height = 10
39         self.width = 10
40         self.device = device
41
42     ######################################################################
43
44     def frame2img(self, x, scale=15):
45         x = x.reshape(x.size(0), self.height, -1)
46         x = self.colors[x].permute(0, 3, 1, 2)
47         s = x.shape
48         x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
49         x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
50
51         x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
52         x[:, :, torch.arange(0, x.size(2), scale), :] = 0
53         x = x[:, :, 1:, 1:]
54
55         return x
56
57     def save_image(
58         self,
59         result_dir,
60         filename,
61         prompts,
62         answers,
63         predicted_prompts=None,
64         predicted_answers=None,
65     ):
66         prompts = prompts.reshape(prompts.size(0), self.height, -1)
67         answers = answers.reshape(answers.size(0), self.height, -1)
68
69         if predicted_prompts is None:
70             predicted_prompts = 255
71
72         if predicted_answers is None:
73             predicted_answers = 255
74
75         def add_frame(x, c, margin, bottom=False):
76             if bottom:
77                 h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0
78             else:
79                 h, w, di, dj = (
80                     x.size(2) + 2 * margin,
81                     x.size(3) + 2 * margin,
82                     margin,
83                     margin,
84                 )
85
86             y = x.new_full((x.size(0), x.size(1), h, w), 0)
87
88             if type(c) is int:
89                 y[...] = c
90             else:
91                 c = c.long()[:, None]
92                 c = c * torch.tensor([192, 192, 192], device=c.device) + (
93                     1 - c
94                 ) * torch.tensor([255, 255, 255], device=c.device)
95                 y[...] = c[:, :, None, None]
96
97             y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x
98
99             return y
100
101         margin = 8
102
103         img_prompts = torch.cat(
104             [
105                 add_frame(
106                     add_frame(self.frame2img(x), c=0, margin=1),
107                     c=predicted_prompts,
108                     margin=margin,
109                 )
110                 for x in prompts.to("cpu").split(split_size=self.width, dim=2)
111             ],
112             dim=3,
113         )
114
115         h = img_prompts.size(2)
116         img_answers = add_frame(
117             add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1),
118             c=predicted_answers,
119             margin=margin,
120         )
121
122         separator_size = 2 * margin
123
124         separator = img_prompts.new_full(
125             (
126                 img_prompts.size(0),
127                 img_prompts.size(1),
128                 img_prompts.size(2),
129                 separator_size,
130             ),
131             255,
132         )
133
134         marker = img_prompts.new_full(
135             (
136                 img_prompts.size(0),
137                 img_prompts.size(1),
138                 img_prompts.size(2),
139                 separator_size,
140             ),
141             255,
142         )
143
144         # marker[:, :, 0] = 0
145         # marker[:, :, h - 1] = 0
146
147         for k in range(1, 2 * separator_size - 8):
148             i = k - (separator_size - 4)
149             j = separator_size - 5 - abs(i)
150             marker[:, :, h // 2 - 1 + i, 2 + j] = 0
151             marker[:, :, h // 2 - 1 + i + 1, 2 + j] = 0
152
153         img = torch.cat(
154             [
155                 img_prompts,
156                 marker,
157                 img_answers,
158             ],
159             dim=3,
160         )
161
162         image_name = os.path.join(result_dir, filename)
163         torchvision.utils.save_image(
164             img.float() / 255.0, image_name, nrow=4, padding=margin * 4, pad_value=1.0
165         )
166
167     ######################################################################
168
169     def nb_token_values(self):
170         return len(self.colors)
171
172     # That's quite a tensorial spaghetti mess to sample
173     # non-overlapping rectangles quickly, but made the generation of
174     # 100k samples from 1h50 with a lame pure python code to 4min with
175     # this one.
176     def rec_coo(self, x, n, min_height=3, min_width=3):
177         K = 3
178         N = 1000
179
180         while True:
181             v = (
182                 (
183                     torch.rand(N * K, self.height + 1, device=self.device)
184                     .sort(dim=-1)
185                     .indices
186                     < 2
187                 )
188                 .long()
189                 .cumsum(dim=1)
190                 == 1
191             ).long()
192
193             h = (
194                 (
195                     torch.rand(N * K, self.width + 1, device=self.device)
196                     .sort(dim=-1)
197                     .indices
198                     < 2
199                 )
200                 .long()
201                 .cumsum(dim=1)
202                 == 1
203             ).long()
204
205             i = torch.logical_and(
206                 v.sum(dim=-1) >= min_height, h.sum(dim=-1) >= min_width
207             )
208
209             v, h = v[i], h[i]
210             v = v[: v.size(0) - v.size(0) % K]
211             h = h[: h.size(0) - h.size(0) % K]
212             v = v.reshape(v.size(0) // K, K, -1)
213             h = h.reshape(h.size(0) // K, K, -1)
214
215             r = v[:, :, :, None] * h[:, :, None, :]
216
217             valid = r.sum(dim=1).flatten(1).max(dim=-1).values == 1
218
219             v = v[valid]
220             h = h[valid]
221
222             if v.size(0) > 0:
223                 break
224
225         av = torch.arange(v.size(2), device=self.device)[None, :]
226         ah = torch.arange(h.size(2), device=self.device)[None, :]
227
228         return [
229             (i1.item(), j1.item(), i2.item() + 1, j2.item() + 1)
230             for i1, j1, i2, j2 in zip(
231                 v.size(2) - (v[0] * (v.size(2) - av)).max(dim=-1).values,
232                 h.size(2) - (h[0] * (h.size(2) - ah)).max(dim=-1).values,
233                 (v[0] * av).max(dim=-1).values,
234                 (h[0] * ah).max(dim=-1).values,
235             )
236         ]
237
238     def rec_coo_(self, x, n, min_height=3, min_width=3):
239         collision = x.new(x.size())
240         while True:
241             collision[...] = 0
242             result = []
243             for _ in range(n):
244                 while True:
245                     i1, i2 = torch.randint(x.size(0), (2,))
246                     if i1 + min_height <= i2:
247                         break
248                 while True:
249                     j1, j2 = torch.randint(x.size(1), (2,))
250                     if j1 + min_width <= j2:
251                         break
252                 collision[i1:i2, j1:j2] += 1
253                 if collision.max() > 1:
254                     break
255                 result.append((i1, j1, i2, j2))
256             if collision.max() == 1:
257                 break
258         return result
259
260     ######################################################################
261
262     def task_replace_color(self, A, f_A, B, f_B):
263         N = 3
264         c = torch.randperm(len(self.colors) - 1)[: N + 1] + 1
265         for X, f_X in [(A, f_A), (B, f_B)]:
266             r = self.rec_coo(X, N)
267             for n in range(N):
268                 i1, j1, i2, j2 = r[n]
269                 X[i1:i2, j1:j2] = c[n]
270                 f_X[i1:i2, j1:j2] = c[n if n > 0 else -1]
271
272     def task_move(self, A, f_A, B, f_B):
273         di, dj = torch.randint(2, (2,)) * 2 - 1
274         N = 3
275         c = torch.randperm(len(self.colors) - 1)[:N] + 1
276         for X, f_X in [(A, f_A), (B, f_B)]:
277             while True:
278                 r = self.rec_coo(X, N)
279                 i1, j1, i2, j2 = r[N - 1]
280                 if (
281                     i1 + di >= 0
282                     and i2 + di < X.size(0)
283                     and j1 + dj >= 0
284                     and j2 + dj < X.size(1)
285                 ):
286                     break
287
288             for n in range(N):
289                 i1, j1, i2, j2 = r[n]
290                 X[i1:i2, j1:j2] = c[n]
291                 if n == N - 1:
292                     f_X[i1 + di : i2 + di, j1 + dj : j2 + dj] = c[n]
293                 else:
294                     f_X[i1:i2, j1:j2] = c[n]
295
296     def task_grow(self, A, f_A, B, f_B):
297         di, dj = torch.randint(2, (2,)) * 2 - 1
298         N = 3
299         c = torch.randperm(len(self.colors) - 1)[:N] + 1
300         direction = torch.randint(2, (1,))
301         for X, f_X in [(A, f_A), (B, f_B)]:
302             while True:
303                 r = self.rec_coo(X, N)
304                 i1, j1, i2, j2 = r[N - 1]
305                 if i1 + 3 < i2 and j1 + 3 < j2:
306                     break
307
308             for n in range(N):
309                 i1, j1, i2, j2 = r[n]
310                 if n == N - 1:
311                     if direction == 0:
312                         X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
313                         f_X[i1:i2, j1:j2] = c[n]
314                     else:
315                         X[i1:i2, j1:j2] = c[n]
316                         f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
317                 else:
318                     X[i1:i2, j1:j2] = c[n]
319                     f_X[i1:i2, j1:j2] = c[n]
320
321     def task_color_grow(self, A, f_A, B, f_B):
322         di, dj = torch.randint(2, (2,)) * 2 - 1
323         N = 3
324         c = torch.randperm(len(self.colors) - 1)[: 2 * N] + 1
325         direction = torch.randint(2, (1,))
326         for X, f_X in [(A, f_A), (B, f_B)]:
327             r = self.rec_coo(X, N)
328             for n in range(N):
329                 i1, j1, i2, j2 = r[n]
330                 i = (i1 + i2) // 2
331                 X[i1:i2, j1:j2] = c[2 * n]
332                 X[i : i + 1, j1:j2] = c[2 * n + 1]
333                 f_X[i1:i2, j1:j2] = c[2 * n]
334                 if n == N - 1:
335                     f_X[i:i2, j1:j2] = c[2 * n + 1]
336                 else:
337                     f_X[i : i + 1, j1:j2] = c[2 * n + 1]
338
339     def task_frame(self, A, f_A, B, f_B):
340         N = 3
341         c = torch.randperm(len(self.colors) - 1)[: N + 1] + 1
342         for X, f_X in [(A, f_A), (B, f_B)]:
343             r = self.rec_coo(X, N)
344             for n in range(N):
345                 i1, j1, i2, j2 = r[n]
346                 X[i1:i2, j1:j2] = c[n]
347                 f_X[i1:i2, j1:j2] = c[n]
348                 if n == N - 1:
349                     f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = 0
350
351     def task_detect(self, A, f_A, B, f_B):
352         N = 3
353         c = torch.randperm(len(self.colors) - 1)[: N + 1] + 1
354         for X, f_X in [(A, f_A), (B, f_B)]:
355             r = self.rec_coo(X, N)
356             for n in range(N):
357                 i1, j1, i2, j2 = r[n]
358                 X[i1:i2, j1:j2] = c[n]
359                 f_X[i1, j1] = c[-1]
360
361     ######################################################################
362
363     def generate_prompts_and_answers(self, nb, device="cpu"):
364         tasks = [
365             self.task_replace_color,
366             self.task_move,
367             self.task_grow,
368             self.task_color_grow,
369             self.task_frame,
370             self.task_detect,
371         ]
372         prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64)
373         answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64)
374         w = self.width
375
376         for prompt, answer in tqdm.tqdm(
377             zip(prompts, answers),
378             dynamic_ncols=True,
379             desc="world generation",
380             total=prompts.size(0),
381         ):
382             A = prompt[:, 0 * w : 1 * w]
383             f_A = prompt[:, 1 * w : 2 * w]
384             B = prompt[:, 2 * w : 3 * w]
385             f_B = answer
386             task = tasks[torch.randint(len(tasks), (1,))]
387             task(A, f_A, B, f_B)
388
389         return prompts.flatten(1), answers.flatten(1)
390
391     def save_quizzes(
392         self,
393         result_dir,
394         filename_prefix,
395         prompts,
396         answers,
397         predicted_prompts=None,
398         predicted_answers=None,
399     ):
400         self.save_image(
401             result_dir,
402             filename_prefix + ".png",
403             prompts,
404             answers,
405             predicted_prompts,
406             predicted_answers,
407         )
408
409
410 ######################################################################
411
412 if __name__ == "__main__":
413     import time
414
415     reasoning = Reasoning()
416
417     start_time = time.perf_counter()
418     prompts, answers = reasoning.generate_prompts_and_answers(100)
419     delay = time.perf_counter() - start_time
420     print(f"{prompts.size(0)/delay:02f} seq/s")
421
422     # predicted_prompts = torch.rand(prompts.size(0)) < 0.5
423     # predicted_answers = torch.logical_not(predicted_prompts)
424
425     reasoning.save_quizzes(
426         "/tmp",
427         "test",
428         prompts[:36],
429         answers[:36],
430         # You can add a bool to put a frame around the predicted parts
431         # predicted_prompts, predicted_answers
432     )