8d144cfdace227083ca164cc16b76b7b20eb401f
[culture.git] / grids.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, tqdm, os, warnings
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 ######################################################################
16
17 import problem
18
19
20 def grow_islands(nb, height, width, nb_seeds, nb_iterations):
21     w = torch.empty(5, 1, 3, 3)
22
23     w[0, 0] = torch.tensor(
24         [
25             [1.0, 1.0, 1.0],
26             [1.0, 0.0, 1.0],
27             [1.0, 1.0, 1.0],
28         ]
29     )
30
31     w[1, 0] = torch.tensor(
32         [
33             [-1.0, 1.0, 0.0],
34             [1.0, 0.0, 0.0],
35             [0.0, 0.0, 0.0],
36         ]
37     )
38
39     w[2, 0] = torch.tensor(
40         [
41             [0.0, 1.0, -1.0],
42             [0.0, 0.0, 1.0],
43             [0.0, 0.0, 0.0],
44         ]
45     )
46
47     w[3, 0] = torch.tensor(
48         [
49             [0.0, 0.0, 0.0],
50             [0.0, 0.0, 1.0],
51             [0.0, 1.0, -1.0],
52         ]
53     )
54
55     w[4, 0] = torch.tensor(
56         [
57             [0.0, 0.0, 0.0],
58             [1.0, 0.0, 0.0],
59             [-1.0, 1.0, 0.0],
60         ]
61     )
62
63     Z = torch.zeros(nb, height, width)
64     U = Z.flatten(1)
65
66     for _ in range(nb_seeds):
67         M = F.conv2d(Z[:, None, :, :], w, padding=1)
68         M = torch.cat([M[:, :1], M[:, 1:].min(dim=1, keepdim=True).values], dim=1)
69         M = ((M[:, 0] == 0) & (Z == 0)).long()
70         M = M * torch.rand(M.size())
71         M = M.flatten(1)
72         M = F.one_hot(M.argmax(dim=1), num_classes=M.size(1))
73         U += M
74
75     for _ in range(nb_iterations):
76         M = F.conv2d(Z[:, None, :, :], w, padding=1)
77         M = torch.cat([M[:, :1], M[:, 1:].min(dim=1, keepdim=True).values], dim=1)
78         M = ((M[:, 1] >= 0) & (Z == 0)).long()
79         M = M * torch.rand(M.size())
80         M = M.flatten(1)
81         M = F.one_hot(M.argmax(dim=1), num_classes=M.size(1))
82         U = Z.flatten(1)
83         U += M
84
85     M = Z.clone()
86     Z = Z * (torch.arange(Z.size(1) * Z.size(2)) + 1).reshape(1, Z.size(1), Z.size(2))
87
88     for _ in range(100):
89         Z = F.max_pool2d(Z, 3, 1, 1) * M
90
91     Z = Z.long()
92     U = Z.flatten(1)
93     V = F.one_hot(U).max(dim=1).values
94     W = V.cumsum(dim=1) - V
95     N = torch.arange(Z.size(0))[:, None, None].expand_as(Z)
96     Z = W[N, Z]
97
98     return Z
99
100
101 class Grids(problem.Problem):
102     named_colors = [
103         ("white", [255, 255, 255]),
104         ("red", [255, 0, 0]),
105         ("green", [0, 192, 0]),
106         ("blue", [0, 0, 255]),
107         ("yellow", [255, 224, 0]),
108         ("cyan", [0, 255, 255]),
109         ("violet", [224, 128, 255]),
110         ("lightgreen", [192, 255, 192]),
111         ("brown", [165, 42, 42]),
112         ("lightblue", [192, 192, 255]),
113         ("gray", [128, 128, 128]),
114     ]
115
116     def __init__(
117         self,
118         max_nb_cached_chunks=None,
119         chunk_size=None,
120         nb_threads=-1,
121         tasks=None,
122     ):
123         self.colors = torch.tensor([c for _, c in self.named_colors])
124         self.height = 10
125         self.width = 10
126         self.cache_rec_coo = {}
127
128         all_tasks = [
129             self.task_replace_color,
130             self.task_translate,
131             self.task_grow,
132             self.task_half_fill,
133             self.task_frame,
134             self.task_detect,
135             self.task_count,
136             self.task_trajectory,
137             self.task_bounce,
138             self.task_scale,
139             self.task_symbols,
140             self.task_isometry,
141             #            self.task_path,
142         ]
143
144         if tasks is None:
145             self.all_tasks = all_tasks
146         else:
147             self.all_tasks = [getattr(self, "task_" + t) for t in tasks.split(",")]
148
149         super().__init__(max_nb_cached_chunks, chunk_size, nb_threads)
150
151     ######################################################################
152
153     def frame2img(self, x, scale=15):
154         x = x.reshape(x.size(0), self.height, -1)
155         m = torch.logical_and(x >= 0, x < self.nb_token_values()).long()
156         x = self.colors[x * m].permute(0, 3, 1, 2)
157         s = x.shape
158         x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
159         x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
160
161         x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
162         x[:, :, torch.arange(0, x.size(2), scale), :] = 0
163         x = x[:, :, 1:, 1:]
164
165         for n in range(m.size(0)):
166             for i in range(m.size(1)):
167                 for j in range(m.size(2)):
168                     if m[n, i, j] == 0:
169                         for k in range(2, scale - 2):
170                             for l in [0, 1]:
171                                 x[n, :, i * scale + k, j * scale + k - l] = 0
172                                 x[
173                                     n, :, i * scale + scale - 1 - k, j * scale + k - l
174                                 ] = 0
175
176         return x
177
178     def save_image(
179         self,
180         result_dir,
181         filename,
182         prompts,
183         answers,
184         predicted_prompts=None,
185         predicted_answers=None,
186         nrow=4,
187         margin=8,
188     ):
189         S = self.height * self.width
190         As = prompts[:, 0 * (S + 1) : 0 * (S + 1) + S].view(-1, self.height, self.width)
191         f_As = prompts[:, 1 * (S + 1) : 1 * (S + 1) + S].view(
192             -1, self.height, self.width
193         )
194         Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S].view(-1, self.height, self.width)
195         prompts = torch.cat([As, f_As, Bs], dim=2)
196         answers = answers.reshape(answers.size(0), self.height, self.width)
197
198         if predicted_prompts is None:
199             predicted_prompts = 255
200
201         if predicted_answers is None:
202             predicted_answers = 255
203
204         def add_frame(x, c, margin, bottom=False):
205             if bottom:
206                 h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0
207             else:
208                 h, w, di, dj = (
209                     x.size(2) + 2 * margin,
210                     x.size(3) + 2 * margin,
211                     margin,
212                     margin,
213                 )
214
215             y = x.new_full((x.size(0), x.size(1), h, w), 0)
216
217             if type(c) is int:
218                 y[...] = c
219             else:
220                 c = c.long()[:, None]
221                 c = (
222                     (1 - ((c == 1).long() + (c == 0).long() + (c == -1).long()))
223                     * torch.tensor([64, 64, 64])
224                     + (c == 1).long() * torch.tensor([0, 255, 0])
225                     + (c == 0).long() * torch.tensor([255, 255, 255])
226                     + (c == -1).long() * torch.tensor([255, 0, 0])
227                 )
228                 y[...] = c[:, :, None, None]
229
230             y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x
231
232             return y
233
234         img_prompts = torch.cat(
235             [
236                 add_frame(
237                     add_frame(self.frame2img(x), c=0, margin=1),
238                     c=predicted_prompts,
239                     margin=margin,
240                 )
241                 for x in prompts.to("cpu").split(split_size=self.width, dim=2)
242             ],
243             dim=3,
244         )
245
246         h = img_prompts.size(2)
247         img_answers = add_frame(
248             add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1),
249             c=predicted_answers,
250             margin=margin,
251         )
252
253         separator_size = 2 * margin
254
255         separator = img_prompts.new_full(
256             (
257                 img_prompts.size(0),
258                 img_prompts.size(1),
259                 img_prompts.size(2),
260                 separator_size,
261             ),
262             255,
263         )
264
265         marker = img_prompts.new_full(
266             (
267                 img_prompts.size(0),
268                 img_prompts.size(1),
269                 img_prompts.size(2),
270                 separator_size,
271             ),
272             255,
273         )
274
275         # marker[:, :, 0] = 0
276         # marker[:, :, h - 1] = 0
277
278         for k in range(1, 2 * separator_size - 8):
279             i = k - (separator_size - 4)
280             j = separator_size - 5 - abs(i)
281             marker[:, :, h // 2 - 1 + i, 2 + j] = 0
282             marker[:, :, h // 2 - 1 + i + 1, 2 + j] = 0
283
284         img = torch.cat(
285             [
286                 img_prompts,
287                 marker,
288                 img_answers,
289             ],
290             dim=3,
291         )
292
293         image_name = os.path.join(result_dir, filename)
294         torchvision.utils.save_image(
295             img.float() / 255.0,
296             image_name,
297             nrow=nrow,
298             padding=margin * 4,
299             pad_value=1.0,
300         )
301
302     ######################################################################
303
304     def nb_token_values(self):
305         return len(self.colors)
306
307     # @torch.compile
308     def rec_coo(
309         self,
310         nb_rec,
311         min_height=3,
312         min_width=3,
313         surface_max=None,
314         prevent_overlap=False,
315     ):
316         if surface_max is None:
317             surface_max = self.height * self.width // 2
318
319         signature = (nb_rec, min_height, min_width, surface_max)
320
321         try:
322             return self.cache_rec_coo[signature].pop()
323         except IndexError:
324             pass
325         except KeyError:
326             pass
327
328         N = 10000
329         while True:
330             while True:
331                 i = torch.randint(self.height, (N * nb_rec, 2)).sort(dim=-1).values
332                 j = torch.randint(self.width, (N * nb_rec, 2)).sort(dim=-1).values
333
334                 big_enough = (
335                     (i[:, 1] >= i[:, 0] + min_height)
336                     & (j[:, 1] >= j[:, 0] + min_height)
337                     & ((i[:, 1] - i[:, 0]) * (j[:, 1] - j[:, 0]) <= surface_max)
338                 )
339
340                 i, j = i[big_enough], j[big_enough]
341
342                 n = i.size(0) - i.size(0) % nb_rec
343
344                 if n > 0:
345                     break
346
347             i = i[:n].reshape(n // nb_rec, nb_rec, -1)
348             j = j[:n].reshape(n // nb_rec, nb_rec, -1)
349
350             if prevent_overlap:
351                 can_fit = ((i[:, :, 1] - i[:, :, 0]) * (j[:, :, 1] - j[:, :, 0])).sum(
352                     dim=-1
353                 ) <= self.height * self.width
354                 i, j = i[can_fit], j[can_fit]
355                 if nb_rec == 2:
356                     A_i1, A_i2, A_j1, A_j2 = (
357                         i[:, 0, 0],
358                         i[:, 0, 1],
359                         j[:, 0, 0],
360                         j[:, 0, 1],
361                     )
362                     B_i1, B_i2, B_j1, B_j2 = (
363                         i[:, 1, 0],
364                         i[:, 1, 1],
365                         j[:, 1, 0],
366                         j[:, 1, 1],
367                     )
368                     no_overlap = torch.logical_not(
369                         (A_i1 >= B_i2)
370                         & (A_i2 <= B_i1)
371                         & (A_j1 >= B_j1)
372                         & (A_j2 <= B_j1)
373                     )
374                     i, j = i[no_overlap], j[no_overlap]
375                 elif nb_rec == 3:
376                     A_i1, A_i2, A_j1, A_j2 = (
377                         i[:, 0, 0],
378                         i[:, 0, 1],
379                         j[:, 0, 0],
380                         j[:, 0, 1],
381                     )
382                     B_i1, B_i2, B_j1, B_j2 = (
383                         i[:, 1, 0],
384                         i[:, 1, 1],
385                         j[:, 1, 0],
386                         j[:, 1, 1],
387                     )
388                     C_i1, C_i2, C_j1, C_j2 = (
389                         i[:, 2, 0],
390                         i[:, 2, 1],
391                         j[:, 2, 0],
392                         j[:, 2, 1],
393                     )
394                     no_overlap = (
395                         (
396                             (A_i1 >= B_i2)
397                             | (A_i2 <= B_i1)
398                             | (A_j1 >= B_j2)
399                             | (A_j2 <= B_j1)
400                         )
401                         & (
402                             (A_i1 >= C_i2)
403                             | (A_i2 <= C_i1)
404                             | (A_j1 >= C_j2)
405                             | (A_j2 <= C_j1)
406                         )
407                         & (
408                             (B_i1 >= C_i2)
409                             | (B_i2 <= C_i1)
410                             | (B_j1 >= C_j2)
411                             | (B_j2 <= C_j1)
412                         )
413                     )
414                     i, j = (i[no_overlap], j[no_overlap])
415                 else:
416                     assert nb_rec == 1
417
418             if i.size(0) > 1:
419                 break
420
421         self.cache_rec_coo[signature] = [
422             [
423                 (
424                     i[n, k, 0].item(),
425                     j[n, k, 0].item(),
426                     i[n, k, 1].item(),
427                     j[n, k, 1].item(),
428                 )
429                 for k in range(nb_rec)
430             ]
431             for n in range(i.size(0))
432         ]
433
434         return self.cache_rec_coo[signature].pop()
435
436     ######################################################################
437
438     # @torch.compile
439     def task_replace_color(self, A, f_A, B, f_B):
440         nb_rec = 3
441         c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
442         for X, f_X in [(A, f_A), (B, f_B)]:
443             r = self.rec_coo(nb_rec, prevent_overlap=True)
444             for n in range(nb_rec):
445                 i1, j1, i2, j2 = r[n]
446                 X[i1:i2, j1:j2] = c[n]
447                 f_X[i1:i2, j1:j2] = c[n if n > 0 else -1]
448
449     # @torch.compile
450     def task_translate(self, A, f_A, B, f_B):
451         while True:
452             di, dj = torch.randint(3, (2,)) - 1
453             if di.abs() + dj.abs() > 0:
454                 break
455
456         nb_rec = 3
457         c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
458         for X, f_X in [(A, f_A), (B, f_B)]:
459             while True:
460                 r = self.rec_coo(nb_rec, prevent_overlap=True)
461                 i1, j1, i2, j2 = r[nb_rec - 1]
462                 if (
463                     i1 + di >= 0
464                     and i2 + di < X.size(0)
465                     and j1 + dj >= 0
466                     and j2 + dj < X.size(1)
467                 ):
468                     break
469
470             for n in range(nb_rec):
471                 i1, j1, i2, j2 = r[n]
472                 X[i1:i2, j1:j2] = c[n]
473                 if n == nb_rec - 1:
474                     f_X[i1 + di : i2 + di, j1 + dj : j2 + dj] = c[n]
475                 else:
476                     f_X[i1:i2, j1:j2] = c[n]
477
478     # @torch.compile
479     def task_grow(self, A, f_A, B, f_B):
480         di, dj = torch.randint(2, (2,)) * 2 - 1
481         nb_rec = 3
482         c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
483         direction = torch.randint(2, (1,)).item()
484         for X, f_X in [(A, f_A), (B, f_B)]:
485             while True:
486                 r = self.rec_coo(nb_rec, prevent_overlap=True)
487                 i1, j1, i2, j2 = r[nb_rec - 1]
488                 if i1 + 3 < i2 and j1 + 3 < j2:
489                     break
490
491             for n in range(nb_rec):
492                 i1, j1, i2, j2 = r[n]
493                 if n == nb_rec - 1:
494                     if direction == 0:
495                         X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
496                         f_X[i1:i2, j1:j2] = c[n]
497                     else:
498                         X[i1:i2, j1:j2] = c[n]
499                         f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
500                 else:
501                     X[i1:i2, j1:j2] = c[n]
502                     f_X[i1:i2, j1:j2] = c[n]
503
504     # @torch.compile
505     def task_half_fill(self, A, f_A, B, f_B):
506         di, dj = torch.randint(2, (2,)) * 2 - 1
507         nb_rec = 3
508         c = torch.randperm(len(self.colors) - 1)[: 2 * nb_rec] + 1
509         direction = torch.randint(4, (1,)).item()
510         for X, f_X in [(A, f_A), (B, f_B)]:
511             r = self.rec_coo(nb_rec, prevent_overlap=True)
512             for n in range(nb_rec):
513                 i1, j1, i2, j2 = r[n]
514                 X[i1:i2, j1:j2] = c[2 * n]
515                 f_X[i1:i2, j1:j2] = c[2 * n]
516                 # Not my proudest moment
517                 if direction == 0:
518                     i = (i1 + i2) // 2
519                     X[i : i + 1, j1:j2] = c[2 * n + 1]
520                     if n == nb_rec - 1:
521                         f_X[i:i2, j1:j2] = c[2 * n + 1]
522                     else:
523                         f_X[i : i + 1, j1:j2] = c[2 * n + 1]
524                 elif direction == 1:
525                     i = (i1 + i2 - 1) // 2
526                     X[i : i + 1, j1:j2] = c[2 * n + 1]
527                     if n == nb_rec - 1:
528                         f_X[i1 : i + 1, j1:j2] = c[2 * n + 1]
529                     else:
530                         f_X[i : i + 1, j1:j2] = c[2 * n + 1]
531                 elif direction == 2:
532                     j = (j1 + j2) // 2
533                     X[i1:i2, j : j + 1] = c[2 * n + 1]
534                     if n == nb_rec - 1:
535                         f_X[i1:i2, j:j2] = c[2 * n + 1]
536                     else:
537                         f_X[i1:i2, j : j + 1] = c[2 * n + 1]
538                 elif direction == 3:
539                     j = (j1 + j2 - 1) // 2
540                     X[i1:i2, j : j + 1] = c[2 * n + 1]
541                     if n == nb_rec - 1:
542                         f_X[i1:i2, j1 : j + 1] = c[2 * n + 1]
543                     else:
544                         f_X[i1:i2, j : j + 1] = c[2 * n + 1]
545
546     # @torch.compile
547     def task_frame(self, A, f_A, B, f_B):
548         nb_rec = 3
549         c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
550         for X, f_X in [(A, f_A), (B, f_B)]:
551             r = self.rec_coo(nb_rec, prevent_overlap=True)
552             for n in range(nb_rec):
553                 i1, j1, i2, j2 = r[n]
554                 X[i1:i2, j1:j2] = c[n]
555                 if n == nb_rec - 1:
556                     f_X[i1:i2, j1] = c[n]
557                     f_X[i1:i2, j2 - 1] = c[n]
558                     f_X[i1, j1:j2] = c[n]
559                     f_X[i2 - 1, j1:j2] = c[n]
560                 else:
561                     f_X[i1:i2, j1:j2] = c[n]
562
563     # @torch.compile
564     def task_detect(self, A, f_A, B, f_B):
565         nb_rec = 3
566         c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
567         for X, f_X in [(A, f_A), (B, f_B)]:
568             r = self.rec_coo(nb_rec, prevent_overlap=True)
569             for n in range(nb_rec):
570                 i1, j1, i2, j2 = r[n]
571                 X[i1:i2, j1:j2] = c[n]
572                 if n < nb_rec - 1:
573                     f_X[i1, j1] = c[-1]
574
575     # @torch.compile
576     def contact(self, X, i, j, q):
577         nq, nq_diag = 0, 0
578         no = 0
579
580         for ii, jj in [
581             (i - 1, j - 1),
582             (i - 1, j),
583             (i - 1, j + 1),
584             (i, j - 1),
585             (i, j + 1),
586             (i + 1, j - 1),
587             (i + 1, j),
588             (i + 1, j + 1),
589         ]:
590             if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width:
591                 if X[ii, jj] != 0 and X[ii, jj] != q:
592                     no += 1
593
594         for ii, jj in [
595             (i - 1, j - 1),
596             (i - 1, j + 1),
597             (i + 1, j - 1),
598             (i + 1, j + 1),
599         ]:
600             if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width:
601                 if X[ii, jj] == q and X[i, jj] != q and X[ii, j] != q:
602                     nq_diag += 1
603
604         for ii, jj in [(i - 1, j), (i, j - 1), (i, j + 1), (i + 1, j)]:
605             if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width:
606                 if X[ii, jj] == q:
607                     nq += 1
608
609         return no, nq, nq_diag
610
611     def task_count(self, A, f_A, B, f_B):
612         N = torch.randint(4, (1,)).item() + 2
613         c = torch.randperm(len(self.colors) - 1)[:N] + 1
614
615         for X, f_X in [(A, f_A), (B, f_B)]:
616             l_q = torch.randperm(self.height * self.width)[
617                 : self.height * self.width // 20
618             ]
619             l_d = torch.randint(N, l_q.size())
620             nb = torch.zeros(N, dtype=torch.int64)
621
622             for q, e in zip(l_q, l_d):
623                 d = c[e]
624                 i, j = q % self.height, q // self.height
625                 if (
626                     nb[e] < self.width
627                     and X[max(0, i - 1) : i + 2, max(0, j - 1) : j + 2] == 0
628                 ).all():
629                     X[i, j] = d
630                     nb[e] += 1
631
632             l_q = torch.randperm((self.height - 2) * (self.width - 2))[
633                 : self.height * self.width // 2
634             ]
635             l_d = torch.randint(N, l_q.size())
636             for q, e in zip(l_q, l_d):
637                 d = c[e]
638                 i, j = q % (self.height - 2) + 1, q // (self.height - 2) + 1
639                 a1, a2, a3 = X[i - 1, j - 1 : j + 2]
640                 a8, a4 = X[i, j - 1], X[i, j + 1]
641                 a7, a6, a5 = X[i + 1, j - 1 : j + 2]
642                 if (
643                     X[i, j] == 0
644                     and nb[e] < self.width
645                     and (a2 == 0 or a2 == d)
646                     and (a4 == 0 or a4 == d)
647                     and (a6 == 0 or a6 == d)
648                     and (a8 == 0 or a8 == d)
649                     and (a1 == 0 or a2 == d or a8 == d)
650                     and (a3 == 0 or a4 == d or a2 == d)
651                     and (a5 == 0 or a6 == d or a4 == d)
652                     and (a7 == 0 or a8 == d or a6 == d)
653                 ):
654                     o = (
655                         (a2 != 0).long()
656                         + (a4 != 0).long()
657                         + (a6 != 0).long()
658                         + (a8 != 0).long()
659                     )
660                     if o <= 1:
661                         X[i, j] = d
662                         nb[e] += 1 - o
663
664             for e in range(N):
665                 for j in range(nb[e]):
666                     f_X[e, j] = c[e]
667
668     # @torch.compile
669     def task_trajectory(self, A, f_A, B, f_B):
670         c = torch.randperm(len(self.colors) - 1)[:2] + 1
671         for X, f_X in [(A, f_A), (B, f_B)]:
672             while True:
673                 di, dj = torch.randint(7, (2,)) - 3
674                 i, j = (
675                     torch.randint(self.height, (1,)).item(),
676                     torch.randint(self.width, (1,)).item(),
677                 )
678                 if (
679                     abs(di) + abs(dj) > 0
680                     and i + 2 * di >= 0
681                     and i + 2 * di < self.height
682                     and j + 2 * dj >= 0
683                     and j + 2 * dj < self.width
684                 ):
685                     break
686
687             k = 0
688             while (
689                 i + k * di >= 0
690                 and i + k * di < self.height
691                 and j + k * dj >= 0
692                 and j + k * dj < self.width
693             ):
694                 if k < 2:
695                     X[i + k * di, j + k * dj] = c[k]
696                 f_X[i + k * di, j + k * dj] = c[min(k, 1)]
697                 k += 1
698
699     # @torch.compile
700     def task_bounce(self, A, f_A, B, f_B):
701         c = torch.randperm(len(self.colors) - 1)[:3] + 1
702         for X, f_X in [(A, f_A), (B, f_B)]:
703             # @torch.compile
704             def free(i, j):
705                 return (
706                     i >= 0
707                     and i < self.height
708                     and j >= 0
709                     and j < self.width
710                     and f_X[i, j] == 0
711                 )
712
713             while True:
714                 f_X[...] = 0
715                 X[...] = 0
716
717                 for _ in range((self.height * self.width) // 10):
718                     i, j = (
719                         torch.randint(self.height, (1,)).item(),
720                         torch.randint(self.width, (1,)).item(),
721                     )
722                     X[i, j] = c[0]
723                     f_X[i, j] = c[0]
724
725                 while True:
726                     di, dj = torch.randint(7, (2,)) - 3
727                     if abs(di) + abs(dj) == 1:
728                         break
729
730                 i, j = (
731                     torch.randint(self.height, (1,)).item(),
732                     torch.randint(self.width, (1,)).item(),
733                 )
734
735                 X[i, j] = c[1]
736                 f_X[i, j] = c[1]
737                 l = 0
738
739                 while True:
740                     l += 1
741                     if free(i + di, j + dj):
742                         pass
743                     elif free(i - dj, j + di):
744                         di, dj = -dj, di
745                         if free(i + dj, j - di):
746                             if torch.rand(1) < 0.5:
747                                 di, dj = -di, -dj
748                     elif free(i + dj, j - di):
749                         di, dj = dj, -di
750                     else:
751                         break
752
753                     i, j = i + di, j + dj
754                     f_X[i, j] = c[2]
755                     if l <= 1:
756                         X[i, j] = c[2]
757
758                     if l >= self.width:
759                         break
760
761                 f_X[i, j] = c[1]
762                 X[i, j] = c[1]
763
764                 if l > 3:
765                     break
766
767     # @torch.compile
768     def task_scale(self, A, f_A, B, f_B):
769         c = torch.randperm(len(self.colors) - 1)[:2] + 1
770
771         i, j = (
772             torch.randint(self.height // 2, (1,)).item(),
773             torch.randint(self.width // 2, (1,)).item(),
774         )
775
776         for X, f_X in [(A, f_A), (B, f_B)]:
777             for _ in range(3):
778                 while True:
779                     i1, j1 = (
780                         torch.randint(self.height // 2 + 1, (1,)).item(),
781                         torch.randint(self.width // 2 + 1, (1,)).item(),
782                     )
783                     i2, j2 = (
784                         torch.randint(self.height // 2 + 1, (1,)).item(),
785                         torch.randint(self.width // 2 + 1, (1,)).item(),
786                     )
787                     if i1 < i2 and j1 < j2 and min(i2 - i1, j2 - j1) <= 3:
788                         break
789                 X[i + i1 : i + i2, j + j1 : j + j2] = c[0]
790                 f_X[2 * i1 : 2 * i2, 2 * j1 : 2 * j2] = c[0]
791
792             X[i, j] = c[1]
793             f_X[0:2, 0:2] = c[1]
794
795     # @torch.compile
796     def task_symbols(self, A, f_A, B, f_B):
797         nb_rec = 4
798         c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
799         delta = 3
800         for X, f_X in [(A, f_A), (B, f_B)]:
801             while True:
802                 i, j = torch.randint(self.height - delta + 1, (nb_rec,)), torch.randint(
803                     self.width - delta + 1, (nb_rec,)
804                 )
805                 d = (i[None, :] - i[:, None]).abs().max((j[None, :] - j[:, None]).abs())
806                 d.fill_diagonal_(delta + 1)
807                 if d.min() > delta:
808                     break
809
810             for k in range(1, nb_rec):
811                 X[i[k] : i[k] + delta, j[k] : j[k] + delta] = c[k]
812
813             ai, aj = i.float().mean(), j.float().mean()
814
815             q = torch.randint(3, (1,)).item() + 1
816
817             X[i[0] + delta // 2 - 1, j[0] + delta // 2 - 1] = c[0]
818             X[i[0] + delta // 2 - 1, j[0] + delta // 2 + 1] = c[0]
819             X[i[0] + delta // 2 + 1, j[0] + delta // 2 - 1] = c[0]
820             X[i[0] + delta // 2 + 1, j[0] + delta // 2 + 1] = c[0]
821
822             assert i[q] != ai and j[q] != aj
823
824             X[
825                 i[0] + delta // 2 + (i[q] - ai).sign().long(),
826                 j[0] + delta // 2 + (j[q] - aj).sign().long(),
827             ] = c[nb_rec]
828
829             f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q]
830
831     # @torch.compile
832     def task_isometry(self, A, f_A, B, f_B):
833         nb_rec = 3
834         di, dj = torch.randint(3, (2,)) - 1
835         o = torch.tensor([[0.0, 1.0], [-1.0, 0.0]])
836         m = torch.eye(2)
837         for _ in range(torch.randint(4, (1,)).item()):
838             m = m @ o
839         if torch.rand(1) < 0.5:
840             m[0, :] = -m[0, :]
841
842         ci, cj = (self.height - 1) / 2, (self.width - 1) / 2
843
844         for X, f_X in [(A, f_A), (B, f_B)]:
845             while True:
846                 X[...] = 0
847                 f_X[...] = 0
848
849                 c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
850
851                 for r in range(nb_rec):
852                     while True:
853                         i1, i2 = torch.randint(self.height - 2, (2,)) + 1
854                         j1, j2 = torch.randint(self.width - 2, (2,)) + 1
855                         if (
856                             i2 >= i1
857                             and j2 >= j1
858                             and max(i2 - i1, j2 - j1) >= 2
859                             and min(i2 - i1, j2 - j1) <= 3
860                         ):
861                             break
862                     X[i1 : i2 + 1, j1 : j2 + 1] = c[r]
863
864                     i1, j1, i2, j2 = i1 - ci, j1 - cj, i2 - ci, j2 - cj
865
866                     i1, j1 = m[0, 0] * i1 + m[0, 1] * j1, m[1, 0] * i1 + m[1, 1] * j1
867                     i2, j2 = m[0, 0] * i2 + m[0, 1] * j2, m[1, 0] * i2 + m[1, 1] * j2
868
869                     i1, j1, i2, j2 = i1 + ci, j1 + cj, i2 + ci, j2 + cj
870                     i1, i2 = i1.long() + di, i2.long() + di
871                     j1, j2 = j1.long() + dj, j2.long() + dj
872                     if i1 > i2:
873                         i1, i2 = i2, i1
874                     if j1 > j2:
875                         j1, j2 = j2, j1
876
877                     f_X[i1 : i2 + 1, j1 : j2 + 1] = c[r]
878
879                 n = F.one_hot(X.flatten()).sum(dim=0)[1:]
880                 if (
881                     n.sum() > self.height * self.width // 4
882                     and (n > 0).long().sum() == nb_rec
883                 ):
884                     break
885
886     def compute_distance(self, walls, goal_i, goal_j, start_i, start_j):
887         max_length = walls.numel()
888         dist = torch.full_like(walls, max_length)
889
890         dist[goal_i, goal_j] = 0
891         pred_dist = torch.empty_like(dist)
892
893         while True:
894             pred_dist.copy_(dist)
895             d = (
896                 torch.cat(
897                     (
898                         dist[None, 1:-1, 0:-2],
899                         dist[None, 2:, 1:-1],
900                         dist[None, 1:-1, 2:],
901                         dist[None, 0:-2, 1:-1],
902                     ),
903                     0,
904                 ).min(dim=0)[0]
905                 + 1
906             )
907
908             dist[1:-1, 1:-1].minimum_(d)  # = torch.min(dist[1:-1, 1:-1], d)
909             dist = walls * max_length + (1 - walls) * dist
910
911             if dist[start_i, start_j] < max_length or dist.equal(pred_dist):
912                 return dist * (1 - walls)
913
914     # @torch.compile
915     def task_path(self, A, f_A, B, f_B):
916         c = torch.randperm(len(self.colors) - 1)[:3] + 1
917         dist = torch.empty(self.height + 2, self.width + 2)
918         for X, f_X in [(A, f_A), (B, f_B)]:
919             nb_rec = torch.randint(3, (1,)).item() + 1
920             while True:
921                 r = self.rec_coo(nb_rec, prevent_overlap=True)
922                 X[...] = 0
923                 f_X[...] = 0
924                 for n in range(nb_rec):
925                     i1, j1, i2, j2 = r[n]
926                     X[i1:i2, j1:j2] = c[0]
927                     f_X[i1:i2, j1:j2] = c[0]
928                 while True:
929                     i0, j0 = (
930                         torch.randint(self.height, (1,)).item(),
931                         torch.randint(self.width, (1,)).item(),
932                     )
933                     if X[i0, j0] == 0:
934                         break
935                 while True:
936                     i1, j1 = (
937                         torch.randint(self.height, (1,)).item(),
938                         torch.randint(self.width, (1,)).item(),
939                     )
940                     if X[i1, j1] == 0:
941                         break
942                 dist[...] = 1
943                 dist[1:-1, 1:-1] = (X != 0).long()
944                 dist[...] = self.compute_distance(dist, i1 + 1, j1 + 1, i0 + 1, j0 + 1)
945                 if dist[i0 + 1, j0 + 1] >= 1 and dist[i0 + 1, j0 + 1] < self.height * 4:
946                     break
947
948             dist[1:-1, 1:-1] += (X != 0).long() * self.height * self.width
949             dist[0, :] = self.height * self.width
950             dist[-1, :] = self.height * self.width
951             dist[:, 0] = self.height * self.width
952             dist[:, -1] = self.height * self.width
953             # dist += torch.rand(dist.size())
954
955             i, j = i0 + 1, j0 + 1
956             while i != i1 + 1 or j != j1 + 1:
957                 f_X[i - 1, j - 1] = c[2]
958                 r, s, t, u = (
959                     dist[i - 1, j],
960                     dist[i, j - 1],
961                     dist[i + 1, j],
962                     dist[i, j + 1],
963                 )
964                 m = min(r, s, t, u)
965                 if r == m:
966                     i = i - 1
967                 elif t == m:
968                     i = i + 1
969                 elif s == m:
970                     j = j - 1
971                 else:
972                     j = j + 1
973
974             X[i0, j0] = c[2]
975             # f_X[i0, j0] = c[1]
976
977             X[i1, j1] = c[1]
978             f_X[i1, j1] = c[1]
979
980     # for X, f_X in [(A, f_A), (B, f_B)]:
981     # n = torch.arange(self.height * self.width).reshape(self.height, self.width)
982     # k = torch.randperm(self.height * self.width)
983     # X[...]=-1
984     # for q in k:
985     # i,j=q%self.height,q//self.height
986     # if
987
988     # @torch.compile
989     def task_puzzle(self, A, f_A, B, f_B):
990         S = 4
991         i0, j0 = (self.height - S) // 2, (self.width - S) // 2
992         c = torch.randperm(len(self.colors) - 1)[:4] + 1
993         for X, f_X in [(A, f_A), (B, f_B)]:
994             while True:
995                 f_X[...] = 0
996                 h = list(torch.randperm(c.size(0)))
997                 n = torch.zeros(c.max() + 1)
998                 for _ in range(2):
999                     k = torch.randperm(S * S)
1000                     for q in k:
1001                         i, j = q % S + i0, q // S + j0
1002                         if f_X[i, j] == 0:
1003                             r, s, t, u = (
1004                                 f_X[i - 1, j],
1005                                 f_X[i, j - 1],
1006                                 f_X[i + 1, j],
1007                                 f_X[i, j + 1],
1008                             )
1009                             r, s, t, u = torch.tensor([r, s, t, u])[torch.randperm(4)]
1010                             if r > 0 and n[r] < 6:
1011                                 n[r] += 1
1012                                 f_X[i, j] = r
1013                             elif s > 0 and n[s] < 6:
1014                                 n[s] += 1
1015                                 f_X[i, j] = s
1016                             elif t > 0 and n[t] < 6:
1017                                 n[t] += 1
1018                                 f_X[i, j] = t
1019                             elif u > 0 and n[u] < 6:
1020                                 n[u] += 1
1021                                 f_X[i, j] = u
1022                             else:
1023                                 if len(h) > 0:
1024                                     d = c[h.pop()]
1025                                     n[d] += 1
1026                                     f_X[i, j] = d
1027
1028                 if n.sum() == S * S:
1029                     break
1030
1031             k = 0
1032             for d in range(4):
1033                 while True:
1034                     ii, jj = (
1035                         torch.randint(self.height, (1,)).item(),
1036                         torch.randint(self.width, (1,)).item(),
1037                     )
1038                     e = 0
1039                     for i in range(S):
1040                         for j in range(S):
1041                             if (
1042                                 ii + i >= self.height
1043                                 or jj + j >= self.width
1044                                 or (
1045                                     f_X[i + i0, j + j0] == c[d]
1046                                     and X[ii + i, jj + j] > 0
1047                                 )
1048                             ):
1049                                 e = 1
1050                     if e == 0:
1051                         break
1052                 for i in range(S):
1053                     for j in range(S):
1054                         if f_X[i + i0, j + j0] == c[d]:
1055                             X[ii + i, jj + j] = c[d]
1056
1057     def task_islands(self, A, f_A, B, f_B):
1058         c = torch.randperm(len(self.colors) - 1)[:2] + 1
1059         for X, f_X in [(A, f_A), (B, f_B)]:
1060             while True:
1061                 k = torch.randperm(self.height * self.width)
1062                 Z = torch.zeros(self.height + 2, self.width + 2)
1063
1064                 i0, j0 = (
1065                     torch.randint(self.height, (1,)).item() + 1,
1066                     torch.randint(self.width, (1,)).item() + 1,
1067                 )
1068
1069                 Z[i0 - 1 : i0 + 2, j0 - 1 : j0 + 2] = 1
1070
1071                 nb = 9
1072
1073                 for q in k:
1074                     i, j = q % self.height + 1, q // self.height + 1
1075
1076                     if Z[i, j] == 0:
1077                         r, s, t, u, v, w, x, y = (
1078                             Z[i - 1, j],
1079                             Z[i - 1, j + 1],
1080                             Z[i, j + 1],
1081                             Z[i + 1, j + 1],
1082                             Z[i + 1, j],
1083                             Z[i + 1, j - 1],
1084                             Z[i, j - 1],
1085                             Z[i - 1, j - 1],
1086                         )
1087
1088                         if (
1089                             (nb < 16 or r + s + t + u + v + w + x + y > 0)
1090                             and (s == 0 or r + t > 0)
1091                             and (u == 0 or t + v > 0)
1092                             and (w == 0 or x + v > 0)
1093                             and (y == 0 or x + r > 0)
1094                         ):
1095                             # if r+s+t+u+v+w+x+y==0:
1096                             Z[i, j] = 1
1097                             nb += 1
1098
1099                     if nb == self.height * self.width // 2:
1100                         break
1101
1102                 if nb == self.height * self.width // 2:
1103                     break
1104
1105             M = Z.clone()
1106             Z[i0, j0] = 2
1107             X[...] = (Z[1:-1, 1:-1] == 1) * c[0] + (Z[1:-1, 1:-1] == 2) * c[1]
1108
1109             for _ in range(self.height + self.width):
1110                 Z[1:-1, 1:-1] = Z[1:-1, 1:-1].maximum(
1111                     torch.maximum(
1112                         torch.maximum(Z[0:-2, 1:-1], Z[2:, 1:-1]),
1113                         torch.maximum(Z[1:-1, 0:-2], Z[1:-1, 2:]),
1114                     )
1115                 )
1116                 Z *= M
1117
1118             f_X[...] = (Z[1:-1, 1:-1] == 1) * c[0] + (Z[1:-1, 1:-1] == 2) * c[1]
1119
1120     ######################################################################
1121
1122     def trivial_prompts_and_answers(self, prompts, answers):
1123         S = self.height * self.width
1124         Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S]
1125         f_Bs = answers
1126         return (Bs == f_Bs).long().min(dim=-1).values > 0
1127
1128     def generate_prompts_and_answers_(self, nb, tasks=None, progress_bar=False):
1129         if tasks is None:
1130             tasks = self.all_tasks
1131
1132         S = self.height * self.width
1133         prompts = torch.zeros(nb, 3 * S + 2, dtype=torch.int64)
1134         answers = torch.zeros(nb, S, dtype=torch.int64)
1135
1136         bunch = zip(prompts, answers)
1137
1138         if progress_bar:
1139             bunch = tqdm.tqdm(
1140                 bunch,
1141                 dynamic_ncols=True,
1142                 desc="world generation",
1143                 total=prompts.size(0),
1144             )
1145
1146         for prompt, answer in bunch:
1147             A = prompt[0 * (S + 1) : 0 * (S + 1) + S].view(self.height, self.width)
1148             f_A = prompt[1 * (S + 1) : 1 * (S + 1) + S].view(self.height, self.width)
1149             B = prompt[2 * (S + 1) : 2 * (S + 1) + S].view(self.height, self.width)
1150             f_B = answer.view(self.height, self.width)
1151             task = tasks[torch.randint(len(tasks), (1,)).item()]
1152             task(A, f_A, B, f_B)
1153
1154         return prompts.flatten(1), answers.flatten(1)
1155
1156     def save_quiz_illustrations(
1157         self,
1158         result_dir,
1159         filename_prefix,
1160         prompts,
1161         answers,
1162         predicted_prompts=None,
1163         predicted_answers=None,
1164         nrow=4,
1165     ):
1166         self.save_image(
1167             result_dir,
1168             filename_prefix + ".png",
1169             prompts,
1170             answers,
1171             predicted_prompts,
1172             predicted_answers,
1173             nrow,
1174         )
1175
1176     def save_some_examples(self, result_dir):
1177         nb, nrow = 72, 4
1178         for t in self.all_tasks:
1179             print(t.__name__)
1180             prompts, answers = self.generate_prompts_and_answers_(nb, tasks=[t])
1181             self.save_quiz_illustrations(
1182                 result_dir, t.__name__, prompts[:nb], answers[:nb], nrow=nrow
1183             )
1184
1185
1186 ######################################################################
1187
1188 if __name__ == "__main__":
1189     import time
1190
1191     # grids = Grids(max_nb_cached_chunks=5, chunk_size=100, nb_threads=4)
1192     grids = Grids()
1193
1194     # nb = 1000
1195     # grids = problem.MultiThreadProblem(
1196     # grids, max_nb_cached_chunks=50, chunk_size=100, nb_threads=1
1197     # )
1198     #    time.sleep(10)
1199     # start_time = time.perf_counter()
1200     # prompts, answers = grids.generate_prompts_and_answers(nb)
1201     # delay = time.perf_counter() - start_time
1202     # print(f"{prompts.size(0)/delay:02f} seq/s")
1203     # exit(0)
1204
1205     # if True:
1206     nb, nrow = 72, 4
1207     # nb, nrow = 8, 2
1208
1209     # for t in grids.all_tasks:
1210     for t in [grids.task_count]:
1211         print(t.__name__)
1212         prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
1213         grids.save_quiz_illustrations(
1214             "/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow
1215         )
1216
1217     exit(0)
1218
1219     nb = 1000
1220
1221     # for t in grids.all_tasks:
1222     for t in [grids.task_islands]:
1223         start_time = time.perf_counter()
1224         prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
1225         delay = time.perf_counter() - start_time
1226         print(f"{t.__name__} {prompts.size(0)/delay:02f} seq/s")
1227
1228     exit(0)
1229
1230     m = torch.randint(2, (prompts.size(0),))
1231     predicted_prompts = m * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
1232     predicted_answers = (1 - m) * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
1233
1234     grids.save_quiz_illustrations(
1235         "/tmp",
1236         "test",
1237         prompts[:nb],
1238         answers[:nb],
1239         # You can add a bool to put a frame around the predicted parts
1240         predicted_prompts[:nb],
1241         predicted_answers[:nb],
1242     )