Update.
[culture.git] / grids.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, tqdm, os, warnings
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 ######################################################################
16
17 import problem
18
19
20 class Grids(problem.Problem):
21     named_colors = [
22         ("white", [255, 255, 255]),
23         ("red", [255, 0, 0]),
24         ("green", [0, 192, 0]),
25         ("blue", [0, 0, 255]),
26         ("yellow", [255, 224, 0]),
27         ("cyan", [0, 255, 255]),
28         ("violet", [224, 128, 255]),
29         ("lightgreen", [192, 255, 192]),
30         ("brown", [165, 42, 42]),
31         ("lightblue", [192, 192, 255]),
32         ("gray", [128, 128, 128]),
33     ]
34
35     def __init__(
36         self,
37         max_nb_cached_chunks=None,
38         chunk_size=None,
39         nb_threads=-1,
40     ):
41         self.colors = torch.tensor([c for _, c in self.named_colors])
42         self.height = 10
43         self.width = 10
44         self.cache_rec_coo = {}
45         super().__init__(max_nb_cached_chunks, chunk_size, nb_threads)
46
47     ######################################################################
48
49     def frame2img(self, x, scale=15):
50         x = x.reshape(x.size(0), self.height, -1)
51         m = torch.logical_and(x >= 0, x < self.nb_token_values()).long()
52         x = self.colors[x * m].permute(0, 3, 1, 2)
53         s = x.shape
54         x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
55         x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
56
57         x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
58         x[:, :, torch.arange(0, x.size(2), scale), :] = 0
59         x = x[:, :, 1:, 1:]
60
61         for n in range(m.size(0)):
62             for i in range(m.size(1)):
63                 for j in range(m.size(2)):
64                     if m[n, i, j] == 0:
65                         for k in range(2, scale - 2):
66                             for l in [0, 1]:
67                                 x[n, :, i * scale + k, j * scale + k - l] = 0
68                                 x[
69                                     n, :, i * scale + scale - 1 - k, j * scale + k - l
70                                 ] = 0
71
72         return x
73
74     def save_image(
75         self,
76         result_dir,
77         filename,
78         prompts,
79         answers,
80         predicted_prompts=None,
81         predicted_answers=None,
82         nrow=4,
83         margin=8,
84     ):
85         S = self.height * self.width
86         As = prompts[:, 0 * (S + 1) : 0 * (S + 1) + S].view(-1, self.height, self.width)
87         f_As = prompts[:, 1 * (S + 1) : 1 * (S + 1) + S].view(
88             -1, self.height, self.width
89         )
90         Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S].view(-1, self.height, self.width)
91         prompts = torch.cat([As, f_As, Bs], dim=2)
92         answers = answers.reshape(answers.size(0), self.height, self.width)
93
94         if predicted_prompts is None:
95             predicted_prompts = 255
96
97         if predicted_answers is None:
98             predicted_answers = 255
99
100         def add_frame(x, c, margin, bottom=False):
101             if bottom:
102                 h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0
103             else:
104                 h, w, di, dj = (
105                     x.size(2) + 2 * margin,
106                     x.size(3) + 2 * margin,
107                     margin,
108                     margin,
109                 )
110
111             y = x.new_full((x.size(0), x.size(1), h, w), 0)
112
113             if type(c) is int:
114                 y[...] = c
115             else:
116                 c = c.long()[:, None]
117                 c = (
118                     (1 - ((c == 1).long() + (c == 0).long() + (c == -1).long()))
119                     * torch.tensor([64, 64, 64])
120                     + (c == 1).long() * torch.tensor([0, 255, 0])
121                     + (c == 0).long() * torch.tensor([255, 255, 255])
122                     + (c == -1).long() * torch.tensor([255, 0, 0])
123                 )
124                 y[...] = c[:, :, None, None]
125
126             y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x
127
128             return y
129
130         img_prompts = torch.cat(
131             [
132                 add_frame(
133                     add_frame(self.frame2img(x), c=0, margin=1),
134                     c=predicted_prompts,
135                     margin=margin,
136                 )
137                 for x in prompts.to("cpu").split(split_size=self.width, dim=2)
138             ],
139             dim=3,
140         )
141
142         h = img_prompts.size(2)
143         img_answers = add_frame(
144             add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1),
145             c=predicted_answers,
146             margin=margin,
147         )
148
149         separator_size = 2 * margin
150
151         separator = img_prompts.new_full(
152             (
153                 img_prompts.size(0),
154                 img_prompts.size(1),
155                 img_prompts.size(2),
156                 separator_size,
157             ),
158             255,
159         )
160
161         marker = img_prompts.new_full(
162             (
163                 img_prompts.size(0),
164                 img_prompts.size(1),
165                 img_prompts.size(2),
166                 separator_size,
167             ),
168             255,
169         )
170
171         # marker[:, :, 0] = 0
172         # marker[:, :, h - 1] = 0
173
174         for k in range(1, 2 * separator_size - 8):
175             i = k - (separator_size - 4)
176             j = separator_size - 5 - abs(i)
177             marker[:, :, h // 2 - 1 + i, 2 + j] = 0
178             marker[:, :, h // 2 - 1 + i + 1, 2 + j] = 0
179
180         img = torch.cat(
181             [
182                 img_prompts,
183                 marker,
184                 img_answers,
185             ],
186             dim=3,
187         )
188
189         image_name = os.path.join(result_dir, filename)
190         torchvision.utils.save_image(
191             img.float() / 255.0,
192             image_name,
193             nrow=nrow,
194             padding=margin * 4,
195             pad_value=1.0,
196         )
197
198     ######################################################################
199
200     def nb_token_values(self):
201         return len(self.colors)
202
203     # @torch.compile
204     def rec_coo(
205         self,
206         nb_rec,
207         min_height=3,
208         min_width=3,
209         surface_max=None,
210         prevent_overlap=False,
211     ):
212         if surface_max is None:
213             surface_max = self.height * self.width // 2
214
215         signature = (nb_rec, min_height, min_width, surface_max)
216
217         try:
218             return self.cache_rec_coo[signature].pop()
219         except IndexError:
220             pass
221         except KeyError:
222             pass
223
224         N = 10000
225         while True:
226             while True:
227                 i = torch.randint(self.height, (N * nb_rec, 2)).sort(dim=-1).values
228                 j = torch.randint(self.width, (N * nb_rec, 2)).sort(dim=-1).values
229
230                 big_enough = (
231                     (i[:, 1] >= i[:, 0] + min_height)
232                     & (j[:, 1] >= j[:, 0] + min_height)
233                     & ((i[:, 1] - i[:, 0]) * (j[:, 1] - j[:, 0]) <= surface_max)
234                 )
235
236                 i, j = i[big_enough], j[big_enough]
237
238                 n = i.size(0) - i.size(0) % nb_rec
239
240                 if n > 0:
241                     break
242
243             i = i[:n].reshape(n // nb_rec, nb_rec, -1)
244             j = j[:n].reshape(n // nb_rec, nb_rec, -1)
245
246             if prevent_overlap:
247                 can_fit = ((i[:, :, 1] - i[:, :, 0]) * (j[:, :, 1] - j[:, :, 0])).sum(
248                     dim=-1
249                 ) <= self.height * self.width
250                 i, j = i[can_fit], j[can_fit]
251                 if nb_rec == 2:
252                     A_i1, A_i2, A_j1, A_j2 = (
253                         i[:, 0, 0],
254                         i[:, 0, 1],
255                         j[:, 0, 0],
256                         j[:, 0, 1],
257                     )
258                     B_i1, B_i2, B_j1, B_j2 = (
259                         i[:, 1, 0],
260                         i[:, 1, 1],
261                         j[:, 1, 0],
262                         j[:, 1, 1],
263                     )
264                     no_overlap = torch.logical_not(
265                         (A_i1 >= B_i2)
266                         & (A_i2 <= B_i1)
267                         & (A_j1 >= B_j1)
268                         & (A_j2 <= B_j1)
269                     )
270                     i, j = i[no_overlap], j[no_overlap]
271                 elif nb_rec == 3:
272                     A_i1, A_i2, A_j1, A_j2 = (
273                         i[:, 0, 0],
274                         i[:, 0, 1],
275                         j[:, 0, 0],
276                         j[:, 0, 1],
277                     )
278                     B_i1, B_i2, B_j1, B_j2 = (
279                         i[:, 1, 0],
280                         i[:, 1, 1],
281                         j[:, 1, 0],
282                         j[:, 1, 1],
283                     )
284                     C_i1, C_i2, C_j1, C_j2 = (
285                         i[:, 2, 0],
286                         i[:, 2, 1],
287                         j[:, 2, 0],
288                         j[:, 2, 1],
289                     )
290                     no_overlap = (
291                         (
292                             (A_i1 >= B_i2)
293                             | (A_i2 <= B_i1)
294                             | (A_j1 >= B_j2)
295                             | (A_j2 <= B_j1)
296                         )
297                         & (
298                             (A_i1 >= C_i2)
299                             | (A_i2 <= C_i1)
300                             | (A_j1 >= C_j2)
301                             | (A_j2 <= C_j1)
302                         )
303                         & (
304                             (B_i1 >= C_i2)
305                             | (B_i2 <= C_i1)
306                             | (B_j1 >= C_j2)
307                             | (B_j2 <= C_j1)
308                         )
309                     )
310                     i, j = (i[no_overlap], j[no_overlap])
311                 else:
312                     assert nb_rec == 1
313
314             if i.size(0) > 1:
315                 break
316
317         self.cache_rec_coo[signature] = [
318             [
319                 (
320                     i[n, k, 0].item(),
321                     j[n, k, 0].item(),
322                     i[n, k, 1].item(),
323                     j[n, k, 1].item(),
324                 )
325                 for k in range(nb_rec)
326             ]
327             for n in range(i.size(0))
328         ]
329
330         return self.cache_rec_coo[signature].pop()
331
332     ######################################################################
333
334     # @torch.compile
335     def task_replace_color(self, A, f_A, B, f_B):
336         nb_rec = 3
337         c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
338         for X, f_X in [(A, f_A), (B, f_B)]:
339             r = self.rec_coo(nb_rec, prevent_overlap=True)
340             for n in range(nb_rec):
341                 i1, j1, i2, j2 = r[n]
342                 X[i1:i2, j1:j2] = c[n]
343                 f_X[i1:i2, j1:j2] = c[n if n > 0 else -1]
344
345     # @torch.compile
346     def task_translate(self, A, f_A, B, f_B):
347         di, dj = torch.randint(3, (2,)) - 1
348         nb_rec = 3
349         c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
350         for X, f_X in [(A, f_A), (B, f_B)]:
351             while True:
352                 r = self.rec_coo(nb_rec, prevent_overlap=True)
353                 i1, j1, i2, j2 = r[nb_rec - 1]
354                 if (
355                     i1 + di >= 0
356                     and i2 + di < X.size(0)
357                     and j1 + dj >= 0
358                     and j2 + dj < X.size(1)
359                 ):
360                     break
361
362             for n in range(nb_rec):
363                 i1, j1, i2, j2 = r[n]
364                 X[i1:i2, j1:j2] = c[n]
365                 if n == nb_rec - 1:
366                     f_X[i1 + di : i2 + di, j1 + dj : j2 + dj] = c[n]
367                 else:
368                     f_X[i1:i2, j1:j2] = c[n]
369
370     # @torch.compile
371     def task_grow(self, A, f_A, B, f_B):
372         di, dj = torch.randint(2, (2,)) * 2 - 1
373         nb_rec = 3
374         c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
375         direction = torch.randint(2, (1,))
376         for X, f_X in [(A, f_A), (B, f_B)]:
377             while True:
378                 r = self.rec_coo(nb_rec, prevent_overlap=True)
379                 i1, j1, i2, j2 = r[nb_rec - 1]
380                 if i1 + 3 < i2 and j1 + 3 < j2:
381                     break
382
383             for n in range(nb_rec):
384                 i1, j1, i2, j2 = r[n]
385                 if n == nb_rec - 1:
386                     if direction == 0:
387                         X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
388                         f_X[i1:i2, j1:j2] = c[n]
389                     else:
390                         X[i1:i2, j1:j2] = c[n]
391                         f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
392                 else:
393                     X[i1:i2, j1:j2] = c[n]
394                     f_X[i1:i2, j1:j2] = c[n]
395
396     # @torch.compile
397     def task_color_grow(self, A, f_A, B, f_B):
398         di, dj = torch.randint(2, (2,)) * 2 - 1
399         nb_rec = 3
400         c = torch.randperm(len(self.colors) - 1)[: 2 * nb_rec] + 1
401         direction = torch.randint(4, (1,))
402         for X, f_X in [(A, f_A), (B, f_B)]:
403             r = self.rec_coo(nb_rec, prevent_overlap=True)
404             for n in range(nb_rec):
405                 i1, j1, i2, j2 = r[n]
406                 X[i1:i2, j1:j2] = c[2 * n]
407                 f_X[i1:i2, j1:j2] = c[2 * n]
408                 # Not my proudest moment
409                 if direction == 0:
410                     i = (i1 + i2) // 2
411                     X[i : i + 1, j1:j2] = c[2 * n + 1]
412                     if n == nb_rec - 1:
413                         f_X[i:i2, j1:j2] = c[2 * n + 1]
414                     else:
415                         f_X[i : i + 1, j1:j2] = c[2 * n + 1]
416                 elif direction == 1:
417                     i = (i1 + i2 - 1) // 2
418                     X[i : i + 1, j1:j2] = c[2 * n + 1]
419                     if n == nb_rec - 1:
420                         f_X[i1 : i + 1, j1:j2] = c[2 * n + 1]
421                     else:
422                         f_X[i : i + 1, j1:j2] = c[2 * n + 1]
423                 elif direction == 2:
424                     j = (j1 + j2) // 2
425                     X[i1:i2, j : j + 1] = c[2 * n + 1]
426                     if n == nb_rec - 1:
427                         f_X[i1:i2, j:j2] = c[2 * n + 1]
428                     else:
429                         f_X[i1:i2, j : j + 1] = c[2 * n + 1]
430                 elif direction == 3:
431                     j = (j1 + j2 - 1) // 2
432                     X[i1:i2, j : j + 1] = c[2 * n + 1]
433                     if n == nb_rec - 1:
434                         f_X[i1:i2, j1 : j + 1] = c[2 * n + 1]
435                     else:
436                         f_X[i1:i2, j : j + 1] = c[2 * n + 1]
437
438     # @torch.compile
439     def task_frame(self, A, f_A, B, f_B):
440         nb_rec = 3
441         c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
442         for X, f_X in [(A, f_A), (B, f_B)]:
443             r = self.rec_coo(nb_rec, prevent_overlap=True)
444             for n in range(nb_rec):
445                 i1, j1, i2, j2 = r[n]
446                 X[i1:i2, j1:j2] = c[n]
447                 if n == nb_rec - 1:
448                     f_X[i1:i2, j1] = c[n]
449                     f_X[i1:i2, j2 - 1] = c[n]
450                     f_X[i1, j1:j2] = c[n]
451                     f_X[i2 - 1, j1:j2] = c[n]
452                 else:
453                     f_X[i1:i2, j1:j2] = c[n]
454
455     # @torch.compile
456     def task_detect(self, A, f_A, B, f_B):
457         nb_rec = 3
458         c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
459         for X, f_X in [(A, f_A), (B, f_B)]:
460             r = self.rec_coo(nb_rec, prevent_overlap=True)
461             for n in range(nb_rec):
462                 i1, j1, i2, j2 = r[n]
463                 X[i1:i2, j1:j2] = c[n]
464                 if n < nb_rec - 1:
465                     f_X[i1, j1] = c[-1]
466
467     # @torch.compile
468     def contact(self, X, i, j, q):
469         nq, nq_diag = 0, 0
470         no = 0
471
472         for ii, jj in [
473             (i - 1, j - 1),
474             (i - 1, j),
475             (i - 1, j + 1),
476             (i, j - 1),
477             (i, j + 1),
478             (i + 1, j - 1),
479             (i + 1, j),
480             (i + 1, j + 1),
481         ]:
482             if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width:
483                 if X[ii, jj] != 0 and X[ii, jj] != q:
484                     no += 1
485
486         for ii, jj in [
487             (i - 1, j - 1),
488             (i - 1, j + 1),
489             (i + 1, j - 1),
490             (i + 1, j + 1),
491         ]:
492             if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width:
493                 if X[ii, jj] == q and X[i, jj] != q and X[ii, j] != q:
494                     nq_diag += 1
495
496         for ii, jj in [(i - 1, j), (i, j - 1), (i, j + 1), (i + 1, j)]:
497             if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width:
498                 if X[ii, jj] == q:
499                     nq += 1
500
501         return no, nq, nq_diag
502
503     def task_count(self, A, f_A, B, f_B):
504         N = (torch.randint(4, (1,)) + 2).item()
505         c = torch.randperm(len(self.colors) - 1)[:N] + 1
506
507         for X, f_X in [(A, f_A), (B, f_B)]:
508             l_q = torch.randperm(self.height * self.width)[
509                 : self.height * self.width // 20
510             ]
511             l_d = torch.randint(N, l_q.size())
512             nb = torch.zeros(N, dtype=torch.int64)
513
514             for q, e in zip(l_q, l_d):
515                 d = c[e]
516                 i, j = q % self.height, q // self.height
517                 if (
518                     nb[e] < self.width
519                     and X[max(0, i - 1) : i + 2, max(0, j - 1) : j + 2] == 0
520                 ).all():
521                     X[i, j] = d
522                     nb[e] += 1
523
524             l_q = torch.randperm((self.height - 2) * (self.width - 2))[
525                 : self.height * self.width // 2
526             ]
527             l_d = torch.randint(N, l_q.size())
528             for q, e in zip(l_q, l_d):
529                 d = c[e]
530                 i, j = q % (self.height - 2) + 1, q // (self.height - 2) + 1
531                 a1, a2, a3 = X[i - 1, j - 1 : j + 2]
532                 a8, a4 = X[i, j - 1], X[i, j + 1]
533                 a7, a6, a5 = X[i + 1, j - 1 : j + 2]
534                 if (
535                     X[i, j] == 0
536                     and nb[e] < self.width
537                     and (a2 == 0 or a2 == d)
538                     and (a4 == 0 or a4 == d)
539                     and (a6 == 0 or a6 == d)
540                     and (a8 == 0 or a8 == d)
541                     and (a1 == 0 or a2 == d or a8 == d)
542                     and (a3 == 0 or a4 == d or a2 == d)
543                     and (a5 == 0 or a6 == d or a4 == d)
544                     and (a7 == 0 or a8 == d or a6 == d)
545                 ):
546                     o = (
547                         (a2 != 0).long()
548                         + (a4 != 0).long()
549                         + (a6 != 0).long()
550                         + (a8 != 0).long()
551                     )
552                     if o <= 1:
553                         X[i, j] = d
554                         nb[e] += 1 - o
555
556             for e in range(N):
557                 for j in range(nb[e]):
558                     f_X[e, j] = c[e]
559
560     # @torch.compile
561     def task_trajectory(self, A, f_A, B, f_B):
562         c = torch.randperm(len(self.colors) - 1)[:2] + 1
563         for X, f_X in [(A, f_A), (B, f_B)]:
564             while True:
565                 di, dj = torch.randint(7, (2,)) - 3
566                 i, j = torch.randint(self.height, (1,)), torch.randint(self.width, (1,))
567                 if (
568                     abs(di) + abs(dj) > 0
569                     and i + 2 * di >= 0
570                     and i + 2 * di < self.height
571                     and j + 2 * dj >= 0
572                     and j + 2 * dj < self.width
573                 ):
574                     break
575
576             k = 0
577             while (
578                 i + k * di >= 0
579                 and i + k * di < self.height
580                 and j + k * dj >= 0
581                 and j + k * dj < self.width
582             ):
583                 if k < 2:
584                     X[i + k * di, j + k * dj] = c[k]
585                 f_X[i + k * di, j + k * dj] = c[min(k, 1)]
586                 k += 1
587
588     # @torch.compile
589     def task_bounce(self, A, f_A, B, f_B):
590         c = torch.randperm(len(self.colors) - 1)[:3] + 1
591         for X, f_X in [(A, f_A), (B, f_B)]:
592             # @torch.compile
593             def free(i, j):
594                 return (
595                     i >= 0
596                     and i < self.height
597                     and j >= 0
598                     and j < self.width
599                     and f_X[i, j] == 0
600                 )
601
602             while True:
603                 f_X[...] = 0
604                 X[...] = 0
605
606                 for _ in range((self.height * self.width) // 10):
607                     i, j = torch.randint(self.height, (1,)), torch.randint(
608                         self.width, (1,)
609                     )
610                     X[i, j] = c[0]
611                     f_X[i, j] = c[0]
612
613                 while True:
614                     di, dj = torch.randint(7, (2,)) - 3
615                     if abs(di) + abs(dj) == 1:
616                         break
617
618                 i, j = torch.randint(self.height, (1,)), torch.randint(self.width, (1,))
619
620                 X[i, j] = c[1]
621                 f_X[i, j] = c[1]
622                 l = 0
623
624                 while True:
625                     l += 1
626                     if free(i + di, j + dj):
627                         pass
628                     elif free(i - dj, j + di):
629                         di, dj = -dj, di
630                         if free(i + dj, j - di):
631                             if torch.rand(1) < 0.5:
632                                 di, dj = -di, -dj
633                     elif free(i + dj, j - di):
634                         di, dj = dj, -di
635                     else:
636                         break
637
638                     i, j = i + di, j + dj
639                     f_X[i, j] = c[2]
640                     if l <= 1:
641                         X[i, j] = c[2]
642
643                     if l >= self.width:
644                         break
645
646                 f_X[i, j] = c[1]
647                 X[i, j] = c[1]
648
649                 if l > 3:
650                     break
651
652     # @torch.compile
653     def task_scale(self, A, f_A, B, f_B):
654         c = torch.randperm(len(self.colors) - 1)[:2] + 1
655
656         i, j = torch.randint(self.height // 2, (1,)), torch.randint(
657             self.width // 2, (1,)
658         )
659
660         for X, f_X in [(A, f_A), (B, f_B)]:
661             for _ in range(3):
662                 while True:
663                     i1, j1 = torch.randint(self.height // 2 + 1, (1,)), torch.randint(
664                         self.width // 2 + 1, (1,)
665                     )
666                     i2, j2 = torch.randint(self.height // 2 + 1, (1,)), torch.randint(
667                         self.width // 2 + 1, (1,)
668                     )
669                     if i1 < i2 and j1 < j2 and min(i2 - i1, j2 - j1) <= 3:
670                         break
671                 X[i + i1 : i + i2, j + j1 : j + j2] = c[0]
672                 f_X[2 * i1 : 2 * i2, 2 * j1 : 2 * j2] = c[0]
673
674             X[i, j] = c[1]
675             f_X[0:2, 0:2] = c[1]
676
677     # @torch.compile
678     def task_symbols(self, A, f_A, B, f_B):
679         nb_rec = 4
680         c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
681         delta = 3
682         for X, f_X in [(A, f_A), (B, f_B)]:
683             while True:
684                 i, j = torch.randint(self.height - delta + 1, (nb_rec,)), torch.randint(
685                     self.width - delta + 1, (nb_rec,)
686                 )
687                 d = (i[None, :] - i[:, None]).abs().max((j[None, :] - j[:, None]).abs())
688                 d.fill_diagonal_(delta + 1)
689                 if d.min() > delta:
690                     break
691
692             for k in range(1, nb_rec):
693                 X[i[k] : i[k] + delta, j[k] : j[k] + delta] = c[k]
694
695             ai, aj = i.float().mean(), j.float().mean()
696
697             q = torch.randint(3, (1,)) + 1
698
699             X[i[0] + delta // 2 - 1, j[0] + delta // 2 - 1] = c[0]
700             X[i[0] + delta // 2 - 1, j[0] + delta // 2 + 1] = c[0]
701             X[i[0] + delta // 2 + 1, j[0] + delta // 2 - 1] = c[0]
702             X[i[0] + delta // 2 + 1, j[0] + delta // 2 + 1] = c[0]
703
704             assert i[q] != ai and j[q] != aj
705
706             X[
707                 i[0] + delta // 2 + (i[q] - ai).sign().long(),
708                 j[0] + delta // 2 + (j[q] - aj).sign().long(),
709             ] = c[nb_rec]
710
711             f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q]
712
713     # @torch.compile
714     def task_ortho(self, A, f_A, B, f_B):
715         nb_rec = 3
716         di, dj = torch.randint(3, (2,)) - 1
717         o = torch.tensor([[0.0, 1.0], [-1.0, 0.0]])
718         m = torch.eye(2)
719         for _ in range(torch.randint(4, (1,))):
720             m = m @ o
721         if torch.rand(1) < 0.5:
722             m[0, :] = -m[0, :]
723
724         ci, cj = (self.height - 1) / 2, (self.width - 1) / 2
725
726         for X, f_X in [(A, f_A), (B, f_B)]:
727             while True:
728                 X[...] = 0
729                 f_X[...] = 0
730
731                 c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
732
733                 for r in range(nb_rec):
734                     while True:
735                         i1, i2 = torch.randint(self.height - 2, (2,)) + 1
736                         j1, j2 = torch.randint(self.width - 2, (2,)) + 1
737                         if (
738                             i2 >= i1
739                             and j2 >= j1
740                             and max(i2 - i1, j2 - j1) >= 2
741                             and min(i2 - i1, j2 - j1) <= 3
742                         ):
743                             break
744                     X[i1 : i2 + 1, j1 : j2 + 1] = c[r]
745
746                     i1, j1, i2, j2 = i1 - ci, j1 - cj, i2 - ci, j2 - cj
747
748                     i1, j1 = m[0, 0] * i1 + m[0, 1] * j1, m[1, 0] * i1 + m[1, 1] * j1
749                     i2, j2 = m[0, 0] * i2 + m[0, 1] * j2, m[1, 0] * i2 + m[1, 1] * j2
750
751                     i1, j1, i2, j2 = i1 + ci, j1 + cj, i2 + ci, j2 + cj
752                     i1, i2 = i1.long() + di, i2.long() + di
753                     j1, j2 = j1.long() + dj, j2.long() + dj
754                     if i1 > i2:
755                         i1, i2 = i2, i1
756                     if j1 > j2:
757                         j1, j2 = j2, j1
758
759                     f_X[i1 : i2 + 1, j1 : j2 + 1] = c[r]
760
761                 n = F.one_hot(X.flatten()).sum(dim=0)[1:]
762                 if (
763                     n.sum() > self.height * self.width // 4
764                     and (n > 0).long().sum() == nb_rec
765                 ):
766                     break
767
768     # @torch.compile
769     def task_islands(self, A, f_A, B, f_B):
770         pass
771
772     # for X, f_X in [(A, f_A), (B, f_B)]:
773     # n = torch.arange(self.height * self.width).reshape(self.height, self.width)
774     # k = torch.randperm(self.height * self.width)
775     # X[...]=-1
776     # for q in k:
777     # i,j=q%self.height,q//self.height
778     # if
779
780     ######################################################################
781
782     def all_tasks(self):
783         return [
784             self.task_replace_color,
785             self.task_translate,
786             self.task_grow,
787             self.task_color_grow,
788             self.task_frame,
789             self.task_detect,
790             self.task_count,
791             self.task_trajectory,
792             self.task_bounce,
793             self.task_scale,
794             self.task_symbols,
795             self.task_ortho,
796             # self.task_islands,
797         ]
798
799     def trivial_prompts_and_answers(self, prompts, answers):
800         S = self.height * self.width
801         Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S]
802         f_Bs = answers
803         return (Bs == f_Bs).long().min(dim=-1).values > 0
804
805     def generate_prompts_and_answers_(self, nb, tasks=None, progress_bar=False):
806         if tasks is None:
807             tasks = self.all_tasks()
808
809         S = self.height * self.width
810         prompts = torch.zeros(nb, 3 * S + 2, dtype=torch.int64)
811         answers = torch.zeros(nb, S, dtype=torch.int64)
812
813         bunch = zip(prompts, answers)
814
815         if progress_bar:
816             bunch = tqdm.tqdm(
817                 bunch,
818                 dynamic_ncols=True,
819                 desc="world generation",
820                 total=prompts.size(0),
821             )
822
823         for prompt, answer in bunch:
824             A = prompt[0 * (S + 1) : 0 * (S + 1) + S].view(self.height, self.width)
825             f_A = prompt[1 * (S + 1) : 1 * (S + 1) + S].view(self.height, self.width)
826             B = prompt[2 * (S + 1) : 2 * (S + 1) + S].view(self.height, self.width)
827             f_B = answer.view(self.height, self.width)
828             task = tasks[torch.randint(len(tasks), (1,))]
829             task(A, f_A, B, f_B)
830
831         return prompts.flatten(1), answers.flatten(1)
832
833     def save_quizzes(
834         self,
835         result_dir,
836         filename_prefix,
837         prompts,
838         answers,
839         predicted_prompts=None,
840         predicted_answers=None,
841         nrow=4,
842     ):
843         self.save_image(
844             result_dir,
845             filename_prefix + ".png",
846             prompts,
847             answers,
848             predicted_prompts,
849             predicted_answers,
850             nrow,
851         )
852
853
854 ######################################################################
855
856 if __name__ == "__main__":
857     import time
858
859     # grids = Grids(max_nb_cached_chunks=5, chunk_size=100, nb_threads=4)
860     grids = Grids()
861
862     # nb = 1000
863     # grids = problem.MultiThreadProblem(
864     # grids, max_nb_cached_chunks=50, chunk_size=100, nb_threads=1
865     # )
866     #    time.sleep(10)
867     # start_time = time.perf_counter()
868     # prompts, answers = grids.generate_prompts_and_answers(nb)
869     # delay = time.perf_counter() - start_time
870     # print(f"{prompts.size(0)/delay:02f} seq/s")
871     # exit(0)
872
873     # if True:
874     nb = 72
875
876     for t in grids.all_tasks():
877         # for t in [grids.task_replace_color]:
878         print(t.__name__)
879         prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
880         grids.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=4)
881
882     nb = 1000
883
884     for t in grids.all_tasks():
885         # for t in [ grids.task_replace_color ]: #grids.all_tasks():
886         start_time = time.perf_counter()
887         prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
888         delay = time.perf_counter() - start_time
889         print(f"{t.__name__} {prompts.size(0)/delay:02f} seq/s")
890
891     exit(0)
892
893     m = torch.randint(2, (prompts.size(0),))
894     predicted_prompts = m * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
895     predicted_answers = (1 - m) * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
896
897     grids.save_quizzes(
898         "/tmp",
899         "test",
900         prompts[:nb],
901         answers[:nb],
902         # You can add a bool to put a frame around the predicted parts
903         predicted_prompts[:nb],
904         predicted_answers[:nb],
905     )