Update.
[culture.git] / grids.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, tqdm, os, warnings
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 ######################################################################
16
17 import problem
18
19
20 class Grids(problem.Problem):
21     named_colors = [
22         ("white", [255, 255, 255]),
23         ("red", [255, 0, 0]),
24         ("green", [0, 192, 0]),
25         ("blue", [0, 0, 255]),
26         ("yellow", [255, 224, 0]),
27         ("cyan", [0, 255, 255]),
28         ("violet", [224, 128, 255]),
29         ("lightgreen", [192, 255, 192]),
30         ("brown", [165, 42, 42]),
31         ("lightblue", [192, 192, 255]),
32         ("gray", [128, 128, 128]),
33     ]
34
35     def __init__(
36         self,
37         max_nb_cached_chunks=None,
38         chunk_size=None,
39         nb_threads=-1,
40     ):
41         self.colors = torch.tensor([c for _, c in self.named_colors])
42         self.height = 10
43         self.width = 10
44         self.cache_rec_coo = {}
45         super().__init__(max_nb_cached_chunks, chunk_size, nb_threads)
46
47     ######################################################################
48
49     def frame2img(self, x, scale=15):
50         x = x.reshape(x.size(0), self.height, -1)
51         m = torch.logical_and(x >= 0, x < self.nb_token_values()).long()
52         x = self.colors[x * m].permute(0, 3, 1, 2)
53         s = x.shape
54         x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
55         x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
56
57         x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
58         x[:, :, torch.arange(0, x.size(2), scale), :] = 0
59         x = x[:, :, 1:, 1:]
60
61         for n in range(m.size(0)):
62             for i in range(m.size(1)):
63                 for j in range(m.size(2)):
64                     if m[n, i, j] == 0:
65                         for k in range(2, scale - 2):
66                             for l in [0, 1]:
67                                 x[n, :, i * scale + k, j * scale + k - l] = 0
68                                 x[
69                                     n, :, i * scale + scale - 1 - k, j * scale + k - l
70                                 ] = 0
71
72         return x
73
74     def save_image(
75         self,
76         result_dir,
77         filename,
78         prompts,
79         answers,
80         predicted_prompts=None,
81         predicted_answers=None,
82         nrow=4,
83         margin=8,
84     ):
85         S = self.height * self.width
86         As = prompts[:, 0 * (S + 1) : 0 * (S + 1) + S].view(-1, self.height, self.width)
87         f_As = prompts[:, 1 * (S + 1) : 1 * (S + 1) + S].view(
88             -1, self.height, self.width
89         )
90         Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S].view(-1, self.height, self.width)
91         prompts = torch.cat([As, f_As, Bs], dim=2)
92         answers = answers.reshape(answers.size(0), self.height, self.width)
93
94         if predicted_prompts is None:
95             predicted_prompts = 255
96
97         if predicted_answers is None:
98             predicted_answers = 255
99
100         def add_frame(x, c, margin, bottom=False):
101             if bottom:
102                 h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0
103             else:
104                 h, w, di, dj = (
105                     x.size(2) + 2 * margin,
106                     x.size(3) + 2 * margin,
107                     margin,
108                     margin,
109                 )
110
111             y = x.new_full((x.size(0), x.size(1), h, w), 0)
112
113             if type(c) is int:
114                 y[...] = c
115             else:
116                 c = c.long()[:, None]
117                 c = (
118                     (1 - ((c == 1).long() + (c == 0).long() + (c == -1).long()))
119                     * torch.tensor([64, 64, 64])
120                     + (c == 1).long() * torch.tensor([0, 255, 0])
121                     + (c == 0).long() * torch.tensor([255, 255, 255])
122                     + (c == -1).long() * torch.tensor([255, 0, 0])
123                 )
124                 y[...] = c[:, :, None, None]
125
126             y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x
127
128             return y
129
130         img_prompts = torch.cat(
131             [
132                 add_frame(
133                     add_frame(self.frame2img(x), c=0, margin=1),
134                     c=predicted_prompts,
135                     margin=margin,
136                 )
137                 for x in prompts.to("cpu").split(split_size=self.width, dim=2)
138             ],
139             dim=3,
140         )
141
142         h = img_prompts.size(2)
143         img_answers = add_frame(
144             add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1),
145             c=predicted_answers,
146             margin=margin,
147         )
148
149         separator_size = 2 * margin
150
151         separator = img_prompts.new_full(
152             (
153                 img_prompts.size(0),
154                 img_prompts.size(1),
155                 img_prompts.size(2),
156                 separator_size,
157             ),
158             255,
159         )
160
161         marker = img_prompts.new_full(
162             (
163                 img_prompts.size(0),
164                 img_prompts.size(1),
165                 img_prompts.size(2),
166                 separator_size,
167             ),
168             255,
169         )
170
171         # marker[:, :, 0] = 0
172         # marker[:, :, h - 1] = 0
173
174         for k in range(1, 2 * separator_size - 8):
175             i = k - (separator_size - 4)
176             j = separator_size - 5 - abs(i)
177             marker[:, :, h // 2 - 1 + i, 2 + j] = 0
178             marker[:, :, h // 2 - 1 + i + 1, 2 + j] = 0
179
180         img = torch.cat(
181             [
182                 img_prompts,
183                 marker,
184                 img_answers,
185             ],
186             dim=3,
187         )
188
189         image_name = os.path.join(result_dir, filename)
190         torchvision.utils.save_image(
191             img.float() / 255.0,
192             image_name,
193             nrow=nrow,
194             padding=margin * 4,
195             pad_value=1.0,
196         )
197
198     ######################################################################
199
200     def nb_token_values(self):
201         return len(self.colors)
202
203     # @torch.compile
204     def rec_coo(
205         self,
206         nb_rec,
207         min_height=3,
208         min_width=3,
209         surface_max=None,
210         prevent_overlap=False,
211     ):
212         if surface_max is None:
213             surface_max = self.height * self.width // 2
214
215         signature = (nb_rec, min_height, min_width, surface_max)
216
217         try:
218             return self.cache_rec_coo[signature].pop()
219         except IndexError:
220             pass
221         except KeyError:
222             pass
223
224         N = 10000
225         while True:
226             while True:
227                 i = torch.randint(self.height, (N * nb_rec, 2)).sort(dim=-1).values
228                 j = torch.randint(self.width, (N * nb_rec, 2)).sort(dim=-1).values
229
230                 big_enough = (
231                     (i[:, 1] >= i[:, 0] + min_height)
232                     & (j[:, 1] >= j[:, 0] + min_height)
233                     & ((i[:, 1] - i[:, 0]) * (j[:, 1] - j[:, 0]) <= surface_max)
234                 )
235
236                 i, j = i[big_enough], j[big_enough]
237
238                 n = i.size(0) - i.size(0) % nb_rec
239
240                 if n > 0:
241                     break
242
243             i = i[:n].reshape(n // nb_rec, nb_rec, -1)
244             j = j[:n].reshape(n // nb_rec, nb_rec, -1)
245
246             if prevent_overlap:
247                 can_fit = ((i[:, :, 1] - i[:, :, 0]) * (j[:, :, 1] - j[:, :, 0])).sum(
248                     dim=-1
249                 ) <= self.height * self.width
250                 i, j = i[can_fit], j[can_fit]
251                 if nb_rec == 2:
252                     A_i1, A_i2, A_j1, A_j2 = (
253                         i[:, 0, 0],
254                         i[:, 0, 1],
255                         j[:, 0, 0],
256                         j[:, 0, 1],
257                     )
258                     B_i1, B_i2, B_j1, B_j2 = (
259                         i[:, 1, 0],
260                         i[:, 1, 1],
261                         j[:, 1, 0],
262                         j[:, 1, 1],
263                     )
264                     no_overlap = torch.logical_not(
265                         (A_i1 >= B_i2)
266                         & (A_i2 <= B_i1)
267                         & (A_j1 >= B_j1)
268                         & (A_j2 <= B_j1)
269                     )
270                     i, j = i[no_overlap], j[no_overlap]
271                 elif nb_rec == 3:
272                     A_i1, A_i2, A_j1, A_j2 = (
273                         i[:, 0, 0],
274                         i[:, 0, 1],
275                         j[:, 0, 0],
276                         j[:, 0, 1],
277                     )
278                     B_i1, B_i2, B_j1, B_j2 = (
279                         i[:, 1, 0],
280                         i[:, 1, 1],
281                         j[:, 1, 0],
282                         j[:, 1, 1],
283                     )
284                     C_i1, C_i2, C_j1, C_j2 = (
285                         i[:, 2, 0],
286                         i[:, 2, 1],
287                         j[:, 2, 0],
288                         j[:, 2, 1],
289                     )
290                     no_overlap = (
291                         (
292                             (A_i1 >= B_i2)
293                             | (A_i2 <= B_i1)
294                             | (A_j1 >= B_j2)
295                             | (A_j2 <= B_j1)
296                         )
297                         & (
298                             (A_i1 >= C_i2)
299                             | (A_i2 <= C_i1)
300                             | (A_j1 >= C_j2)
301                             | (A_j2 <= C_j1)
302                         )
303                         & (
304                             (B_i1 >= C_i2)
305                             | (B_i2 <= C_i1)
306                             | (B_j1 >= C_j2)
307                             | (B_j2 <= C_j1)
308                         )
309                     )
310                     i, j = (i[no_overlap], j[no_overlap])
311                 else:
312                     assert nb_rec == 1
313
314             if i.size(0) > 1:
315                 break
316
317         self.cache_rec_coo[signature] = [
318             [
319                 (
320                     i[n, k, 0].item(),
321                     j[n, k, 0].item(),
322                     i[n, k, 1].item(),
323                     j[n, k, 1].item(),
324                 )
325                 for k in range(nb_rec)
326             ]
327             for n in range(i.size(0))
328         ]
329
330         return self.cache_rec_coo[signature].pop()
331
332     ######################################################################
333
334     # @torch.compile
335     def task_replace_color(self, A, f_A, B, f_B):
336         nb_rec = 3
337         c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
338         for X, f_X in [(A, f_A), (B, f_B)]:
339             r = self.rec_coo(nb_rec, prevent_overlap=True)
340             for n in range(nb_rec):
341                 i1, j1, i2, j2 = r[n]
342                 X[i1:i2, j1:j2] = c[n]
343                 f_X[i1:i2, j1:j2] = c[n if n > 0 else -1]
344
345     # @torch.compile
346     def task_translate(self, A, f_A, B, f_B):
347         while True:
348             di, dj = torch.randint(3, (2,)) - 1
349             if di.abs() + dj.abs() > 0:
350                 break
351
352         nb_rec = 3
353         c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
354         for X, f_X in [(A, f_A), (B, f_B)]:
355             while True:
356                 r = self.rec_coo(nb_rec, prevent_overlap=True)
357                 i1, j1, i2, j2 = r[nb_rec - 1]
358                 if (
359                     i1 + di >= 0
360                     and i2 + di < X.size(0)
361                     and j1 + dj >= 0
362                     and j2 + dj < X.size(1)
363                 ):
364                     break
365
366             for n in range(nb_rec):
367                 i1, j1, i2, j2 = r[n]
368                 X[i1:i2, j1:j2] = c[n]
369                 if n == nb_rec - 1:
370                     f_X[i1 + di : i2 + di, j1 + dj : j2 + dj] = c[n]
371                 else:
372                     f_X[i1:i2, j1:j2] = c[n]
373
374     # @torch.compile
375     def task_grow(self, A, f_A, B, f_B):
376         di, dj = torch.randint(2, (2,)) * 2 - 1
377         nb_rec = 3
378         c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
379         direction = torch.randint(2, (1,))
380         for X, f_X in [(A, f_A), (B, f_B)]:
381             while True:
382                 r = self.rec_coo(nb_rec, prevent_overlap=True)
383                 i1, j1, i2, j2 = r[nb_rec - 1]
384                 if i1 + 3 < i2 and j1 + 3 < j2:
385                     break
386
387             for n in range(nb_rec):
388                 i1, j1, i2, j2 = r[n]
389                 if n == nb_rec - 1:
390                     if direction == 0:
391                         X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
392                         f_X[i1:i2, j1:j2] = c[n]
393                     else:
394                         X[i1:i2, j1:j2] = c[n]
395                         f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
396                 else:
397                     X[i1:i2, j1:j2] = c[n]
398                     f_X[i1:i2, j1:j2] = c[n]
399
400     # @torch.compile
401     def task_color_grow(self, A, f_A, B, f_B):
402         di, dj = torch.randint(2, (2,)) * 2 - 1
403         nb_rec = 3
404         c = torch.randperm(len(self.colors) - 1)[: 2 * nb_rec] + 1
405         direction = torch.randint(4, (1,))
406         for X, f_X in [(A, f_A), (B, f_B)]:
407             r = self.rec_coo(nb_rec, prevent_overlap=True)
408             for n in range(nb_rec):
409                 i1, j1, i2, j2 = r[n]
410                 X[i1:i2, j1:j2] = c[2 * n]
411                 f_X[i1:i2, j1:j2] = c[2 * n]
412                 # Not my proudest moment
413                 if direction == 0:
414                     i = (i1 + i2) // 2
415                     X[i : i + 1, j1:j2] = c[2 * n + 1]
416                     if n == nb_rec - 1:
417                         f_X[i:i2, j1:j2] = c[2 * n + 1]
418                     else:
419                         f_X[i : i + 1, j1:j2] = c[2 * n + 1]
420                 elif direction == 1:
421                     i = (i1 + i2 - 1) // 2
422                     X[i : i + 1, j1:j2] = c[2 * n + 1]
423                     if n == nb_rec - 1:
424                         f_X[i1 : i + 1, j1:j2] = c[2 * n + 1]
425                     else:
426                         f_X[i : i + 1, j1:j2] = c[2 * n + 1]
427                 elif direction == 2:
428                     j = (j1 + j2) // 2
429                     X[i1:i2, j : j + 1] = c[2 * n + 1]
430                     if n == nb_rec - 1:
431                         f_X[i1:i2, j:j2] = c[2 * n + 1]
432                     else:
433                         f_X[i1:i2, j : j + 1] = c[2 * n + 1]
434                 elif direction == 3:
435                     j = (j1 + j2 - 1) // 2
436                     X[i1:i2, j : j + 1] = c[2 * n + 1]
437                     if n == nb_rec - 1:
438                         f_X[i1:i2, j1 : j + 1] = c[2 * n + 1]
439                     else:
440                         f_X[i1:i2, j : j + 1] = c[2 * n + 1]
441
442     # @torch.compile
443     def task_frame(self, A, f_A, B, f_B):
444         nb_rec = 3
445         c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
446         for X, f_X in [(A, f_A), (B, f_B)]:
447             r = self.rec_coo(nb_rec, prevent_overlap=True)
448             for n in range(nb_rec):
449                 i1, j1, i2, j2 = r[n]
450                 X[i1:i2, j1:j2] = c[n]
451                 if n == nb_rec - 1:
452                     f_X[i1:i2, j1] = c[n]
453                     f_X[i1:i2, j2 - 1] = c[n]
454                     f_X[i1, j1:j2] = c[n]
455                     f_X[i2 - 1, j1:j2] = c[n]
456                 else:
457                     f_X[i1:i2, j1:j2] = c[n]
458
459     # @torch.compile
460     def task_detect(self, A, f_A, B, f_B):
461         nb_rec = 3
462         c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
463         for X, f_X in [(A, f_A), (B, f_B)]:
464             r = self.rec_coo(nb_rec, prevent_overlap=True)
465             for n in range(nb_rec):
466                 i1, j1, i2, j2 = r[n]
467                 X[i1:i2, j1:j2] = c[n]
468                 if n < nb_rec - 1:
469                     f_X[i1, j1] = c[-1]
470
471     # @torch.compile
472     def contact(self, X, i, j, q):
473         nq, nq_diag = 0, 0
474         no = 0
475
476         for ii, jj in [
477             (i - 1, j - 1),
478             (i - 1, j),
479             (i - 1, j + 1),
480             (i, j - 1),
481             (i, j + 1),
482             (i + 1, j - 1),
483             (i + 1, j),
484             (i + 1, j + 1),
485         ]:
486             if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width:
487                 if X[ii, jj] != 0 and X[ii, jj] != q:
488                     no += 1
489
490         for ii, jj in [
491             (i - 1, j - 1),
492             (i - 1, j + 1),
493             (i + 1, j - 1),
494             (i + 1, j + 1),
495         ]:
496             if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width:
497                 if X[ii, jj] == q and X[i, jj] != q and X[ii, j] != q:
498                     nq_diag += 1
499
500         for ii, jj in [(i - 1, j), (i, j - 1), (i, j + 1), (i + 1, j)]:
501             if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width:
502                 if X[ii, jj] == q:
503                     nq += 1
504
505         return no, nq, nq_diag
506
507     def task_count(self, A, f_A, B, f_B):
508         N = (torch.randint(4, (1,)) + 2).item()
509         c = torch.randperm(len(self.colors) - 1)[:N] + 1
510
511         for X, f_X in [(A, f_A), (B, f_B)]:
512             l_q = torch.randperm(self.height * self.width)[
513                 : self.height * self.width // 20
514             ]
515             l_d = torch.randint(N, l_q.size())
516             nb = torch.zeros(N, dtype=torch.int64)
517
518             for q, e in zip(l_q, l_d):
519                 d = c[e]
520                 i, j = q % self.height, q // self.height
521                 if (
522                     nb[e] < self.width
523                     and X[max(0, i - 1) : i + 2, max(0, j - 1) : j + 2] == 0
524                 ).all():
525                     X[i, j] = d
526                     nb[e] += 1
527
528             l_q = torch.randperm((self.height - 2) * (self.width - 2))[
529                 : self.height * self.width // 2
530             ]
531             l_d = torch.randint(N, l_q.size())
532             for q, e in zip(l_q, l_d):
533                 d = c[e]
534                 i, j = q % (self.height - 2) + 1, q // (self.height - 2) + 1
535                 a1, a2, a3 = X[i - 1, j - 1 : j + 2]
536                 a8, a4 = X[i, j - 1], X[i, j + 1]
537                 a7, a6, a5 = X[i + 1, j - 1 : j + 2]
538                 if (
539                     X[i, j] == 0
540                     and nb[e] < self.width
541                     and (a2 == 0 or a2 == d)
542                     and (a4 == 0 or a4 == d)
543                     and (a6 == 0 or a6 == d)
544                     and (a8 == 0 or a8 == d)
545                     and (a1 == 0 or a2 == d or a8 == d)
546                     and (a3 == 0 or a4 == d or a2 == d)
547                     and (a5 == 0 or a6 == d or a4 == d)
548                     and (a7 == 0 or a8 == d or a6 == d)
549                 ):
550                     o = (
551                         (a2 != 0).long()
552                         + (a4 != 0).long()
553                         + (a6 != 0).long()
554                         + (a8 != 0).long()
555                     )
556                     if o <= 1:
557                         X[i, j] = d
558                         nb[e] += 1 - o
559
560             for e in range(N):
561                 for j in range(nb[e]):
562                     f_X[e, j] = c[e]
563
564     # @torch.compile
565     def task_trajectory(self, A, f_A, B, f_B):
566         c = torch.randperm(len(self.colors) - 1)[:2] + 1
567         for X, f_X in [(A, f_A), (B, f_B)]:
568             while True:
569                 di, dj = torch.randint(7, (2,)) - 3
570                 i, j = torch.randint(self.height, (1,)), torch.randint(self.width, (1,))
571                 if (
572                     abs(di) + abs(dj) > 0
573                     and i + 2 * di >= 0
574                     and i + 2 * di < self.height
575                     and j + 2 * dj >= 0
576                     and j + 2 * dj < self.width
577                 ):
578                     break
579
580             k = 0
581             while (
582                 i + k * di >= 0
583                 and i + k * di < self.height
584                 and j + k * dj >= 0
585                 and j + k * dj < self.width
586             ):
587                 if k < 2:
588                     X[i + k * di, j + k * dj] = c[k]
589                 f_X[i + k * di, j + k * dj] = c[min(k, 1)]
590                 k += 1
591
592     # @torch.compile
593     def task_bounce(self, A, f_A, B, f_B):
594         c = torch.randperm(len(self.colors) - 1)[:3] + 1
595         for X, f_X in [(A, f_A), (B, f_B)]:
596             # @torch.compile
597             def free(i, j):
598                 return (
599                     i >= 0
600                     and i < self.height
601                     and j >= 0
602                     and j < self.width
603                     and f_X[i, j] == 0
604                 )
605
606             while True:
607                 f_X[...] = 0
608                 X[...] = 0
609
610                 for _ in range((self.height * self.width) // 10):
611                     i, j = torch.randint(self.height, (1,)), torch.randint(
612                         self.width, (1,)
613                     )
614                     X[i, j] = c[0]
615                     f_X[i, j] = c[0]
616
617                 while True:
618                     di, dj = torch.randint(7, (2,)) - 3
619                     if abs(di) + abs(dj) == 1:
620                         break
621
622                 i, j = torch.randint(self.height, (1,)), torch.randint(self.width, (1,))
623
624                 X[i, j] = c[1]
625                 f_X[i, j] = c[1]
626                 l = 0
627
628                 while True:
629                     l += 1
630                     if free(i + di, j + dj):
631                         pass
632                     elif free(i - dj, j + di):
633                         di, dj = -dj, di
634                         if free(i + dj, j - di):
635                             if torch.rand(1) < 0.5:
636                                 di, dj = -di, -dj
637                     elif free(i + dj, j - di):
638                         di, dj = dj, -di
639                     else:
640                         break
641
642                     i, j = i + di, j + dj
643                     f_X[i, j] = c[2]
644                     if l <= 1:
645                         X[i, j] = c[2]
646
647                     if l >= self.width:
648                         break
649
650                 f_X[i, j] = c[1]
651                 X[i, j] = c[1]
652
653                 if l > 3:
654                     break
655
656     # @torch.compile
657     def task_scale(self, A, f_A, B, f_B):
658         c = torch.randperm(len(self.colors) - 1)[:2] + 1
659
660         i, j = torch.randint(self.height // 2, (1,)), torch.randint(
661             self.width // 2, (1,)
662         )
663
664         for X, f_X in [(A, f_A), (B, f_B)]:
665             for _ in range(3):
666                 while True:
667                     i1, j1 = torch.randint(self.height // 2 + 1, (1,)), torch.randint(
668                         self.width // 2 + 1, (1,)
669                     )
670                     i2, j2 = torch.randint(self.height // 2 + 1, (1,)), torch.randint(
671                         self.width // 2 + 1, (1,)
672                     )
673                     if i1 < i2 and j1 < j2 and min(i2 - i1, j2 - j1) <= 3:
674                         break
675                 X[i + i1 : i + i2, j + j1 : j + j2] = c[0]
676                 f_X[2 * i1 : 2 * i2, 2 * j1 : 2 * j2] = c[0]
677
678             X[i, j] = c[1]
679             f_X[0:2, 0:2] = c[1]
680
681     # @torch.compile
682     def task_symbols(self, A, f_A, B, f_B):
683         nb_rec = 4
684         c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
685         delta = 3
686         for X, f_X in [(A, f_A), (B, f_B)]:
687             while True:
688                 i, j = torch.randint(self.height - delta + 1, (nb_rec,)), torch.randint(
689                     self.width - delta + 1, (nb_rec,)
690                 )
691                 d = (i[None, :] - i[:, None]).abs().max((j[None, :] - j[:, None]).abs())
692                 d.fill_diagonal_(delta + 1)
693                 if d.min() > delta:
694                     break
695
696             for k in range(1, nb_rec):
697                 X[i[k] : i[k] + delta, j[k] : j[k] + delta] = c[k]
698
699             ai, aj = i.float().mean(), j.float().mean()
700
701             q = torch.randint(3, (1,)) + 1
702
703             X[i[0] + delta // 2 - 1, j[0] + delta // 2 - 1] = c[0]
704             X[i[0] + delta // 2 - 1, j[0] + delta // 2 + 1] = c[0]
705             X[i[0] + delta // 2 + 1, j[0] + delta // 2 - 1] = c[0]
706             X[i[0] + delta // 2 + 1, j[0] + delta // 2 + 1] = c[0]
707
708             assert i[q] != ai and j[q] != aj
709
710             X[
711                 i[0] + delta // 2 + (i[q] - ai).sign().long(),
712                 j[0] + delta // 2 + (j[q] - aj).sign().long(),
713             ] = c[nb_rec]
714
715             f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q]
716
717     # @torch.compile
718     def task_ortho(self, A, f_A, B, f_B):
719         nb_rec = 3
720         di, dj = torch.randint(3, (2,)) - 1
721         o = torch.tensor([[0.0, 1.0], [-1.0, 0.0]])
722         m = torch.eye(2)
723         for _ in range(torch.randint(4, (1,))):
724             m = m @ o
725         if torch.rand(1) < 0.5:
726             m[0, :] = -m[0, :]
727
728         ci, cj = (self.height - 1) / 2, (self.width - 1) / 2
729
730         for X, f_X in [(A, f_A), (B, f_B)]:
731             while True:
732                 X[...] = 0
733                 f_X[...] = 0
734
735                 c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
736
737                 for r in range(nb_rec):
738                     while True:
739                         i1, i2 = torch.randint(self.height - 2, (2,)) + 1
740                         j1, j2 = torch.randint(self.width - 2, (2,)) + 1
741                         if (
742                             i2 >= i1
743                             and j2 >= j1
744                             and max(i2 - i1, j2 - j1) >= 2
745                             and min(i2 - i1, j2 - j1) <= 3
746                         ):
747                             break
748                     X[i1 : i2 + 1, j1 : j2 + 1] = c[r]
749
750                     i1, j1, i2, j2 = i1 - ci, j1 - cj, i2 - ci, j2 - cj
751
752                     i1, j1 = m[0, 0] * i1 + m[0, 1] * j1, m[1, 0] * i1 + m[1, 1] * j1
753                     i2, j2 = m[0, 0] * i2 + m[0, 1] * j2, m[1, 0] * i2 + m[1, 1] * j2
754
755                     i1, j1, i2, j2 = i1 + ci, j1 + cj, i2 + ci, j2 + cj
756                     i1, i2 = i1.long() + di, i2.long() + di
757                     j1, j2 = j1.long() + dj, j2.long() + dj
758                     if i1 > i2:
759                         i1, i2 = i2, i1
760                     if j1 > j2:
761                         j1, j2 = j2, j1
762
763                     f_X[i1 : i2 + 1, j1 : j2 + 1] = c[r]
764
765                 n = F.one_hot(X.flatten()).sum(dim=0)[1:]
766                 if (
767                     n.sum() > self.height * self.width // 4
768                     and (n > 0).long().sum() == nb_rec
769                 ):
770                     break
771
772     def compute_distance(self, walls, goal_i, goal_j, start_i, start_j):
773         max_length = walls.numel()
774         dist = torch.full_like(walls, max_length)
775
776         dist[goal_i, goal_j] = 0
777         pred_dist = torch.empty_like(dist)
778
779         while True:
780             pred_dist.copy_(dist)
781             d = (
782                 torch.cat(
783                     (
784                         dist[None, 1:-1, 0:-2],
785                         dist[None, 2:, 1:-1],
786                         dist[None, 1:-1, 2:],
787                         dist[None, 0:-2, 1:-1],
788                     ),
789                     0,
790                 ).min(dim=0)[0]
791                 + 1
792             )
793
794             dist[1:-1, 1:-1].minimum_(d)  # = torch.min(dist[1:-1, 1:-1], d)
795             dist = walls * max_length + (1 - walls) * dist
796
797             if dist[start_i, start_j] < max_length or dist.equal(pred_dist):
798                 return dist * (1 - walls)
799
800     # @torch.compile
801     def task_path(self, A, f_A, B, f_B):
802         c = torch.randperm(len(self.colors) - 1)[:3] + 1
803         dist = torch.empty(self.height + 2, self.width + 2)
804         for X, f_X in [(A, f_A), (B, f_B)]:
805             nb_rec = torch.randint(3, (1,)) + 1
806             while True:
807                 r = self.rec_coo(nb_rec, prevent_overlap=True)
808                 X[...] = 0
809                 f_X[...] = 0
810                 for n in range(nb_rec):
811                     i1, j1, i2, j2 = r[n]
812                     X[i1:i2, j1:j2] = c[0]
813                     f_X[i1:i2, j1:j2] = c[0]
814                 while True:
815                     i0, j0 = torch.randint(self.height, (1,)), torch.randint(
816                         self.width, (1,)
817                     )
818                     if X[i0, j0] == 0:
819                         break
820                 while True:
821                     i1, j1 = torch.randint(self.height, (1,)), torch.randint(
822                         self.width, (1,)
823                     )
824                     if X[i1, j1] == 0:
825                         break
826                 dist[...] = 1
827                 dist[1:-1, 1:-1] = (X != 0).long()
828                 dist[...] = self.compute_distance(dist, i1 + 1, j1 + 1, i0 + 1, j0 + 1)
829                 if dist[i0 + 1, j0 + 1] >= 1 and dist[i0 + 1, j0 + 1] < self.height * 4:
830                     break
831
832             dist[1:-1, 1:-1] += (X != 0).long() * self.height * self.width
833             dist[0, :] = self.height * self.width
834             dist[-1, :] = self.height * self.width
835             dist[:, 0] = self.height * self.width
836             dist[:, -1] = self.height * self.width
837             # dist += torch.rand(dist.size())
838
839             i, j = i0 + 1, j0 + 1
840             while i != i1 + 1 or j != j1 + 1:
841                 f_X[i - 1, j - 1] = c[2]
842                 r, s, t, u = (
843                     dist[i - 1, j],
844                     dist[i, j - 1],
845                     dist[i + 1, j],
846                     dist[i, j + 1],
847                 )
848                 m = min(r, s, t, u)
849                 if r == m:
850                     i = i - 1
851                 elif t == m:
852                     i = i + 1
853                 elif s == m:
854                     j = j - 1
855                 else:
856                     j = j + 1
857
858             X[i0, j0] = c[2]
859             # f_X[i0, j0] = c[1]
860
861             X[i1, j1] = c[1]
862             f_X[i1, j1] = c[1]
863
864     # for X, f_X in [(A, f_A), (B, f_B)]:
865     # n = torch.arange(self.height * self.width).reshape(self.height, self.width)
866     # k = torch.randperm(self.height * self.width)
867     # X[...]=-1
868     # for q in k:
869     # i,j=q%self.height,q//self.height
870     # if
871
872     ######################################################################
873
874     def all_tasks(self):
875         return [
876             self.task_replace_color,
877             self.task_translate,
878             self.task_grow,
879             self.task_color_grow,
880             self.task_frame,
881             self.task_detect,
882             self.task_count,
883             self.task_trajectory,
884             self.task_bounce,
885             self.task_scale,
886             self.task_symbols,
887             self.task_ortho,
888             #            self.task_path,
889         ]
890
891     def trivial_prompts_and_answers(self, prompts, answers):
892         S = self.height * self.width
893         Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S]
894         f_Bs = answers
895         return (Bs == f_Bs).long().min(dim=-1).values > 0
896
897     def generate_prompts_and_answers_(self, nb, tasks=None, progress_bar=False):
898         if tasks is None:
899             tasks = self.all_tasks()
900
901         S = self.height * self.width
902         prompts = torch.zeros(nb, 3 * S + 2, dtype=torch.int64)
903         answers = torch.zeros(nb, S, dtype=torch.int64)
904
905         bunch = zip(prompts, answers)
906
907         if progress_bar:
908             bunch = tqdm.tqdm(
909                 bunch,
910                 dynamic_ncols=True,
911                 desc="world generation",
912                 total=prompts.size(0),
913             )
914
915         for prompt, answer in bunch:
916             A = prompt[0 * (S + 1) : 0 * (S + 1) + S].view(self.height, self.width)
917             f_A = prompt[1 * (S + 1) : 1 * (S + 1) + S].view(self.height, self.width)
918             B = prompt[2 * (S + 1) : 2 * (S + 1) + S].view(self.height, self.width)
919             f_B = answer.view(self.height, self.width)
920             task = tasks[torch.randint(len(tasks), (1,))]
921             task(A, f_A, B, f_B)
922
923         return prompts.flatten(1), answers.flatten(1)
924
925     def save_quizzes(
926         self,
927         result_dir,
928         filename_prefix,
929         prompts,
930         answers,
931         predicted_prompts=None,
932         predicted_answers=None,
933         nrow=4,
934     ):
935         self.save_image(
936             result_dir,
937             filename_prefix + ".png",
938             prompts,
939             answers,
940             predicted_prompts,
941             predicted_answers,
942             nrow,
943         )
944
945     def save_some_examples(self, result_dir):
946         nb, nrow = 72, 4
947         for t in self.all_tasks():
948             print(t.__name__)
949             prompts, answers = self.generate_prompts_and_answers_(nb, tasks=[t])
950             self.save_quizzes(
951                 result_dir, t.__name__, prompts[:nb], answers[:nb], nrow=nrow
952             )
953
954
955 ######################################################################
956
957 if __name__ == "__main__":
958     import time
959
960     # grids = Grids(max_nb_cached_chunks=5, chunk_size=100, nb_threads=4)
961     grids = Grids()
962
963     # nb = 1000
964     # grids = problem.MultiThreadProblem(
965     # grids, max_nb_cached_chunks=50, chunk_size=100, nb_threads=1
966     # )
967     #    time.sleep(10)
968     # start_time = time.perf_counter()
969     # prompts, answers = grids.generate_prompts_and_answers(nb)
970     # delay = time.perf_counter() - start_time
971     # print(f"{prompts.size(0)/delay:02f} seq/s")
972     # exit(0)
973
974     # if True:
975     nb, nrow = 72, 4
976     # nb, nrow = 8, 2
977
978     # for t in grids.all_tasks():
979     for t in [grids.task_path]:
980         print(t.__name__)
981         prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
982         grids.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow)
983
984     # exit(0)
985
986     nb = 1000
987
988     for t in grids.all_tasks():
989         # for t in [ grids.task_replace_color ]: #grids.all_tasks():
990         start_time = time.perf_counter()
991         prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
992         delay = time.perf_counter() - start_time
993         print(f"{t.__name__} {prompts.size(0)/delay:02f} seq/s")
994
995     exit(0)
996
997     m = torch.randint(2, (prompts.size(0),))
998     predicted_prompts = m * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
999     predicted_answers = (1 - m) * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
1000
1001     grids.save_quizzes(
1002         "/tmp",
1003         "test",
1004         prompts[:nb],
1005         answers[:nb],
1006         # You can add a bool to put a frame around the predicted parts
1007         predicted_prompts[:nb],
1008         predicted_answers[:nb],
1009     )