Update.
[culture.git] / grids.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, tqdm, os, warnings
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 ######################################################################
16
17 import problem
18
19
20 class Grids(problem.Problem):
21     named_colors = [
22         ("white", [255, 255, 255]),
23         ("red", [255, 0, 0]),
24         ("green", [0, 192, 0]),
25         ("blue", [0, 0, 255]),
26         ("yellow", [255, 224, 0]),
27         ("cyan", [0, 255, 255]),
28         ("violet", [224, 128, 255]),
29         ("lightgreen", [192, 255, 192]),
30         ("brown", [165, 42, 42]),
31         ("lightblue", [192, 192, 255]),
32         ("gray", [128, 128, 128]),
33     ]
34
35     def __init__(
36         self,
37         max_nb_cached_chunks=None,
38         chunk_size=None,
39         nb_threads=-1,
40         tasks=None,
41     ):
42         self.colors = torch.tensor([c for _, c in self.named_colors])
43         self.height = 10
44         self.width = 10
45         self.cache_rec_coo = {}
46
47         all_tasks = [
48             self.task_replace_color,
49             self.task_translate,
50             self.task_grow,
51             self.task_half_fill,
52             self.task_frame,
53             self.task_detect,
54             self.task_count,
55             self.task_trajectory,
56             self.task_bounce,
57             self.task_scale,
58             self.task_symbols,
59             self.task_isometry,
60             #            self.task_path,
61         ]
62
63         if tasks is None:
64             self.all_tasks = all_tasks
65         else:
66             self.all_tasks = [getattr(self, "task_" + t) for t in tasks.split(",")]
67
68         super().__init__(max_nb_cached_chunks, chunk_size, nb_threads)
69
70     ######################################################################
71
72     def frame2img(self, x, scale=15):
73         x = x.reshape(x.size(0), self.height, -1)
74         m = torch.logical_and(x >= 0, x < self.nb_token_values()).long()
75         x = self.colors[x * m].permute(0, 3, 1, 2)
76         s = x.shape
77         x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
78         x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
79
80         x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
81         x[:, :, torch.arange(0, x.size(2), scale), :] = 0
82         x = x[:, :, 1:, 1:]
83
84         for n in range(m.size(0)):
85             for i in range(m.size(1)):
86                 for j in range(m.size(2)):
87                     if m[n, i, j] == 0:
88                         for k in range(2, scale - 2):
89                             for l in [0, 1]:
90                                 x[n, :, i * scale + k, j * scale + k - l] = 0
91                                 x[
92                                     n, :, i * scale + scale - 1 - k, j * scale + k - l
93                                 ] = 0
94
95         return x
96
97     def save_image(
98         self,
99         result_dir,
100         filename,
101         prompts,
102         answers,
103         predicted_prompts=None,
104         predicted_answers=None,
105         nrow=4,
106         margin=8,
107     ):
108         S = self.height * self.width
109         As = prompts[:, 0 * (S + 1) : 0 * (S + 1) + S].view(-1, self.height, self.width)
110         f_As = prompts[:, 1 * (S + 1) : 1 * (S + 1) + S].view(
111             -1, self.height, self.width
112         )
113         Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S].view(-1, self.height, self.width)
114         prompts = torch.cat([As, f_As, Bs], dim=2)
115         answers = answers.reshape(answers.size(0), self.height, self.width)
116
117         if predicted_prompts is None:
118             predicted_prompts = 255
119
120         if predicted_answers is None:
121             predicted_answers = 255
122
123         def add_frame(x, c, margin, bottom=False):
124             if bottom:
125                 h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0
126             else:
127                 h, w, di, dj = (
128                     x.size(2) + 2 * margin,
129                     x.size(3) + 2 * margin,
130                     margin,
131                     margin,
132                 )
133
134             y = x.new_full((x.size(0), x.size(1), h, w), 0)
135
136             if type(c) is int:
137                 y[...] = c
138             else:
139                 c = c.long()[:, None]
140                 c = (
141                     (1 - ((c == 1).long() + (c == 0).long() + (c == -1).long()))
142                     * torch.tensor([64, 64, 64])
143                     + (c == 1).long() * torch.tensor([0, 255, 0])
144                     + (c == 0).long() * torch.tensor([255, 255, 255])
145                     + (c == -1).long() * torch.tensor([255, 0, 0])
146                 )
147                 y[...] = c[:, :, None, None]
148
149             y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x
150
151             return y
152
153         img_prompts = torch.cat(
154             [
155                 add_frame(
156                     add_frame(self.frame2img(x), c=0, margin=1),
157                     c=predicted_prompts,
158                     margin=margin,
159                 )
160                 for x in prompts.to("cpu").split(split_size=self.width, dim=2)
161             ],
162             dim=3,
163         )
164
165         h = img_prompts.size(2)
166         img_answers = add_frame(
167             add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1),
168             c=predicted_answers,
169             margin=margin,
170         )
171
172         separator_size = 2 * margin
173
174         separator = img_prompts.new_full(
175             (
176                 img_prompts.size(0),
177                 img_prompts.size(1),
178                 img_prompts.size(2),
179                 separator_size,
180             ),
181             255,
182         )
183
184         marker = img_prompts.new_full(
185             (
186                 img_prompts.size(0),
187                 img_prompts.size(1),
188                 img_prompts.size(2),
189                 separator_size,
190             ),
191             255,
192         )
193
194         # marker[:, :, 0] = 0
195         # marker[:, :, h - 1] = 0
196
197         for k in range(1, 2 * separator_size - 8):
198             i = k - (separator_size - 4)
199             j = separator_size - 5 - abs(i)
200             marker[:, :, h // 2 - 1 + i, 2 + j] = 0
201             marker[:, :, h // 2 - 1 + i + 1, 2 + j] = 0
202
203         img = torch.cat(
204             [
205                 img_prompts,
206                 marker,
207                 img_answers,
208             ],
209             dim=3,
210         )
211
212         image_name = os.path.join(result_dir, filename)
213         torchvision.utils.save_image(
214             img.float() / 255.0,
215             image_name,
216             nrow=nrow,
217             padding=margin * 4,
218             pad_value=1.0,
219         )
220
221     ######################################################################
222
223     def nb_token_values(self):
224         return len(self.colors)
225
226     # @torch.compile
227     def rec_coo(
228         self,
229         nb_rec,
230         min_height=3,
231         min_width=3,
232         surface_max=None,
233         prevent_overlap=False,
234     ):
235         if surface_max is None:
236             surface_max = self.height * self.width // 2
237
238         signature = (nb_rec, min_height, min_width, surface_max)
239
240         try:
241             return self.cache_rec_coo[signature].pop()
242         except IndexError:
243             pass
244         except KeyError:
245             pass
246
247         N = 10000
248         while True:
249             while True:
250                 i = torch.randint(self.height, (N * nb_rec, 2)).sort(dim=-1).values
251                 j = torch.randint(self.width, (N * nb_rec, 2)).sort(dim=-1).values
252
253                 big_enough = (
254                     (i[:, 1] >= i[:, 0] + min_height)
255                     & (j[:, 1] >= j[:, 0] + min_height)
256                     & ((i[:, 1] - i[:, 0]) * (j[:, 1] - j[:, 0]) <= surface_max)
257                 )
258
259                 i, j = i[big_enough], j[big_enough]
260
261                 n = i.size(0) - i.size(0) % nb_rec
262
263                 if n > 0:
264                     break
265
266             i = i[:n].reshape(n // nb_rec, nb_rec, -1)
267             j = j[:n].reshape(n // nb_rec, nb_rec, -1)
268
269             if prevent_overlap:
270                 can_fit = ((i[:, :, 1] - i[:, :, 0]) * (j[:, :, 1] - j[:, :, 0])).sum(
271                     dim=-1
272                 ) <= self.height * self.width
273                 i, j = i[can_fit], j[can_fit]
274                 if nb_rec == 2:
275                     A_i1, A_i2, A_j1, A_j2 = (
276                         i[:, 0, 0],
277                         i[:, 0, 1],
278                         j[:, 0, 0],
279                         j[:, 0, 1],
280                     )
281                     B_i1, B_i2, B_j1, B_j2 = (
282                         i[:, 1, 0],
283                         i[:, 1, 1],
284                         j[:, 1, 0],
285                         j[:, 1, 1],
286                     )
287                     no_overlap = torch.logical_not(
288                         (A_i1 >= B_i2)
289                         & (A_i2 <= B_i1)
290                         & (A_j1 >= B_j1)
291                         & (A_j2 <= B_j1)
292                     )
293                     i, j = i[no_overlap], j[no_overlap]
294                 elif nb_rec == 3:
295                     A_i1, A_i2, A_j1, A_j2 = (
296                         i[:, 0, 0],
297                         i[:, 0, 1],
298                         j[:, 0, 0],
299                         j[:, 0, 1],
300                     )
301                     B_i1, B_i2, B_j1, B_j2 = (
302                         i[:, 1, 0],
303                         i[:, 1, 1],
304                         j[:, 1, 0],
305                         j[:, 1, 1],
306                     )
307                     C_i1, C_i2, C_j1, C_j2 = (
308                         i[:, 2, 0],
309                         i[:, 2, 1],
310                         j[:, 2, 0],
311                         j[:, 2, 1],
312                     )
313                     no_overlap = (
314                         (
315                             (A_i1 >= B_i2)
316                             | (A_i2 <= B_i1)
317                             | (A_j1 >= B_j2)
318                             | (A_j2 <= B_j1)
319                         )
320                         & (
321                             (A_i1 >= C_i2)
322                             | (A_i2 <= C_i1)
323                             | (A_j1 >= C_j2)
324                             | (A_j2 <= C_j1)
325                         )
326                         & (
327                             (B_i1 >= C_i2)
328                             | (B_i2 <= C_i1)
329                             | (B_j1 >= C_j2)
330                             | (B_j2 <= C_j1)
331                         )
332                     )
333                     i, j = (i[no_overlap], j[no_overlap])
334                 else:
335                     assert nb_rec == 1
336
337             if i.size(0) > 1:
338                 break
339
340         self.cache_rec_coo[signature] = [
341             [
342                 (
343                     i[n, k, 0].item(),
344                     j[n, k, 0].item(),
345                     i[n, k, 1].item(),
346                     j[n, k, 1].item(),
347                 )
348                 for k in range(nb_rec)
349             ]
350             for n in range(i.size(0))
351         ]
352
353         return self.cache_rec_coo[signature].pop()
354
355     ######################################################################
356
357     # @torch.compile
358     def task_replace_color(self, A, f_A, B, f_B):
359         nb_rec = 3
360         c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
361         for X, f_X in [(A, f_A), (B, f_B)]:
362             r = self.rec_coo(nb_rec, prevent_overlap=True)
363             for n in range(nb_rec):
364                 i1, j1, i2, j2 = r[n]
365                 X[i1:i2, j1:j2] = c[n]
366                 f_X[i1:i2, j1:j2] = c[n if n > 0 else -1]
367
368     # @torch.compile
369     def task_translate(self, A, f_A, B, f_B):
370         while True:
371             di, dj = torch.randint(3, (2,)) - 1
372             if di.abs() + dj.abs() > 0:
373                 break
374
375         nb_rec = 3
376         c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
377         for X, f_X in [(A, f_A), (B, f_B)]:
378             while True:
379                 r = self.rec_coo(nb_rec, prevent_overlap=True)
380                 i1, j1, i2, j2 = r[nb_rec - 1]
381                 if (
382                     i1 + di >= 0
383                     and i2 + di < X.size(0)
384                     and j1 + dj >= 0
385                     and j2 + dj < X.size(1)
386                 ):
387                     break
388
389             for n in range(nb_rec):
390                 i1, j1, i2, j2 = r[n]
391                 X[i1:i2, j1:j2] = c[n]
392                 if n == nb_rec - 1:
393                     f_X[i1 + di : i2 + di, j1 + dj : j2 + dj] = c[n]
394                 else:
395                     f_X[i1:i2, j1:j2] = c[n]
396
397     # @torch.compile
398     def task_grow(self, A, f_A, B, f_B):
399         di, dj = torch.randint(2, (2,)) * 2 - 1
400         nb_rec = 3
401         c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
402         direction = torch.randint(2, (1,))
403         for X, f_X in [(A, f_A), (B, f_B)]:
404             while True:
405                 r = self.rec_coo(nb_rec, prevent_overlap=True)
406                 i1, j1, i2, j2 = r[nb_rec - 1]
407                 if i1 + 3 < i2 and j1 + 3 < j2:
408                     break
409
410             for n in range(nb_rec):
411                 i1, j1, i2, j2 = r[n]
412                 if n == nb_rec - 1:
413                     if direction == 0:
414                         X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
415                         f_X[i1:i2, j1:j2] = c[n]
416                     else:
417                         X[i1:i2, j1:j2] = c[n]
418                         f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
419                 else:
420                     X[i1:i2, j1:j2] = c[n]
421                     f_X[i1:i2, j1:j2] = c[n]
422
423     # @torch.compile
424     def task_half_fill(self, A, f_A, B, f_B):
425         di, dj = torch.randint(2, (2,)) * 2 - 1
426         nb_rec = 3
427         c = torch.randperm(len(self.colors) - 1)[: 2 * nb_rec] + 1
428         direction = torch.randint(4, (1,))
429         for X, f_X in [(A, f_A), (B, f_B)]:
430             r = self.rec_coo(nb_rec, prevent_overlap=True)
431             for n in range(nb_rec):
432                 i1, j1, i2, j2 = r[n]
433                 X[i1:i2, j1:j2] = c[2 * n]
434                 f_X[i1:i2, j1:j2] = c[2 * n]
435                 # Not my proudest moment
436                 if direction == 0:
437                     i = (i1 + i2) // 2
438                     X[i : i + 1, j1:j2] = c[2 * n + 1]
439                     if n == nb_rec - 1:
440                         f_X[i:i2, j1:j2] = c[2 * n + 1]
441                     else:
442                         f_X[i : i + 1, j1:j2] = c[2 * n + 1]
443                 elif direction == 1:
444                     i = (i1 + i2 - 1) // 2
445                     X[i : i + 1, j1:j2] = c[2 * n + 1]
446                     if n == nb_rec - 1:
447                         f_X[i1 : i + 1, j1:j2] = c[2 * n + 1]
448                     else:
449                         f_X[i : i + 1, j1:j2] = c[2 * n + 1]
450                 elif direction == 2:
451                     j = (j1 + j2) // 2
452                     X[i1:i2, j : j + 1] = c[2 * n + 1]
453                     if n == nb_rec - 1:
454                         f_X[i1:i2, j:j2] = c[2 * n + 1]
455                     else:
456                         f_X[i1:i2, j : j + 1] = c[2 * n + 1]
457                 elif direction == 3:
458                     j = (j1 + j2 - 1) // 2
459                     X[i1:i2, j : j + 1] = c[2 * n + 1]
460                     if n == nb_rec - 1:
461                         f_X[i1:i2, j1 : j + 1] = c[2 * n + 1]
462                     else:
463                         f_X[i1:i2, j : j + 1] = c[2 * n + 1]
464
465     # @torch.compile
466     def task_frame(self, A, f_A, B, f_B):
467         nb_rec = 3
468         c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
469         for X, f_X in [(A, f_A), (B, f_B)]:
470             r = self.rec_coo(nb_rec, prevent_overlap=True)
471             for n in range(nb_rec):
472                 i1, j1, i2, j2 = r[n]
473                 X[i1:i2, j1:j2] = c[n]
474                 if n == nb_rec - 1:
475                     f_X[i1:i2, j1] = c[n]
476                     f_X[i1:i2, j2 - 1] = c[n]
477                     f_X[i1, j1:j2] = c[n]
478                     f_X[i2 - 1, j1:j2] = c[n]
479                 else:
480                     f_X[i1:i2, j1:j2] = c[n]
481
482     # @torch.compile
483     def task_detect(self, A, f_A, B, f_B):
484         nb_rec = 3
485         c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
486         for X, f_X in [(A, f_A), (B, f_B)]:
487             r = self.rec_coo(nb_rec, prevent_overlap=True)
488             for n in range(nb_rec):
489                 i1, j1, i2, j2 = r[n]
490                 X[i1:i2, j1:j2] = c[n]
491                 if n < nb_rec - 1:
492                     f_X[i1, j1] = c[-1]
493
494     # @torch.compile
495     def contact(self, X, i, j, q):
496         nq, nq_diag = 0, 0
497         no = 0
498
499         for ii, jj in [
500             (i - 1, j - 1),
501             (i - 1, j),
502             (i - 1, j + 1),
503             (i, j - 1),
504             (i, j + 1),
505             (i + 1, j - 1),
506             (i + 1, j),
507             (i + 1, j + 1),
508         ]:
509             if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width:
510                 if X[ii, jj] != 0 and X[ii, jj] != q:
511                     no += 1
512
513         for ii, jj in [
514             (i - 1, j - 1),
515             (i - 1, j + 1),
516             (i + 1, j - 1),
517             (i + 1, j + 1),
518         ]:
519             if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width:
520                 if X[ii, jj] == q and X[i, jj] != q and X[ii, j] != q:
521                     nq_diag += 1
522
523         for ii, jj in [(i - 1, j), (i, j - 1), (i, j + 1), (i + 1, j)]:
524             if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width:
525                 if X[ii, jj] == q:
526                     nq += 1
527
528         return no, nq, nq_diag
529
530     def task_count(self, A, f_A, B, f_B):
531         N = (torch.randint(4, (1,)) + 2).item()
532         c = torch.randperm(len(self.colors) - 1)[:N] + 1
533
534         for X, f_X in [(A, f_A), (B, f_B)]:
535             l_q = torch.randperm(self.height * self.width)[
536                 : self.height * self.width // 20
537             ]
538             l_d = torch.randint(N, l_q.size())
539             nb = torch.zeros(N, dtype=torch.int64)
540
541             for q, e in zip(l_q, l_d):
542                 d = c[e]
543                 i, j = q % self.height, q // self.height
544                 if (
545                     nb[e] < self.width
546                     and X[max(0, i - 1) : i + 2, max(0, j - 1) : j + 2] == 0
547                 ).all():
548                     X[i, j] = d
549                     nb[e] += 1
550
551             l_q = torch.randperm((self.height - 2) * (self.width - 2))[
552                 : self.height * self.width // 2
553             ]
554             l_d = torch.randint(N, l_q.size())
555             for q, e in zip(l_q, l_d):
556                 d = c[e]
557                 i, j = q % (self.height - 2) + 1, q // (self.height - 2) + 1
558                 a1, a2, a3 = X[i - 1, j - 1 : j + 2]
559                 a8, a4 = X[i, j - 1], X[i, j + 1]
560                 a7, a6, a5 = X[i + 1, j - 1 : j + 2]
561                 if (
562                     X[i, j] == 0
563                     and nb[e] < self.width
564                     and (a2 == 0 or a2 == d)
565                     and (a4 == 0 or a4 == d)
566                     and (a6 == 0 or a6 == d)
567                     and (a8 == 0 or a8 == d)
568                     and (a1 == 0 or a2 == d or a8 == d)
569                     and (a3 == 0 or a4 == d or a2 == d)
570                     and (a5 == 0 or a6 == d or a4 == d)
571                     and (a7 == 0 or a8 == d or a6 == d)
572                 ):
573                     o = (
574                         (a2 != 0).long()
575                         + (a4 != 0).long()
576                         + (a6 != 0).long()
577                         + (a8 != 0).long()
578                     )
579                     if o <= 1:
580                         X[i, j] = d
581                         nb[e] += 1 - o
582
583             for e in range(N):
584                 for j in range(nb[e]):
585                     f_X[e, j] = c[e]
586
587     # @torch.compile
588     def task_trajectory(self, A, f_A, B, f_B):
589         c = torch.randperm(len(self.colors) - 1)[:2] + 1
590         for X, f_X in [(A, f_A), (B, f_B)]:
591             while True:
592                 di, dj = torch.randint(7, (2,)) - 3
593                 i, j = torch.randint(self.height, (1,)), torch.randint(self.width, (1,))
594                 if (
595                     abs(di) + abs(dj) > 0
596                     and i + 2 * di >= 0
597                     and i + 2 * di < self.height
598                     and j + 2 * dj >= 0
599                     and j + 2 * dj < self.width
600                 ):
601                     break
602
603             k = 0
604             while (
605                 i + k * di >= 0
606                 and i + k * di < self.height
607                 and j + k * dj >= 0
608                 and j + k * dj < self.width
609             ):
610                 if k < 2:
611                     X[i + k * di, j + k * dj] = c[k]
612                 f_X[i + k * di, j + k * dj] = c[min(k, 1)]
613                 k += 1
614
615     # @torch.compile
616     def task_bounce(self, A, f_A, B, f_B):
617         c = torch.randperm(len(self.colors) - 1)[:3] + 1
618         for X, f_X in [(A, f_A), (B, f_B)]:
619             # @torch.compile
620             def free(i, j):
621                 return (
622                     i >= 0
623                     and i < self.height
624                     and j >= 0
625                     and j < self.width
626                     and f_X[i, j] == 0
627                 )
628
629             while True:
630                 f_X[...] = 0
631                 X[...] = 0
632
633                 for _ in range((self.height * self.width) // 10):
634                     i, j = torch.randint(self.height, (1,)), torch.randint(
635                         self.width, (1,)
636                     )
637                     X[i, j] = c[0]
638                     f_X[i, j] = c[0]
639
640                 while True:
641                     di, dj = torch.randint(7, (2,)) - 3
642                     if abs(di) + abs(dj) == 1:
643                         break
644
645                 i, j = torch.randint(self.height, (1,)), torch.randint(self.width, (1,))
646
647                 X[i, j] = c[1]
648                 f_X[i, j] = c[1]
649                 l = 0
650
651                 while True:
652                     l += 1
653                     if free(i + di, j + dj):
654                         pass
655                     elif free(i - dj, j + di):
656                         di, dj = -dj, di
657                         if free(i + dj, j - di):
658                             if torch.rand(1) < 0.5:
659                                 di, dj = -di, -dj
660                     elif free(i + dj, j - di):
661                         di, dj = dj, -di
662                     else:
663                         break
664
665                     i, j = i + di, j + dj
666                     f_X[i, j] = c[2]
667                     if l <= 1:
668                         X[i, j] = c[2]
669
670                     if l >= self.width:
671                         break
672
673                 f_X[i, j] = c[1]
674                 X[i, j] = c[1]
675
676                 if l > 3:
677                     break
678
679     # @torch.compile
680     def task_scale(self, A, f_A, B, f_B):
681         c = torch.randperm(len(self.colors) - 1)[:2] + 1
682
683         i, j = torch.randint(self.height // 2, (1,)), torch.randint(
684             self.width // 2, (1,)
685         )
686
687         for X, f_X in [(A, f_A), (B, f_B)]:
688             for _ in range(3):
689                 while True:
690                     i1, j1 = torch.randint(self.height // 2 + 1, (1,)), torch.randint(
691                         self.width // 2 + 1, (1,)
692                     )
693                     i2, j2 = torch.randint(self.height // 2 + 1, (1,)), torch.randint(
694                         self.width // 2 + 1, (1,)
695                     )
696                     if i1 < i2 and j1 < j2 and min(i2 - i1, j2 - j1) <= 3:
697                         break
698                 X[i + i1 : i + i2, j + j1 : j + j2] = c[0]
699                 f_X[2 * i1 : 2 * i2, 2 * j1 : 2 * j2] = c[0]
700
701             X[i, j] = c[1]
702             f_X[0:2, 0:2] = c[1]
703
704     # @torch.compile
705     def task_symbols(self, A, f_A, B, f_B):
706         nb_rec = 4
707         c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
708         delta = 3
709         for X, f_X in [(A, f_A), (B, f_B)]:
710             while True:
711                 i, j = torch.randint(self.height - delta + 1, (nb_rec,)), torch.randint(
712                     self.width - delta + 1, (nb_rec,)
713                 )
714                 d = (i[None, :] - i[:, None]).abs().max((j[None, :] - j[:, None]).abs())
715                 d.fill_diagonal_(delta + 1)
716                 if d.min() > delta:
717                     break
718
719             for k in range(1, nb_rec):
720                 X[i[k] : i[k] + delta, j[k] : j[k] + delta] = c[k]
721
722             ai, aj = i.float().mean(), j.float().mean()
723
724             q = torch.randint(3, (1,)) + 1
725
726             X[i[0] + delta // 2 - 1, j[0] + delta // 2 - 1] = c[0]
727             X[i[0] + delta // 2 - 1, j[0] + delta // 2 + 1] = c[0]
728             X[i[0] + delta // 2 + 1, j[0] + delta // 2 - 1] = c[0]
729             X[i[0] + delta // 2 + 1, j[0] + delta // 2 + 1] = c[0]
730
731             assert i[q] != ai and j[q] != aj
732
733             X[
734                 i[0] + delta // 2 + (i[q] - ai).sign().long(),
735                 j[0] + delta // 2 + (j[q] - aj).sign().long(),
736             ] = c[nb_rec]
737
738             f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q]
739
740     # @torch.compile
741     def task_isometry(self, A, f_A, B, f_B):
742         nb_rec = 3
743         di, dj = torch.randint(3, (2,)) - 1
744         o = torch.tensor([[0.0, 1.0], [-1.0, 0.0]])
745         m = torch.eye(2)
746         for _ in range(torch.randint(4, (1,))):
747             m = m @ o
748         if torch.rand(1) < 0.5:
749             m[0, :] = -m[0, :]
750
751         ci, cj = (self.height - 1) / 2, (self.width - 1) / 2
752
753         for X, f_X in [(A, f_A), (B, f_B)]:
754             while True:
755                 X[...] = 0
756                 f_X[...] = 0
757
758                 c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
759
760                 for r in range(nb_rec):
761                     while True:
762                         i1, i2 = torch.randint(self.height - 2, (2,)) + 1
763                         j1, j2 = torch.randint(self.width - 2, (2,)) + 1
764                         if (
765                             i2 >= i1
766                             and j2 >= j1
767                             and max(i2 - i1, j2 - j1) >= 2
768                             and min(i2 - i1, j2 - j1) <= 3
769                         ):
770                             break
771                     X[i1 : i2 + 1, j1 : j2 + 1] = c[r]
772
773                     i1, j1, i2, j2 = i1 - ci, j1 - cj, i2 - ci, j2 - cj
774
775                     i1, j1 = m[0, 0] * i1 + m[0, 1] * j1, m[1, 0] * i1 + m[1, 1] * j1
776                     i2, j2 = m[0, 0] * i2 + m[0, 1] * j2, m[1, 0] * i2 + m[1, 1] * j2
777
778                     i1, j1, i2, j2 = i1 + ci, j1 + cj, i2 + ci, j2 + cj
779                     i1, i2 = i1.long() + di, i2.long() + di
780                     j1, j2 = j1.long() + dj, j2.long() + dj
781                     if i1 > i2:
782                         i1, i2 = i2, i1
783                     if j1 > j2:
784                         j1, j2 = j2, j1
785
786                     f_X[i1 : i2 + 1, j1 : j2 + 1] = c[r]
787
788                 n = F.one_hot(X.flatten()).sum(dim=0)[1:]
789                 if (
790                     n.sum() > self.height * self.width // 4
791                     and (n > 0).long().sum() == nb_rec
792                 ):
793                     break
794
795     def compute_distance(self, walls, goal_i, goal_j, start_i, start_j):
796         max_length = walls.numel()
797         dist = torch.full_like(walls, max_length)
798
799         dist[goal_i, goal_j] = 0
800         pred_dist = torch.empty_like(dist)
801
802         while True:
803             pred_dist.copy_(dist)
804             d = (
805                 torch.cat(
806                     (
807                         dist[None, 1:-1, 0:-2],
808                         dist[None, 2:, 1:-1],
809                         dist[None, 1:-1, 2:],
810                         dist[None, 0:-2, 1:-1],
811                     ),
812                     0,
813                 ).min(dim=0)[0]
814                 + 1
815             )
816
817             dist[1:-1, 1:-1].minimum_(d)  # = torch.min(dist[1:-1, 1:-1], d)
818             dist = walls * max_length + (1 - walls) * dist
819
820             if dist[start_i, start_j] < max_length or dist.equal(pred_dist):
821                 return dist * (1 - walls)
822
823     # @torch.compile
824     def task_path(self, A, f_A, B, f_B):
825         c = torch.randperm(len(self.colors) - 1)[:3] + 1
826         dist = torch.empty(self.height + 2, self.width + 2)
827         for X, f_X in [(A, f_A), (B, f_B)]:
828             nb_rec = torch.randint(3, (1,)) + 1
829             while True:
830                 r = self.rec_coo(nb_rec, prevent_overlap=True)
831                 X[...] = 0
832                 f_X[...] = 0
833                 for n in range(nb_rec):
834                     i1, j1, i2, j2 = r[n]
835                     X[i1:i2, j1:j2] = c[0]
836                     f_X[i1:i2, j1:j2] = c[0]
837                 while True:
838                     i0, j0 = torch.randint(self.height, (1,)), torch.randint(
839                         self.width, (1,)
840                     )
841                     if X[i0, j0] == 0:
842                         break
843                 while True:
844                     i1, j1 = torch.randint(self.height, (1,)), torch.randint(
845                         self.width, (1,)
846                     )
847                     if X[i1, j1] == 0:
848                         break
849                 dist[...] = 1
850                 dist[1:-1, 1:-1] = (X != 0).long()
851                 dist[...] = self.compute_distance(dist, i1 + 1, j1 + 1, i0 + 1, j0 + 1)
852                 if dist[i0 + 1, j0 + 1] >= 1 and dist[i0 + 1, j0 + 1] < self.height * 4:
853                     break
854
855             dist[1:-1, 1:-1] += (X != 0).long() * self.height * self.width
856             dist[0, :] = self.height * self.width
857             dist[-1, :] = self.height * self.width
858             dist[:, 0] = self.height * self.width
859             dist[:, -1] = self.height * self.width
860             # dist += torch.rand(dist.size())
861
862             i, j = i0 + 1, j0 + 1
863             while i != i1 + 1 or j != j1 + 1:
864                 f_X[i - 1, j - 1] = c[2]
865                 r, s, t, u = (
866                     dist[i - 1, j],
867                     dist[i, j - 1],
868                     dist[i + 1, j],
869                     dist[i, j + 1],
870                 )
871                 m = min(r, s, t, u)
872                 if r == m:
873                     i = i - 1
874                 elif t == m:
875                     i = i + 1
876                 elif s == m:
877                     j = j - 1
878                 else:
879                     j = j + 1
880
881             X[i0, j0] = c[2]
882             # f_X[i0, j0] = c[1]
883
884             X[i1, j1] = c[1]
885             f_X[i1, j1] = c[1]
886
887     # for X, f_X in [(A, f_A), (B, f_B)]:
888     # n = torch.arange(self.height * self.width).reshape(self.height, self.width)
889     # k = torch.randperm(self.height * self.width)
890     # X[...]=-1
891     # for q in k:
892     # i,j=q%self.height,q//self.height
893     # if
894
895     # @torch.compile
896     def task_puzzle(self, A, f_A, B, f_B):
897         S = 4
898         i0, j0 = (self.height - S) // 2, (self.width - S) // 2
899         c = torch.randperm(len(self.colors) - 1)[:4] + 1
900         for X, f_X in [(A, f_A), (B, f_B)]:
901             while True:
902                 f_X[...] = 0
903                 h = list(torch.randperm(c.size(0)))
904                 n = torch.zeros(c.max() + 1)
905                 for _ in range(2):
906                     k = torch.randperm(S * S)
907                     for q in k:
908                         i, j = q % S + i0, q // S + j0
909                         if f_X[i, j] == 0:
910                             r, s, t, u = (
911                                 f_X[i - 1, j],
912                                 f_X[i, j - 1],
913                                 f_X[i + 1, j],
914                                 f_X[i, j + 1],
915                             )
916                             r, s, t, u = torch.tensor([r, s, t, u])[torch.randperm(4)]
917                             if r > 0 and n[r] < 6:
918                                 n[r] += 1
919                                 f_X[i, j] = r
920                             elif s > 0 and n[s] < 6:
921                                 n[s] += 1
922                                 f_X[i, j] = s
923                             elif t > 0 and n[t] < 6:
924                                 n[t] += 1
925                                 f_X[i, j] = t
926                             elif u > 0 and n[u] < 6:
927                                 n[u] += 1
928                                 f_X[i, j] = u
929                             else:
930                                 if len(h) > 0:
931                                     d = c[h.pop()]
932                                     n[d] += 1
933                                     f_X[i, j] = d
934
935                 if n.sum() == S * S:
936                     break
937
938             k = 0
939             for d in range(4):
940                 while True:
941                     ii, jj = torch.randint(self.height, (1,)), torch.randint(
942                         self.width, (1,)
943                     )
944                     e = 0
945                     for i in range(S):
946                         for j in range(S):
947                             if (
948                                 ii + i >= self.height
949                                 or jj + j >= self.width
950                                 or (
951                                     f_X[i + i0, j + j0] == c[d]
952                                     and X[ii + i, jj + j] > 0
953                                 )
954                             ):
955                                 e = 1
956                     if e == 0:
957                         break
958                 for i in range(S):
959                     for j in range(S):
960                         if f_X[i + i0, j + j0] == c[d]:
961                             X[ii + i, jj + j] = c[d]
962
963     ######################################################################
964
965     def trivial_prompts_and_answers(self, prompts, answers):
966         S = self.height * self.width
967         Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S]
968         f_Bs = answers
969         return (Bs == f_Bs).long().min(dim=-1).values > 0
970
971     def generate_prompts_and_answers_(self, nb, tasks=None, progress_bar=False):
972         if tasks is None:
973             tasks = self.all_tasks
974
975         S = self.height * self.width
976         prompts = torch.zeros(nb, 3 * S + 2, dtype=torch.int64)
977         answers = torch.zeros(nb, S, dtype=torch.int64)
978
979         bunch = zip(prompts, answers)
980
981         if progress_bar:
982             bunch = tqdm.tqdm(
983                 bunch,
984                 dynamic_ncols=True,
985                 desc="world generation",
986                 total=prompts.size(0),
987             )
988
989         for prompt, answer in bunch:
990             A = prompt[0 * (S + 1) : 0 * (S + 1) + S].view(self.height, self.width)
991             f_A = prompt[1 * (S + 1) : 1 * (S + 1) + S].view(self.height, self.width)
992             B = prompt[2 * (S + 1) : 2 * (S + 1) + S].view(self.height, self.width)
993             f_B = answer.view(self.height, self.width)
994             task = tasks[torch.randint(len(tasks), (1,))]
995             task(A, f_A, B, f_B)
996
997         return prompts.flatten(1), answers.flatten(1)
998
999     def save_quizzes(
1000         self,
1001         result_dir,
1002         filename_prefix,
1003         prompts,
1004         answers,
1005         predicted_prompts=None,
1006         predicted_answers=None,
1007         nrow=4,
1008     ):
1009         self.save_image(
1010             result_dir,
1011             filename_prefix + ".png",
1012             prompts,
1013             answers,
1014             predicted_prompts,
1015             predicted_answers,
1016             nrow,
1017         )
1018
1019     def save_some_examples(self, result_dir):
1020         nb, nrow = 72, 4
1021         for t in self.all_tasks:
1022             print(t.__name__)
1023             prompts, answers = self.generate_prompts_and_answers_(nb, tasks=[t])
1024             self.save_quizzes(
1025                 result_dir, t.__name__, prompts[:nb], answers[:nb], nrow=nrow
1026             )
1027
1028
1029 ######################################################################
1030
1031 if __name__ == "__main__":
1032     import time
1033
1034     # grids = Grids(max_nb_cached_chunks=5, chunk_size=100, nb_threads=4)
1035     grids = Grids()
1036
1037     # nb = 1000
1038     # grids = problem.MultiThreadProblem(
1039     # grids, max_nb_cached_chunks=50, chunk_size=100, nb_threads=1
1040     # )
1041     #    time.sleep(10)
1042     # start_time = time.perf_counter()
1043     # prompts, answers = grids.generate_prompts_and_answers(nb)
1044     # delay = time.perf_counter() - start_time
1045     # print(f"{prompts.size(0)/delay:02f} seq/s")
1046     # exit(0)
1047
1048     # if True:
1049     nb, nrow = 72, 4
1050     # nb, nrow = 8, 2
1051
1052     # for t in grids.all_tasks:
1053     for t in [
1054         grids.task_replace_color,
1055         grids.task_frame,
1056     ]:
1057         print(t.__name__)
1058         prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
1059         grids.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow)
1060
1061     exit(0)
1062
1063     nb = 1000
1064
1065     for t in grids.all_tasks:
1066         # for t in [ grids.task_replace_color ]: #grids.all_tasks:
1067         start_time = time.perf_counter()
1068         prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
1069         delay = time.perf_counter() - start_time
1070         print(f"{t.__name__} {prompts.size(0)/delay:02f} seq/s")
1071
1072     exit(0)
1073
1074     m = torch.randint(2, (prompts.size(0),))
1075     predicted_prompts = m * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
1076     predicted_answers = (1 - m) * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
1077
1078     grids.save_quizzes(
1079         "/tmp",
1080         "test",
1081         prompts[:nb],
1082         answers[:nb],
1083         # You can add a bool to put a frame around the predicted parts
1084         predicted_prompts[:nb],
1085         predicted_answers[:nb],
1086     )