Merge branch 'dev'
[culture.git] / grids.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, tqdm, os, warnings, cairo
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 ######################################################################
16
17
18 def text_img(height, width, text):
19     pixel_map = torch.full((height, width, 4), 255, dtype=torch.uint8)
20
21     surface = cairo.ImageSurface.create_for_data(
22         pixel_map.numpy(), cairo.FORMAT_ARGB32, pixel_map.size(1), pixel_map.size(0)
23     )
24
25     ctx = cairo.Context(surface)
26     ctx.set_source_rgb(0, 0, 0)
27     ctx.set_font_size(16)
28     ctx.select_font_face("courier", cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_NORMAL)
29     y = None
30     for line in text.split("\n"):
31         xbearing, ybearing, width, height, dx, dy = ctx.text_extents(line)
32         if y is None:
33             y = height * 1.5
34             x = height * 0.5
35
36         ctx.move_to(x, y)
37         ctx.show_text(line)
38         y += height * 1.5
39
40     ctx.stroke()
41
42     return pixel_map.permute(2, 0, 1)[None, :3].contiguous()
43
44
45 ######################################################################
46
47 import problem
48
49
50 def grow_islands(nb, height, width, nb_seeds, nb_iterations):
51     w = torch.empty(5, 1, 3, 3)
52
53     w[0, 0] = torch.tensor(
54         [
55             [1.0, 1.0, 1.0],
56             [1.0, 0.0, 1.0],
57             [1.0, 1.0, 1.0],
58         ]
59     )
60
61     w[1, 0] = torch.tensor(
62         [
63             [-1.0, 1.0, 0.0],
64             [1.0, 0.0, 0.0],
65             [0.0, 0.0, 0.0],
66         ]
67     )
68
69     w[2, 0] = torch.tensor(
70         [
71             [0.0, 1.0, -1.0],
72             [0.0, 0.0, 1.0],
73             [0.0, 0.0, 0.0],
74         ]
75     )
76
77     w[3, 0] = torch.tensor(
78         [
79             [0.0, 0.0, 0.0],
80             [0.0, 0.0, 1.0],
81             [0.0, 1.0, -1.0],
82         ]
83     )
84
85     w[4, 0] = torch.tensor(
86         [
87             [0.0, 0.0, 0.0],
88             [1.0, 0.0, 0.0],
89             [-1.0, 1.0, 0.0],
90         ]
91     )
92
93     Z = torch.zeros(nb, height, width)
94     U = Z.flatten(1)
95
96     for _ in range(nb_seeds):
97         M = F.conv2d(Z[:, None, :, :], w, padding=1)
98         M = torch.cat([M[:, :1], M[:, 1:].min(dim=1, keepdim=True).values], dim=1)
99         M = ((M[:, 0] == 0) & (Z == 0)).long()
100         Q = (M.flatten(1).max(dim=1).values > 0).long()[:, None]
101         M = M * torch.rand(M.size())
102         M = M.flatten(1)
103         M = F.one_hot(M.argmax(dim=1), num_classes=M.size(1))
104         U += M * Q
105
106     for _ in range(nb_iterations):
107         M = F.conv2d(Z[:, None, :, :], w, padding=1)
108         M = torch.cat([M[:, :1], M[:, 1:].min(dim=1, keepdim=True).values], dim=1)
109         M = ((M[:, 1] >= 0) & (Z == 0)).long()
110         Q = (M.flatten(1).max(dim=1).values > 0).long()[:, None]
111         M = M * torch.rand(M.size())
112         M = M.flatten(1)
113         M = F.one_hot(M.argmax(dim=1), num_classes=M.size(1))
114         U = Z.flatten(1)
115         U += M * Q
116
117     M = Z.clone()
118     Z = Z * (torch.arange(Z.size(1) * Z.size(2)) + 1).reshape(1, Z.size(1), Z.size(2))
119
120     while True:
121         W = Z.clone()
122         Z = F.max_pool2d(Z, 3, 1, 1) * M
123         if Z.equal(W):
124             break
125
126     Z = Z.long()
127     U = Z.flatten(1)
128     V = F.one_hot(U).max(dim=1).values
129     W = V.cumsum(dim=1) - V
130     N = torch.arange(Z.size(0))[:, None, None].expand_as(Z)
131     Z = W[N, Z]
132
133     return Z
134
135
136 class Grids(problem.Problem):
137     named_colors = [
138         ("white", [255, 255, 255]),
139         ("red", [255, 0, 0]),
140         ("green", [0, 192, 0]),
141         ("blue", [0, 0, 255]),
142         ("yellow", [255, 224, 0]),
143         ("cyan", [0, 255, 255]),
144         ("violet", [224, 128, 255]),
145         ("lightgreen", [192, 255, 192]),
146         ("brown", [165, 42, 42]),
147         ("lightblue", [192, 192, 255]),
148         ("gray", [128, 128, 128]),
149     ]
150
151     def check_structure(self, quizzes, struct):
152         S = self.height * self.width
153
154         return (
155             (quizzes[:, 0 * (S + 1)] == self.l2tok[struct[0]])
156             & (quizzes[:, 1 * (S + 1)] == self.l2tok[struct[1]])
157             & (quizzes[:, 2 * (S + 1)] == self.l2tok[struct[2]])
158             & (quizzes[:, 3 * (S + 1)] == self.l2tok[struct[3]])
159         ).all()
160
161     def get_structure(self, quizzes):
162         S = self.height * self.width
163         struct = tuple(
164             self.tok2l[n.item()]
165             for n in quizzes.reshape(quizzes.size(0), 4, S + 1)[0, :, 0]
166         )
167         self.check_structure(quizzes, struct)
168         return struct
169
170     def inject_noise(self, quizzes, noise, struct, mask):
171         assert self.check_structure(quizzes, struct=struct)
172         S = self.height * self.width
173
174         mask = torch.tensor(mask, device=quizzes.device)
175         mask = mask[None, :, None].expand(1, 4, S + 1).clone()
176         mask[:, :, 0] = 0
177         mask = mask.reshape(1, -1).expand_as(quizzes)
178         mask = mask * (torch.rand(mask.size(), device=mask.device) <= noise).long()
179         random = torch.randint(self.nb_colors, mask.size())
180         quizzes = mask * random + (1 - mask) * quizzes
181
182         return quizzes
183
184     # What a mess
185     def reconfigure(self, quizzes, struct=("A", "f_A", "B", "f_B")):
186         if torch.is_tensor(quizzes):
187             return self.reconfigure([quizzes], struct=struct)[0]
188
189         S = self.height * self.width
190         result = [x.new(x.size()) for x in quizzes]
191
192         struct_from = self.get_structure(quizzes[0][:1])
193         i = self.indices_select(quizzes[0], struct_from)
194
195         sf = dict((l, n) for n, l in enumerate(struct_from))
196
197         for q in range(4):
198             k = sf[struct[q]]
199             for x, y in zip(quizzes, result):
200                 l = x.size(1) // 4
201                 y[i, q * l : (q + 1) * l] = x[i, k * l : (k + 1) * l]
202
203         j = i == False
204
205         if j.any():
206             for z, y in zip(
207                 self.reconfigure([x[j] for x in quizzes], struct=struct), result
208             ):
209                 y[j] = z
210
211         return result
212
213     def trivial(self, quizzes):
214         S = self.height * self.width
215         assert self.check_structure(quizzes, struct=("A", "f_A", "B", "f_B"))
216         a = quizzes.reshape(quizzes.size(0), 4, S + 1)[:, :, 1:]
217         return (a[:, 0] == a[:, 1]).min(dim=1).values | (a[:, 2] == a[:, 3]).min(
218             dim=1
219         ).values
220
221     def make_quiz_mask(
222         self, quizzes, struct=("A", "f_A", "B", "f_B"), mask=(0, 0, 0, 1)
223     ):
224         assert self.check_structure(quizzes, struct)
225
226         ar_mask = quizzes.new_zeros(quizzes.size())
227
228         S = self.height * self.width
229         a = ar_mask.reshape(ar_mask.size(0), 4, S + 1)[:, :, 1:]
230         a[:, 0, :] = mask[0]
231         a[:, 1, :] = mask[1]
232         a[:, 2, :] = mask[2]
233         a[:, 3, :] = mask[3]
234
235         return ar_mask
236
237     def indices_select(self, quizzes, struct=("A", "f_A", "B", "f_B")):
238         S = self.height * self.width
239         q = quizzes.reshape(quizzes.size(0), 4, S + 1)
240         return (
241             (q[:, 0, 0] == self.l2tok[struct[0]])
242             & (q[:, 1, 0] == self.l2tok[struct[1]])
243             & (q[:, 2, 0] == self.l2tok[struct[2]])
244             & (q[:, 3, 0] == self.l2tok[struct[3]])
245         )
246
247     def __init__(
248         self,
249         max_nb_cached_chunks=None,
250         chunk_size=None,
251         nb_threads=-1,
252         tasks=None,
253     ):
254         self.colors = torch.tensor([c for _, c in self.named_colors])
255
256         self.nb_colors = len(self.colors)
257         self.token_A = self.nb_colors
258         self.token_f_A = self.token_A + 1
259         self.token_B = self.token_f_A + 1
260         self.token_f_B = self.token_B + 1
261
262         self.nb_rec_max = 5
263         self.rfree = torch.tensor([])
264
265         self.l2tok = {
266             "A": self.token_A,
267             "f_A": self.token_f_A,
268             "B": self.token_B,
269             "f_B": self.token_f_B,
270         }
271
272         self.tok2l = {
273             self.token_A: "A",
274             self.token_f_A: "f_A",
275             self.token_B: "B",
276             self.token_f_B: "f_B",
277         }
278
279         self.height = 10
280         self.width = 10
281         self.seq_len = 4 * (1 + self.height * self.width)
282         self.nb_token_values = self.token_f_B + 1
283
284         self.cache_rec_coo = {}
285
286         all_tasks = [
287             self.task_replace_color,
288             self.task_translate,
289             self.task_grow,
290             self.task_half_fill,
291             self.task_frame,
292             self.task_detect,
293             self.task_scale,
294             self.task_symbols,
295             self.task_corners,
296             self.task_contact,
297             self.task_path,
298             self.task_fill,
299             ############################################ hard ones
300             self.task_isometry,
301             self.task_trajectory,
302             self.task_bounce,
303             # self.task_count, # NOT REVERSIBLE
304             # self.task_islands, # TOO MESSY
305         ]
306
307         if tasks is None:
308             self.all_tasks = all_tasks
309         else:
310             self.all_tasks = [getattr(self, "task_" + t) for t in tasks.split(",")]
311
312         super().__init__(max_nb_cached_chunks, chunk_size, nb_threads)
313
314     ######################################################################
315
316     def grid2img(self, x, scale=15):
317         m = torch.logical_and(x >= 0, x < self.nb_colors).long()
318         y = self.colors[x * m].permute(0, 3, 1, 2)
319         s = y.shape
320         y = y[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
321         y = y.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
322
323         y[:, :, :, torch.arange(0, y.size(3), scale)] = 64
324         y[:, :, torch.arange(0, y.size(2), scale), :] = 64
325
326         for n in range(m.size(0)):
327             for i in range(m.size(1)):
328                 for j in range(m.size(2)):
329                     if m[n, i, j] == 0:
330                         for k in range(3, scale - 2):
331                             y[n, :, i * scale + k, j * scale + k] = 0
332                             y[n, :, i * scale + k, j * scale + scale - k] = 0
333
334         y = y[:, :, 1:, 1:]
335
336         return y
337
338     def add_frame(self, img, colors, thickness):
339         result = img.new(
340             img.size(0),
341             img.size(1),
342             img.size(2) + 2 * thickness,
343             img.size(3) + 2 * thickness,
344         )
345
346         result[...] = colors[:, :, None, None]
347         result[:, :, thickness:-thickness, thickness:-thickness] = img
348
349         return result
350
351     def save_quizzes_as_image(
352         self,
353         result_dir,
354         filename,
355         quizzes,
356         predicted_parts=None,
357         correct_parts=None,
358         comments=None,
359         comment_height=48,
360         nrow=4,
361         margin=8,
362     ):
363         quizzes = quizzes.to("cpu")
364
365         to_reconfigure = [quizzes]
366         if predicted_parts is not None:
367             to_reconfigure.append(predicted_parts)
368         if correct_parts is not None:
369             to_reconfigure.append(correct_parts)
370
371         to_reconfigure = self.reconfigure(to_reconfigure, ("A", "f_A", "B", "f_B"))
372
373         quizzes = to_reconfigure.pop(0)
374         if predicted_parts is not None:
375             predicted_parts = to_reconfigure.pop(0)
376         if correct_parts is not None:
377             correct_parts = to_reconfigure.pop(0)
378
379         S = self.height * self.width
380
381         A, f_A, B, f_B = (
382             quizzes.reshape(quizzes.size(0), 4, S + 1)[:, :, 1:]
383             .reshape(quizzes.size(0), 4, self.height, self.width)
384             .permute(1, 0, 2, 3)
385         )
386
387         frame, white, gray, green, red = torch.tensor(
388             [[64, 64, 64], [255, 255, 255], [200, 200, 200], [0, 255, 0], [255, 0, 0]],
389             device=quizzes.device,
390         )
391
392         img_A = self.add_frame(self.grid2img(A), frame[None, :], thickness=1)
393         img_f_A = self.add_frame(self.grid2img(f_A), frame[None, :], thickness=1)
394         img_B = self.add_frame(self.grid2img(B), frame[None, :], thickness=1)
395         img_f_B = self.add_frame(self.grid2img(f_B), frame[None, :], thickness=1)
396
397         # predicted_parts Nx4
398         # correct_parts Nx4
399
400         if predicted_parts is None:
401             colors = white[None, None, :].expand(-1, 4, -1)
402         else:
403             predicted_parts = predicted_parts.to("cpu")
404             if correct_parts is None:
405                 colors = (
406                     predicted_parts[:, :, None] * gray[None, None, :]
407                     + (1 - predicted_parts[:, :, None]) * white[None, None, :]
408                 )
409             else:
410                 correct_parts = correct_parts.to("cpu")
411                 colors = (
412                     predicted_parts[:, :, None]
413                     * (
414                         (correct_parts[:, :, None] == 1).long() * green[None, None, :]
415                         + (correct_parts[:, :, None] == 0).long() * gray[None, None, :]
416                         + (correct_parts[:, :, None] == -1).long() * red[None, None, :]
417                     )
418                     + (1 - predicted_parts[:, :, None]) * white[None, None, :]
419                 )
420
421         img_A = self.add_frame(img_A, colors[:, 0], thickness=8)
422         img_f_A = self.add_frame(img_f_A, colors[:, 1], thickness=8)
423         img_B = self.add_frame(img_B, colors[:, 2], thickness=8)
424         img_f_B = self.add_frame(img_f_B, colors[:, 3], thickness=8)
425
426         img_A = self.add_frame(img_A, white[None, :], thickness=2)
427         img_f_A = self.add_frame(img_f_A, white[None, :], thickness=2)
428         img_B = self.add_frame(img_B, white[None, :], thickness=2)
429         img_f_B = self.add_frame(img_f_B, white[None, :], thickness=2)
430
431         img = torch.cat([img_A, img_f_A, img_B, img_f_B], dim=3)
432
433         if comments is not None:
434             comment_img = [text_img(comment_height, img.size(3), t) for t in comments]
435             comment_img = torch.cat(comment_img, dim=0)
436             img = torch.cat([img, comment_img], dim=2)
437
438         image_name = os.path.join(result_dir, filename)
439
440         torchvision.utils.save_image(
441             img.float() / 255.0,
442             image_name,
443             nrow=nrow,
444             padding=margin * 4,
445             pad_value=1.0,
446         )
447
448     ######################################################################
449
450     # @torch.compile
451     def rec_coo(
452         self,
453         nb_rec,
454         min_height=3,
455         min_width=3,
456         surface_max=None,
457         prevent_overlap=False,
458     ):
459         if surface_max is None:
460             surface_max = self.height * self.width // 2
461
462         signature = (nb_rec, min_height, min_width, surface_max)
463
464         try:
465             return self.cache_rec_coo[signature].pop()
466         except IndexError:
467             pass
468         except KeyError:
469             pass
470
471         N = 10000
472         while True:
473             while True:
474                 i = torch.randint(self.height, (N * nb_rec, 2)).sort(dim=-1).values
475                 j = torch.randint(self.width, (N * nb_rec, 2)).sort(dim=-1).values
476                 i[:, 1] += 1
477                 j[:, 1] += 1
478                 big_enough = (
479                     (i[:, 1] >= i[:, 0] + min_height)
480                     & (j[:, 1] >= j[:, 0] + min_height)
481                     & ((i[:, 1] - i[:, 0]) * (j[:, 1] - j[:, 0]) <= surface_max)
482                 )
483
484                 i, j = i[big_enough], j[big_enough]
485
486                 n = i.size(0) - i.size(0) % nb_rec
487
488                 if n > 0:
489                     break
490
491             i = i[:n].reshape(n // nb_rec, nb_rec, -1)
492             j = j[:n].reshape(n // nb_rec, nb_rec, -1)
493
494             if prevent_overlap:
495                 can_fit = ((i[:, :, 1] - i[:, :, 0]) * (j[:, :, 1] - j[:, :, 0])).sum(
496                     dim=-1
497                 ) <= self.height * self.width
498                 i, j = i[can_fit], j[can_fit]
499                 if nb_rec == 2:
500                     A_i1, A_i2, A_j1, A_j2 = (
501                         i[:, 0, 0],
502                         i[:, 0, 1],
503                         j[:, 0, 0],
504                         j[:, 0, 1],
505                     )
506                     B_i1, B_i2, B_j1, B_j2 = (
507                         i[:, 1, 0],
508                         i[:, 1, 1],
509                         j[:, 1, 0],
510                         j[:, 1, 1],
511                     )
512                     no_overlap = (
513                         (A_i1 >= B_i2)
514                         | (A_i2 <= B_i1)
515                         | (A_j1 >= B_j2)
516                         | (A_j2 <= B_j1)
517                     )
518                     i, j = (i[no_overlap], j[no_overlap])
519                 elif nb_rec == 3:
520                     A_i1, A_i2, A_j1, A_j2 = (
521                         i[:, 0, 0],
522                         i[:, 0, 1],
523                         j[:, 0, 0],
524                         j[:, 0, 1],
525                     )
526                     B_i1, B_i2, B_j1, B_j2 = (
527                         i[:, 1, 0],
528                         i[:, 1, 1],
529                         j[:, 1, 0],
530                         j[:, 1, 1],
531                     )
532                     C_i1, C_i2, C_j1, C_j2 = (
533                         i[:, 2, 0],
534                         i[:, 2, 1],
535                         j[:, 2, 0],
536                         j[:, 2, 1],
537                     )
538                     no_overlap = (
539                         (
540                             (A_i1 >= B_i2)
541                             | (A_i2 <= B_i1)
542                             | (A_j1 >= B_j2)
543                             | (A_j2 <= B_j1)
544                         )
545                         & (
546                             (A_i1 >= C_i2)
547                             | (A_i2 <= C_i1)
548                             | (A_j1 >= C_j2)
549                             | (A_j2 <= C_j1)
550                         )
551                         & (
552                             (B_i1 >= C_i2)
553                             | (B_i2 <= C_i1)
554                             | (B_j1 >= C_j2)
555                             | (B_j2 <= C_j1)
556                         )
557                     )
558                     i, j = (i[no_overlap], j[no_overlap])
559                 else:
560                     assert nb_rec == 1
561
562             if i.size(0) > 1:
563                 break
564
565         self.cache_rec_coo[signature] = [
566             [
567                 (
568                     i[n, k, 0].item(),
569                     j[n, k, 0].item(),
570                     i[n, k, 1].item(),
571                     j[n, k, 1].item(),
572                 )
573                 for k in range(nb_rec)
574             ]
575             for n in range(i.size(0))
576         ]
577
578         return self.cache_rec_coo[signature].pop()
579
580     ######################################################################
581
582     def contact_matrices(self, rn, ri, rj, rz):
583         n = torch.arange(self.nb_rec_max)
584         return (
585             (
586                 (
587                     (
588                         (ri[:, :, None, 0] == ri[:, None, :, 1] + 1)
589                         | (ri[:, :, None, 1] + 1 == ri[:, None, :, 0])
590                     )
591                     & (rj[:, :, None, 0] <= rj[:, None, :, 1])
592                     & (rj[:, :, None, 1] >= rj[:, None, :, 0])
593                 )
594                 | (
595                     (
596                         (rj[:, :, None, 0] == rj[:, None, :, 1] + 1)
597                         | (rj[:, :, None, 1] + 1 == rj[:, None, :, 0])
598                     )
599                     & (ri[:, :, None, 0] <= ri[:, None, :, 1])
600                     & (ri[:, :, None, 1] >= ri[:, None, :, 0])
601                 )
602             )
603             # & (rz[:, :, None] == rz[:, None, :])
604             & (n[None, :, None] < rn[:, None, None])
605             & (n[None, None, :] < n[None, :, None])
606         )
607
608     def sample_rworld_states(self, N=1000):
609         while True:
610             ri = (
611                 torch.randint(self.height - 2, (N, self.nb_rec_max, 2))
612                 .sort(dim=2)
613                 .values
614             )
615             ri[:, :, 1] += 2
616             rj = (
617                 torch.randint(self.width - 2, (N, self.nb_rec_max, 2))
618                 .sort(dim=2)
619                 .values
620             )
621             rj[:, :, 1] += 2
622             rn = torch.randint(self.nb_rec_max - 1, (N,)) + 2
623             rz = torch.randint(2, (N, self.nb_rec_max))
624             rc = torch.randint(self.nb_colors - 1, (N, self.nb_rec_max)) + 1
625             n = torch.arange(self.nb_rec_max)
626             nb_collisions = (
627                 (
628                     (ri[:, :, None, 0] <= ri[:, None, :, 1])
629                     & (ri[:, :, None, 1] >= ri[:, None, :, 0])
630                     & (rj[:, :, None, 0] <= rj[:, None, :, 1])
631                     & (rj[:, :, None, 1] >= rj[:, None, :, 0])
632                     & (rz[:, :, None] == rz[:, None, :])
633                     & (n[None, :, None] < rn[:, None, None])
634                     & (n[None, None, :] < n[None, :, None])
635                 )
636                 .long()
637                 .flatten(1)
638                 .sum(dim=1)
639             )
640
641             no_collision = nb_collisions == 0
642
643             if no_collision.any():
644                 print(no_collision.long().sum() / N)
645                 self.rn = rn[no_collision]
646                 self.ri = ri[no_collision]
647                 self.rj = rj[no_collision]
648                 self.rz = rz[no_collision]
649                 self.rc = rc[no_collision]
650
651                 nb_contact = (
652                     self.contact_matrices(rn, ri, rj, rz).long().flatten(1).sum(dim=1)
653                 )
654
655                 self.rcontact = nb_contact > 0
656                 self.rfree = torch.full((self.rn.size(0),), True)
657
658                 break
659
660     def get_recworld_state(self):
661         if not self.rfree.any():
662             self.sample_rworld_states()
663         k = torch.arange(self.rn.size(0))[self.rfree]
664         k = k[torch.randint(k.size(0), (1,))].item()
665         self.rfree[k] = False
666         return self.rn[k], self.ri[k], self.rj[k], self.rz[k], self.rc[k]
667
668     def draw_state(self, X, rn, ri, rj, rz, rc):
669         for n in sorted(list(range(rn)), key=lambda n: rz[n].item()):
670             X[ri[n, 0] : ri[n, 1] + 1, rj[n, 0] : rj[n, 1] + 1] = rc[n]
671
672     def task_recworld_immobile(self, A, f_A, B, f_B):
673         for X, f_X in [(A, f_A), (B, f_B)]:
674             rn, ri, rj, rz, rc = self.get_recworld_state()
675             self.draw_state(X, rn, ri, rj, rz, rc)
676             ri += 1
677             self.draw_state(f_X, rn, ri, rj, rz, rc)
678
679     ######################################################################
680
681     # @torch.compile
682     def task_replace_color(self, A, f_A, B, f_B):
683         nb_rec = 3
684         c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
685         for X, f_X in [(A, f_A), (B, f_B)]:
686             r = self.rec_coo(nb_rec, prevent_overlap=True)
687             for n in range(nb_rec):
688                 i1, j1, i2, j2 = r[n]
689                 X[i1:i2, j1:j2] = c[n]
690                 f_X[i1:i2, j1:j2] = c[n if n > 0 else -1]
691
692     # @torch.compile
693     def task_translate(self, A, f_A, B, f_B):
694         while True:
695             di, dj = torch.randint(3, (2,)) - 1
696             if di.abs() + dj.abs() > 0:
697                 break
698
699         nb_rec = 3
700         c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
701         for X, f_X in [(A, f_A), (B, f_B)]:
702             while True:
703                 r = self.rec_coo(nb_rec, prevent_overlap=True)
704                 i1, j1, i2, j2 = r[nb_rec - 1]
705                 if (
706                     i1 + di >= 0
707                     and i2 + di < X.size(0)
708                     and j1 + dj >= 0
709                     and j2 + dj < X.size(1)
710                 ):
711                     break
712
713             for n in range(nb_rec):
714                 i1, j1, i2, j2 = r[n]
715                 X[i1:i2, j1:j2] = c[n]
716                 if n == nb_rec - 1:
717                     f_X[i1 + di : i2 + di, j1 + dj : j2 + dj] = c[n]
718                 else:
719                     f_X[i1:i2, j1:j2] = c[n]
720
721     # @torch.compile
722     def task_grow(self, A, f_A, B, f_B):
723         di, dj = torch.randint(2, (2,)) * 2 - 1
724         nb_rec = 3
725         c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
726         direction = torch.randint(2, (1,)).item()
727         for X, f_X in [(A, f_A), (B, f_B)]:
728             while True:
729                 r = self.rec_coo(nb_rec, prevent_overlap=True)
730                 i1, j1, i2, j2 = r[nb_rec - 1]
731                 if i1 + 3 < i2 and j1 + 3 < j2:
732                     break
733
734             for n in range(nb_rec):
735                 i1, j1, i2, j2 = r[n]
736                 if n == nb_rec - 1:
737                     if direction == 0:
738                         X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
739                         f_X[i1:i2, j1:j2] = c[n]
740                     else:
741                         X[i1:i2, j1:j2] = c[n]
742                         f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
743                 else:
744                     X[i1:i2, j1:j2] = c[n]
745                     f_X[i1:i2, j1:j2] = c[n]
746
747     # @torch.compile
748     def task_half_fill(self, A, f_A, B, f_B):
749         di, dj = torch.randint(2, (2,)) * 2 - 1
750         nb_rec = 3
751         c = torch.randperm(self.nb_colors - 1)[: 2 * nb_rec] + 1
752         direction = torch.randint(4, (1,)).item()
753         for X, f_X in [(A, f_A), (B, f_B)]:
754             r = self.rec_coo(nb_rec, prevent_overlap=True)
755             for n in range(nb_rec):
756                 i1, j1, i2, j2 = r[n]
757                 X[i1:i2, j1:j2] = c[2 * n]
758                 f_X[i1:i2, j1:j2] = c[2 * n]
759                 # Not my proudest moment
760                 if direction == 0:
761                     i = (i1 + i2) // 2
762                     X[i : i + 1, j1:j2] = c[2 * n + 1]
763                     if n == nb_rec - 1:
764                         f_X[i:i2, j1:j2] = c[2 * n + 1]
765                     else:
766                         f_X[i : i + 1, j1:j2] = c[2 * n + 1]
767                 elif direction == 1:
768                     i = (i1 + i2 - 1) // 2
769                     X[i : i + 1, j1:j2] = c[2 * n + 1]
770                     if n == nb_rec - 1:
771                         f_X[i1 : i + 1, j1:j2] = c[2 * n + 1]
772                     else:
773                         f_X[i : i + 1, j1:j2] = c[2 * n + 1]
774                 elif direction == 2:
775                     j = (j1 + j2) // 2
776                     X[i1:i2, j : j + 1] = c[2 * n + 1]
777                     if n == nb_rec - 1:
778                         f_X[i1:i2, j:j2] = c[2 * n + 1]
779                     else:
780                         f_X[i1:i2, j : j + 1] = c[2 * n + 1]
781                 elif direction == 3:
782                     j = (j1 + j2 - 1) // 2
783                     X[i1:i2, j : j + 1] = c[2 * n + 1]
784                     if n == nb_rec - 1:
785                         f_X[i1:i2, j1 : j + 1] = c[2 * n + 1]
786                     else:
787                         f_X[i1:i2, j : j + 1] = c[2 * n + 1]
788
789     # @torch.compile
790     def task_frame(self, A, f_A, B, f_B):
791         nb_rec = 3
792         c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
793         for X, f_X in [(A, f_A), (B, f_B)]:
794             r = self.rec_coo(nb_rec, prevent_overlap=True)
795             for n in range(nb_rec):
796                 i1, j1, i2, j2 = r[n]
797                 X[i1:i2, j1:j2] = c[n]
798                 if n == nb_rec - 1:
799                     f_X[i1:i2, j1] = c[n]
800                     f_X[i1:i2, j2 - 1] = c[n]
801                     f_X[i1, j1:j2] = c[n]
802                     f_X[i2 - 1, j1:j2] = c[n]
803                 else:
804                     f_X[i1:i2, j1:j2] = c[n]
805
806     # @torch.compile
807     def task_detect(self, A, f_A, B, f_B):
808         nb_rec = 3
809         c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
810         for X, f_X in [(A, f_A), (B, f_B)]:
811             r = self.rec_coo(nb_rec, prevent_overlap=True)
812             for n in range(nb_rec):
813                 i1, j1, i2, j2 = r[n]
814                 X[i1:i2, j1:j2] = c[n]
815                 f_X[i1:i2, j1:j2] = c[n]
816                 if n < nb_rec - 1:
817                     for k in range(2):
818                         f_X[i1 + k, j1] = c[-1]
819                         f_X[i1, j1 + k] = c[-1]
820
821     # @torch.compile
822     def contact(self, X, i, j, q):
823         nq, nq_diag = 0, 0
824         no = 0
825
826         for ii, jj in [
827             (i - 1, j - 1),
828             (i - 1, j),
829             (i - 1, j + 1),
830             (i, j - 1),
831             (i, j + 1),
832             (i + 1, j - 1),
833             (i + 1, j),
834             (i + 1, j + 1),
835         ]:
836             if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width:
837                 if X[ii, jj] != 0 and X[ii, jj] != q:
838                     no += 1
839
840         for ii, jj in [
841             (i - 1, j - 1),
842             (i - 1, j + 1),
843             (i + 1, j - 1),
844             (i + 1, j + 1),
845         ]:
846             if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width:
847                 if X[ii, jj] == q and X[i, jj] != q and X[ii, j] != q:
848                     nq_diag += 1
849
850         for ii, jj in [(i - 1, j), (i, j - 1), (i, j + 1), (i + 1, j)]:
851             if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width:
852                 if X[ii, jj] == q:
853                     nq += 1
854
855         return no, nq, nq_diag
856
857     def REMOVED_task_count(self, A, f_A, B, f_B):
858         while True:
859             error = False
860
861             N = 3
862             c = torch.zeros(N + 2, dtype=torch.int64)
863             c[1:] = torch.randperm(self.nb_colors - 1)[: N + 1] + 1
864
865             for X, f_X in [(A, f_A), (B, f_B)]:
866                 if not hasattr(self, "cache_count") or len(self.cache_count) == 0:
867                     self.cache_count = list(
868                         grow_islands(
869                             1000,
870                             self.height,
871                             self.width,
872                             nb_seeds=self.height * self.width // 8,
873                             nb_iterations=self.height * self.width // 5,
874                         )
875                     )
876
877                 X[...] = self.cache_count.pop()
878
879                 # k = (X.max() + 1 + (c.size(0) - 1)).item()
880                 # V = torch.arange(k) // (c.size(0) - 1)
881                 # V = (V + torch.rand(V.size())).sort().indices[: X.max() + 1] % (
882                 # c.size(0) - 1
883                 # ) + 1
884
885                 V = torch.randint(N, (X.max() + 1,)) + 1
886                 V[0] = 0
887                 NB = F.one_hot(c[V]).sum(dim=0)
888                 X[...] = c[V[X]]
889                 f_X[...] = X
890
891                 if F.one_hot(X.flatten()).max(dim=0).values.sum().item() >= 3:
892                     m = NB[c[:-1]].max()
893                     if (NB[c[:-1]] == m).long().sum() == 1:
894                         for e in range(1, N + 1):
895                             if NB[c[e]] == m:
896                                 a = (f_X == c[e]).long()
897                                 f_X[...] = (1 - a) * f_X + a * c[-1]
898                 else:
899                     error = True
900                     break
901
902             if not error:
903                 break
904
905         assert F.one_hot(A.flatten()).max(dim=0).values.sum() >= 3
906
907     # @torch.compile
908     def task_trajectory(self, A, f_A, B, f_B):
909         c = torch.randperm(self.nb_colors - 1)[:2] + 1
910         for X, f_X in [(A, f_A), (B, f_B)]:
911             while True:
912                 di, dj = torch.randint(7, (2,)) - 3
913                 i, j = (
914                     torch.randint(self.height, (1,)).item(),
915                     torch.randint(self.width, (1,)).item(),
916                 )
917                 if (
918                     abs(di) + abs(dj) > 0
919                     and i + 2 * di >= 0
920                     and i + 2 * di < self.height
921                     and j + 2 * dj >= 0
922                     and j + 2 * dj < self.width
923                 ):
924                     break
925
926             k = 0
927             while (
928                 i + k * di >= 0
929                 and i + k * di < self.height
930                 and j + k * dj >= 0
931                 and j + k * dj < self.width
932             ):
933                 if k < 2:
934                     X[i + k * di, j + k * dj] = c[k]
935                 f_X[i + k * di, j + k * dj] = c[min(k, 1)]
936                 k += 1
937
938     # @torch.compile
939     def task_bounce(self, A, f_A, B, f_B):
940         c = torch.randperm(self.nb_colors - 1)[:3] + 1
941         for X, f_X in [(A, f_A), (B, f_B)]:
942             # @torch.compile
943             def free(i, j):
944                 return (
945                     i >= 0
946                     and i < self.height
947                     and j >= 0
948                     and j < self.width
949                     and f_X[i, j] == 0
950                 )
951
952             while True:
953                 f_X[...] = 0
954                 X[...] = 0
955
956                 for _ in range((self.height * self.width) // 10):
957                     i, j = (
958                         torch.randint(self.height, (1,)).item(),
959                         torch.randint(self.width, (1,)).item(),
960                     )
961                     X[i, j] = c[0]
962                     f_X[i, j] = c[0]
963
964                 while True:
965                     di, dj = torch.randint(7, (2,)) - 3
966                     if abs(di) + abs(dj) == 1:
967                         break
968
969                 i, j = (
970                     torch.randint(self.height, (1,)).item(),
971                     torch.randint(self.width, (1,)).item(),
972                 )
973
974                 X[i, j] = c[1]
975                 f_X[i, j] = c[1]
976                 l = 0
977
978                 while True:
979                     l += 1
980                     if free(i + di, j + dj):
981                         pass
982                     elif free(i - dj, j + di):
983                         di, dj = -dj, di
984                         if free(i + dj, j - di):
985                             if torch.rand(1) < 0.5:
986                                 di, dj = -di, -dj
987                     elif free(i + dj, j - di):
988                         di, dj = dj, -di
989                     else:
990                         break
991
992                     i, j = i + di, j + dj
993                     f_X[i, j] = c[2]
994                     if l <= 1:
995                         X[i, j] = c[2]
996                         f_X[i, j] = c[1]
997
998                     if l >= self.width:
999                         break
1000
1001                 f_X[i, j] = c[1]
1002                 X[i, j] = c[1]
1003
1004                 if l > 3:
1005                     break
1006
1007     # @torch.compile
1008     def task_scale(self, A, f_A, B, f_B):
1009         c = torch.randperm(self.nb_colors - 1)[:2] + 1
1010
1011         i, j = (
1012             torch.randint(self.height // 2, (1,)).item(),
1013             torch.randint(self.width // 2, (1,)).item(),
1014         )
1015
1016         for X, f_X in [(A, f_A), (B, f_B)]:
1017             for _ in range(3):
1018                 while True:
1019                     i1, j1 = (
1020                         torch.randint(self.height // 2 + 1, (1,)).item(),
1021                         torch.randint(self.width // 2 + 1, (1,)).item(),
1022                     )
1023                     i2, j2 = (
1024                         torch.randint(self.height // 2 + 1, (1,)).item(),
1025                         torch.randint(self.width // 2 + 1, (1,)).item(),
1026                     )
1027                     if i1 < i2 and j1 < j2 and min(i2 - i1, j2 - j1) <= 3:
1028                         break
1029                 X[i + i1 : i + i2, j + j1 : j + j2] = c[0]
1030                 f_X[2 * i1 : 2 * i2, 2 * j1 : 2 * j2] = c[0]
1031
1032             for k in range(2):
1033                 X[i + k, j] = c[1]
1034                 X[i, j + k] = c[1]
1035                 f_X[i + k, j] = c[1]
1036                 f_X[i, j + k] = c[1]
1037
1038     # @torch.compile
1039     def task_symbols(self, A, f_A, B, f_B):
1040         nb_rec = 4
1041         c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
1042         delta = 3
1043         for X, f_X in [(A, f_A), (B, f_B)]:
1044             while True:
1045                 i, j = torch.randint(self.height - delta + 1, (nb_rec,)), torch.randint(
1046                     self.width - delta + 1, (nb_rec,)
1047                 )
1048                 d = (i[None, :] - i[:, None]).abs().max((j[None, :] - j[:, None]).abs())
1049                 d.fill_diagonal_(delta + 1)
1050                 if d.min() > delta:
1051                     break
1052
1053             ai, aj = i.float().mean(), j.float().mean()
1054
1055             q = torch.randint(3, (1,)).item() + 1
1056
1057             assert i[q] != ai and j[q] != aj
1058
1059             for Z in [X, f_X]:
1060                 for k in range(0, nb_rec):
1061                     Z[i[k] : i[k] + delta, j[k] : j[k] + delta] = c[k]
1062                 # Z[i[0] + delta // 2 - 1, j[0] + delta // 2 - 1] = c[0]
1063                 # Z[i[0] + delta // 2 - 1, j[0] + delta // 2 + 1] = c[0]
1064                 # Z[i[0] + delta // 2 + 1, j[0] + delta // 2 - 1] = c[0]
1065                 # Z[i[0] + delta // 2 + 1, j[0] + delta // 2 + 1] = c[0]
1066
1067             # f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q]
1068
1069             f_X[i[0] + delta // 2, j[0] + delta // 2] = c[q]
1070             # f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q]
1071
1072             ii, jj = (
1073                 i[0] + delta // 2 + (i[q] - ai).sign().long(),
1074                 j[0] + delta // 2 + (j[q] - aj).sign().long(),
1075             )
1076
1077             X[ii, jj] = c[nb_rec]
1078             X[i[0] + delta // 2, jj] = c[nb_rec]
1079             X[ii, j[0] + delta // 2] = c[nb_rec]
1080
1081             f_X[ii, jj] = c[nb_rec]
1082             f_X[i[0] + delta // 2, jj] = c[nb_rec]
1083             f_X[ii, j[0] + delta // 2] = c[nb_rec]
1084
1085     # @torch.compile
1086     def task_isometry(self, A, f_A, B, f_B):
1087         nb_rec = 3
1088         di, dj = torch.randint(3, (2,)) - 1
1089         o = torch.tensor([[0.0, 1.0], [-1.0, 0.0]])
1090         m = torch.eye(2)
1091         for _ in range(torch.randint(4, (1,)).item()):
1092             m = m @ o
1093         if torch.rand(1) < 0.5:
1094             m[0, :] = -m[0, :]
1095
1096         ci, cj = (self.height - 1) / 2, (self.width - 1) / 2
1097
1098         for X, f_X in [(A, f_A), (B, f_B)]:
1099             while True:
1100                 X[...] = 0
1101                 f_X[...] = 0
1102
1103                 c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
1104
1105                 for r in range(nb_rec):
1106                     while True:
1107                         i1, i2 = torch.randint(self.height - 2, (2,)) + 1
1108                         j1, j2 = torch.randint(self.width - 2, (2,)) + 1
1109                         if (
1110                             i2 >= i1
1111                             and j2 >= j1
1112                             and max(i2 - i1, j2 - j1) >= 2
1113                             and min(i2 - i1, j2 - j1) <= 3
1114                         ):
1115                             break
1116                     X[i1 : i2 + 1, j1 : j2 + 1] = c[r]
1117
1118                     i1, j1, i2, j2 = i1 - ci, j1 - cj, i2 - ci, j2 - cj
1119
1120                     i1, j1 = m[0, 0] * i1 + m[0, 1] * j1, m[1, 0] * i1 + m[1, 1] * j1
1121                     i2, j2 = m[0, 0] * i2 + m[0, 1] * j2, m[1, 0] * i2 + m[1, 1] * j2
1122
1123                     i1, j1, i2, j2 = i1 + ci, j1 + cj, i2 + ci, j2 + cj
1124                     i1, i2 = i1.long() + di, i2.long() + di
1125                     j1, j2 = j1.long() + dj, j2.long() + dj
1126                     if i1 > i2:
1127                         i1, i2 = i2, i1
1128                     if j1 > j2:
1129                         j1, j2 = j2, j1
1130
1131                     f_X[i1 : i2 + 1, j1 : j2 + 1] = c[r]
1132
1133                 n = F.one_hot(X.flatten()).sum(dim=0)[1:]
1134                 if (
1135                     n.sum() > self.height * self.width // 4
1136                     and (n > 0).long().sum() == nb_rec
1137                 ):
1138                     break
1139
1140     def compute_distance(self, walls, goal_i, goal_j):
1141         max_length = walls.numel()
1142         dist = torch.full_like(walls, max_length)
1143
1144         dist[goal_i, goal_j] = 0
1145         pred_dist = torch.empty_like(dist)
1146
1147         while True:
1148             pred_dist.copy_(dist)
1149             dist[1:-1, 1:-1] = (
1150                 torch.cat(
1151                     (
1152                         dist[None, 1:-1, 1:-1],
1153                         dist[None, 1:-1, 0:-2],
1154                         dist[None, 2:, 1:-1],
1155                         dist[None, 1:-1, 2:],
1156                         dist[None, 0:-2, 1:-1],
1157                     ),
1158                     0,
1159                 ).min(dim=0)[0]
1160                 + 1
1161             )
1162
1163             dist = walls * max_length + (1 - walls) * dist
1164
1165             if dist.equal(pred_dist):
1166                 return dist * (1 - walls)
1167
1168     # @torch.compile
1169     def REMOVED_task_distance(self, A, f_A, B, f_B):
1170         c = torch.randperm(self.nb_colors - 1)[:3] + 1
1171         dist0 = torch.empty(self.height + 2, self.width + 2)
1172         dist1 = torch.empty(self.height + 2, self.width + 2)
1173         for X, f_X in [(A, f_A), (B, f_B)]:
1174             nb_rec = torch.randint(3, (1,)).item() + 1
1175             while True:
1176                 r = self.rec_coo(nb_rec, prevent_overlap=True)
1177                 X[...] = 0
1178                 f_X[...] = 0
1179                 for n in range(nb_rec):
1180                     i1, j1, i2, j2 = r[n]
1181                     X[i1:i2, j1:j2] = c[0]
1182                     f_X[i1:i2, j1:j2] = c[0]
1183                 while True:
1184                     i0, j0 = (
1185                         torch.randint(self.height, (1,)).item(),
1186                         torch.randint(self.width, (1,)).item(),
1187                     )
1188                     if X[i0, j0] == 0:
1189                         break
1190                 while True:
1191                     i1, j1 = (
1192                         torch.randint(self.height, (1,)).item(),
1193                         torch.randint(self.width, (1,)).item(),
1194                     )
1195                     if X[i1, j1] == 0:
1196                         break
1197                 dist1[...] = 1
1198                 dist1[1:-1, 1:-1] = (X != 0).long()
1199                 dist1[...] = self.compute_distance(dist1, i1 + 1, j1 + 1)
1200                 if (
1201                     dist1[i0 + 1, j0 + 1] >= 1
1202                     and dist1[i0 + 1, j0 + 1] < self.height * 4
1203                 ):
1204                     break
1205
1206             dist0[...] = 1
1207             dist0[1:-1, 1:-1] = (X != 0).long()
1208             dist0[...] = self.compute_distance(dist0, i0 + 1, j0 + 1)
1209
1210             dist0 = dist0[1:-1, 1:-1]
1211             dist1 = dist1[1:-1, 1:-1]
1212
1213             D = dist1[i0, j0]
1214             for d in range(1, D):
1215                 M = (dist0 == d) & (dist1 == D - d)
1216                 f_X[...] = (1 - M) * f_X + M * c[1]
1217
1218             X[i0, j0] = c[2]
1219             f_X[i0, j0] = c[2]
1220             X[i1, j1] = c[2]
1221             f_X[i1, j1] = c[2]
1222
1223     # for X, f_X in [(A, f_A), (B, f_B)]:
1224     # n = torch.arange(self.height * self.width).reshape(self.height, self.width)
1225     # k = torch.randperm(self.height * self.width)
1226     # X[...]=-1
1227     # for q in k:
1228     # i,j=q%self.height,q//self.height
1229     # if
1230
1231     # @torch.compile
1232     def TOO_HARD_task_puzzle(self, A, f_A, B, f_B):
1233         S = 4
1234         i0, j0 = (self.height - S) // 2, (self.width - S) // 2
1235         c = torch.randperm(self.nb_colors - 1)[:4] + 1
1236         for X, f_X in [(A, f_A), (B, f_B)]:
1237             while True:
1238                 f_X[...] = 0
1239                 h = list(torch.randperm(c.size(0)))
1240                 n = torch.zeros(c.max() + 1)
1241                 for _ in range(2):
1242                     k = torch.randperm(S * S)
1243                     for q in k:
1244                         i, j = q % S + i0, q // S + j0
1245                         if f_X[i, j] == 0:
1246                             r, s, t, u = (
1247                                 f_X[i - 1, j],
1248                                 f_X[i, j - 1],
1249                                 f_X[i + 1, j],
1250                                 f_X[i, j + 1],
1251                             )
1252                             r, s, t, u = torch.tensor([r, s, t, u])[torch.randperm(4)]
1253                             if r > 0 and n[r] < 6:
1254                                 n[r] += 1
1255                                 f_X[i, j] = r
1256                             elif s > 0 and n[s] < 6:
1257                                 n[s] += 1
1258                                 f_X[i, j] = s
1259                             elif t > 0 and n[t] < 6:
1260                                 n[t] += 1
1261                                 f_X[i, j] = t
1262                             elif u > 0 and n[u] < 6:
1263                                 n[u] += 1
1264                                 f_X[i, j] = u
1265                             else:
1266                                 if len(h) > 0:
1267                                     d = c[h.pop()]
1268                                     n[d] += 1
1269                                     f_X[i, j] = d
1270
1271                 if n.sum() == S * S:
1272                     break
1273
1274             k = 0
1275             for d in range(4):
1276                 while True:
1277                     ii, jj = (
1278                         torch.randint(self.height, (1,)).item(),
1279                         torch.randint(self.width, (1,)).item(),
1280                     )
1281                     e = 0
1282                     for i in range(S):
1283                         for j in range(S):
1284                             if (
1285                                 ii + i >= self.height
1286                                 or jj + j >= self.width
1287                                 or (
1288                                     f_X[i + i0, j + j0] == c[d]
1289                                     and X[ii + i, jj + j] > 0
1290                                 )
1291                             ):
1292                                 e = 1
1293                     if e == 0:
1294                         break
1295                 for i in range(S):
1296                     for j in range(S):
1297                         if f_X[i + i0, j + j0] == c[d]:
1298                             X[ii + i, jj + j] = c[d]
1299
1300     def TOO_MESSY_task_islands(self, A, f_A, B, f_B):
1301         c = torch.randperm(self.nb_colors - 1)[:2] + 1
1302         for X, f_X in [(A, f_A), (B, f_B)]:
1303             if not hasattr(self, "cache_islands") or len(self.cache_islands) == 0:
1304                 self.cache_islands = list(
1305                     grow_islands(
1306                         1000,
1307                         self.height,
1308                         self.width,
1309                         nb_seeds=self.height * self.width // 20,
1310                         nb_iterations=self.height * self.width // 2,
1311                     )
1312                 )
1313
1314             A = self.cache_islands.pop()
1315
1316             while True:
1317                 i, j = (
1318                     torch.randint(self.height // 2, (1,)).item(),
1319                     torch.randint(self.width // 2, (1,)).item(),
1320                 )
1321                 if A[i, j] > 0:
1322                     break
1323
1324             X[...] = (A > 0) * c[0]
1325             f_X[...] = (A == A[i, j]) * c[1] + ((A > 0) & (A != A[i, j])) * c[0]
1326             f_X[i, j] = X[i, j]
1327             X[i, j] = c[1]
1328
1329     # @torch.compile
1330     def TOO_HARD_task_stack(self, A, f_A, B, f_B):
1331         N = 5
1332         c = torch.randperm(self.nb_colors - 1)[:N] + 1
1333         for X, f_X in [(A, f_A), (B, f_B)]:
1334             i1, j1, i2, j2 = (
1335                 self.height // 2 - 1,
1336                 self.width // 2 - 1,
1337                 self.height // 2 + 1,
1338                 self.width // 2 + 1,
1339             )
1340             op = torch.tensor((0, 1, 2, 3) * 4)
1341             op = op[torch.randperm(op.size(0))[:9]]
1342             for q in range(op.size(0)):
1343                 u = 3 * (q // 3)
1344                 v = 3 * (q % 3)
1345                 d = c[torch.randint(N, (1,)).item()]
1346                 # X[u+1,v+1]=d
1347                 if op[q] == 0:  # right
1348                     X[u : u + 3, v + 2] = d
1349                 elif op[q] == 1:  # let
1350                     X[u : u + 3, v] = d
1351                 elif op[q] == 2:  # bottom
1352                     X[u + 2, v : v + 3] = d
1353                 elif op[q] == 3:  # top
1354                     X[u, v : v + 3] = d
1355
1356                 if q == 0:
1357                     f_X[i1:i2, j1:j2] = d
1358                 elif op[q] == 0:  # right
1359                     f_X[i1:i2, j2] = d
1360                     j2 += 1
1361                 elif op[q] == 1:  # let
1362                     j1 -= 1
1363                     f_X[i1:i2, j1] = d
1364                 elif op[q] == 2:  # bottom
1365                     f_X[i2, j1:j2] = d
1366                     i2 += 1
1367                 elif op[q] == 3:  # top
1368                     i1 -= 1
1369                     f_X[i1, j1:j2] = d
1370
1371     def randint(self, *m):
1372         m = torch.tensor(m)
1373         return (torch.rand(m.size()) * m).long()
1374
1375     def TOO_HARD_task_matrices(self, A, f_A, B, f_B):
1376         N = 6
1377         c = torch.randperm(self.nb_colors - 1)[:N] + 1
1378
1379         for X, f_X in [(A, f_A), (B, f_B)]:
1380             M1 = torch.randint(2, (5, 5))
1381             M2 = torch.randint(2, (5, 5))
1382             P = M1 @ M2
1383             for i in range(5):
1384                 for j in range(5):
1385                     X[i, j] = c[M1[i, j]]
1386                     X[i, j + 5] = c[M2[i, j]]
1387                     f_X[i, j] = c[M1[i, j]]
1388                     f_X[i, j + 5] = c[M2[i, j]]
1389                     f_X[i + 5, j + 5] = c[P[i, j]]
1390
1391     def TOO_HARD_task_compute(self, A, f_A, B, f_B):
1392         N = 6
1393         c = torch.randperm(self.nb_colors - 1)[:N] + 1
1394         for X, f_X in [(A, f_A), (B, f_B)]:
1395             v = torch.randint((self.width - 1) // 2, (N,)) + 1
1396             chain = torch.randperm(N)
1397             eq = []
1398             for i in range(chain.size(0) - 1):
1399                 i1, i2 = chain[i], chain[i + 1]
1400                 v1, v2 = v[i1], v[i2]
1401                 k = torch.arange(self.width // 2) + 1
1402                 d = ((k[None, :] * v1 - k[:, None] * v2) == 0).nonzero() + 1
1403                 d = d[torch.randint(d.size(0), (1,)).item()]
1404                 w1, w2 = d
1405                 eq.append((c[i1], w1, c[i2], w2))
1406
1407             ii = torch.randperm(self.height - 2)[: len(eq)]
1408
1409             for k, x in enumerate(eq):
1410                 i = ii[k]
1411                 c1, w1, c2, w2 = x
1412                 s = torch.randint(self.width - (w1 + w2) + 1, (1,)).item()
1413                 X[i, s : s + w1] = c1
1414                 X[i, s + w1 : s + w1 + w2] = c2
1415                 f_X[i, s : s + w1] = c1
1416                 f_X[i, s + w1 : s + w1 + w2] = c2
1417
1418             i1, i2 = torch.randperm(N)[:2]
1419             v1, v2 = v[i1], v[i2]
1420             k = torch.arange(self.width // 2) + 1
1421             d = ((k[None, :] * v1 - k[:, None] * v2) == 0).nonzero() + 1
1422             d = d[torch.randint(d.size(0), (1,)).item()]
1423             w1, w2 = d
1424             c1, c2 = c[i1], c[i2]
1425             s = 0  # torch.randint(self.width - (w1 + w2) + 1, (1,)).item()
1426             i = self.height - 1
1427             X[i, s : s + w1] = c1
1428             X[i, s + w1 : s + w1 + 1] = c2
1429             f_X[i, s : s + w1] = c1
1430             f_X[i, s + w1 : s + w1 + w2] = c2
1431
1432     # @torch.compile
1433     # [ai1,ai2] [bi1,bi2]
1434     def task_contact(self, A, f_A, B, f_B):
1435         def rec_dist(a, b):
1436             ai1, aj1, ai2, aj2 = a
1437             bi1, bj1, bi2, bj2 = b
1438             v = max(ai1 - bi2, bi1 - ai2)
1439             h = max(aj1 - bj2, bj1 - aj2)
1440             return min(max(v, 0) + max(h + 1, 0), max(v + 1, 0) + max(h, 0))
1441
1442         nb_rec = 3
1443         c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
1444         for X, f_X in [(A, f_A), (B, f_B)]:
1445             while True:
1446                 r = self.rec_coo(nb_rec, prevent_overlap=True)
1447                 d = [rec_dist(r[0], r[k]) for k in range(nb_rec)]
1448                 if min(d[1:]) == 0:
1449                     break
1450
1451             for n in range(nb_rec):
1452                 i1, j1, i2, j2 = r[n]
1453                 X[i1:i2, j1:j2] = c[n]
1454                 f_X[i1:i2, j1:j2] = c[n]
1455                 if d[n] == 0:
1456                     f_X[i1, j1:j2] = c[0]
1457                     f_X[i2 - 1, j1:j2] = c[0]
1458                     f_X[i1:i2, j1] = c[0]
1459                     f_X[i1:i2, j2 - 1] = c[0]
1460
1461     # @torch.compile
1462     # [ai1,ai2] [bi1,bi2]
1463     def task_corners(self, A, f_A, B, f_B):
1464         polarity = torch.randint(2, (1,)).item()
1465         nb_rec = 3
1466         c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
1467         for X, f_X in [(A, f_A), (B, f_B)]:
1468             r = self.rec_coo(nb_rec, prevent_overlap=True)
1469
1470             for n in range(nb_rec):
1471                 i1, j1, i2, j2 = r[n]
1472                 for k in range(2):
1473                     if polarity == 0:
1474                         X[i1 + k, j1] = c[n]
1475                         X[i2 - 1 - k, j2 - 1] = c[n]
1476                         X[i1, j1 + k] = c[n]
1477                         X[i2 - 1, j2 - 1 - k] = c[n]
1478                     else:
1479                         X[i1 + k, j2 - 1] = c[n]
1480                         X[i2 - 1 - k, j1] = c[n]
1481                         X[i1, j2 - 1 - k] = c[n]
1482                         X[i2 - 1, j1 + k] = c[n]
1483                     f_X[i1:i2, j1:j2] = c[n]
1484
1485     def compdist(self, X, i, j):
1486         dd = X.new_full((self.height + 2, self.width + 2), self.height * self.width)
1487         d = dd[1:-1, 1:-1]
1488         m = (X > 0).long()
1489         d[i, j] = 0
1490         e = d.clone()
1491         while True:
1492             e[...] = d
1493             d[...] = (
1494                 d.min(dd[:-2, 1:-1] + 1)
1495                 .min(dd[2:, 1:-1] + 1)
1496                 .min(dd[1:-1, :-2] + 1)
1497                 .min(dd[1:-1, 2:] + 1)
1498             )
1499             d[...] = (1 - m) * d + m * self.height * self.width
1500             if e.equal(d):
1501                 break
1502
1503         return d
1504
1505     # @torch.compile
1506     def task_path(self, A, f_A, B, f_B):
1507         nb_rec = 2
1508         c = torch.randperm(self.nb_colors - 1)[: nb_rec + 2] + 1
1509         for X, f_X in [(A, f_A), (B, f_B)]:
1510             while True:
1511                 X[...] = 0
1512                 f_X[...] = 0
1513
1514                 r = self.rec_coo(nb_rec, prevent_overlap=True)
1515                 for n in range(nb_rec):
1516                     i1, j1, i2, j2 = r[n]
1517                     X[i1:i2, j1:j2] = c[n]
1518                     f_X[i1:i2, j1:j2] = c[n]
1519
1520                 i1, i2 = torch.randint(self.height, (2,))
1521                 j1, j2 = torch.randint(self.width, (2,))
1522                 if (
1523                     abs(i1 - i2) + abs(j1 - j2) > 2
1524                     and X[i1, j1] == 0
1525                     and X[i2, j2] == 0
1526                 ):
1527                     d2 = self.compdist(X, i2, j2)
1528                     d = self.compdist(X, i1, j1)
1529
1530                     if d2[i1, j1] < 2 * self.width:
1531                         break
1532
1533             m = ((d + d2) == d[i2, j2]).long()
1534             f_X[...] = m * c[-1] + (1 - m) * f_X
1535
1536             X[i1, j1] = c[-2]
1537             X[i2, j2] = c[-2]
1538             f_X[i1, j1] = c[-2]
1539             f_X[i2, j2] = c[-2]
1540
1541     # @torch.compile
1542     def task_fill(self, A, f_A, B, f_B):
1543         nb_rec = 3
1544         c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
1545         for X, f_X in [(A, f_A), (B, f_B)]:
1546             accept_full = torch.rand(1) < 0.5
1547
1548             while True:
1549                 X[...] = 0
1550                 f_X[...] = 0
1551
1552                 r = self.rec_coo(nb_rec, prevent_overlap=True)
1553                 for n in range(nb_rec):
1554                     i1, j1, i2, j2 = r[n]
1555                     X[i1:i2, j1:j2] = c[n]
1556                     f_X[i1:i2, j1:j2] = c[n]
1557
1558                 while True:
1559                     i, j = (
1560                         torch.randint(self.height, (1,)).item(),
1561                         torch.randint(self.width, (1,)).item(),
1562                     )
1563                     if X[i, j] == 0:
1564                         break
1565
1566                 d = self.compdist(X, i, j)
1567                 m = (d < self.height * self.width).long()
1568                 X[i, j] = c[-1]
1569                 f_X[...] = m * c[-1] + (1 - m) * f_X
1570                 f_X[i, j] = 0
1571
1572                 if accept_full or (d * (X == 0)).max() == self.height * self.width:
1573                     break
1574
1575     def TOO_HARD_task_addition(self, A, f_A, B, f_B):
1576         c = torch.randperm(self.nb_colors - 1)[:4] + 1
1577         for X, f_X in [(A, f_A), (B, f_B)]:
1578             N1 = torch.randint(2 ** (self.width - 1) - 1, (1,)).item()
1579             N2 = torch.randint(2 ** (self.width - 1) - 1, (1,)).item()
1580             S = N1 + N2
1581             for j in range(self.width):
1582                 r1 = (N1 // (2**j)) % 2
1583                 X[0, -j - 1] = c[r1]
1584                 f_X[0, -j - 1] = c[r1]
1585                 r2 = (N2 // (2**j)) % 2
1586                 X[1, -j - 1] = c[r2]
1587                 f_X[1, -j - 1] = c[r2]
1588                 rs = (S // (2**j)) % 2
1589                 f_X[2, -j - 1] = c[2 + rs]
1590
1591     def task_science_implicit(self, A, f_A, B, f_B):
1592         nb_rec = 5
1593         c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
1594
1595         for X, f_X in [(A, f_A), (B, f_B)]:
1596             while True:
1597                 i1, i2 = torch.randint(self.height, (2,)).sort().values
1598                 if i1 >= 1 and i2 < self.height and i1 + 3 < i2:
1599                     break
1600
1601             while True:
1602                 j1, j2 = torch.randint(self.width, (2,)).sort().values
1603                 if j1 >= 1 and j2 < self.width and j1 + 3 < j2:
1604                     break
1605
1606             f_X[i1:i2, j1:j2] = c[0]
1607
1608             # ---------------------
1609
1610             while True:
1611                 ii1, ii2 = torch.randint(self.height, (2,)).sort().values
1612                 if ii1 >= i1 and ii2 <= i2 and ii1 + 1 < ii2:
1613                     break
1614             jj = torch.randint(j1, (1,))
1615             X[ii1:ii2, jj:j1] = c[1]
1616             f_X[ii1:ii2, jj:j1] = c[1]
1617
1618             while True:
1619                 ii1, ii2 = torch.randint(self.height, (2,)).sort().values
1620                 if ii1 >= i1 and ii2 <= i2 and ii1 + 1 < ii2:
1621                     break
1622             jj = torch.randint(self.width - j2, (1,)) + j2 + 1
1623             X[ii1:ii2, j2:jj] = c[2]
1624             f_X[ii1:ii2, j2:jj] = c[2]
1625
1626             # ---------------------
1627
1628             while True:
1629                 jj1, jj2 = torch.randint(self.width, (2,)).sort().values
1630                 if jj1 >= j1 and jj2 <= j2 and jj1 + 1 < jj2:
1631                     break
1632             ii = torch.randint(i1, (1,))
1633             X[ii:i1, jj1:jj2] = c[3]
1634             f_X[ii:i1, jj1:jj2] = c[3]
1635
1636             while True:
1637                 jj1, jj2 = torch.randint(self.width, (2,)).sort().values
1638                 if jj1 >= j1 and jj2 <= j2 and jj1 + 1 < jj2:
1639                     break
1640             ii = torch.randint(self.height - i2, (1,)) + i2 + 1
1641             X[i2:ii, jj1:jj2] = c[4]
1642             f_X[i2:ii, jj1:jj2] = c[4]
1643
1644     def task_science_dot(self, A, f_A, B, f_B):
1645         nb_rec = 3
1646         c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
1647         for X, f_X in [(A, f_A), (B, f_B)]:
1648             while True:
1649                 X[...] = 0
1650                 f_X[...] = 0
1651                 r = self.rec_coo(nb_rec, prevent_overlap=True)
1652                 i, j = (
1653                     torch.randint(self.height, (1,)).item(),
1654                     torch.randint(self.width, (1,)).item(),
1655                 )
1656                 q = 0
1657                 for n in range(nb_rec):
1658                     i1, j1, i2, j2 = r[n]
1659                     X[i1:i2, j1:j2] = c[n]
1660                     f_X[i1:i2, j1:j2] = c[n]
1661                     if i >= i1 and i < i2:
1662                         q += 1
1663                         f_X[i, j1:j2] = c[-1]
1664                     if j >= j1 and j < j2:
1665                         q += 1
1666                         f_X[i1:i2, j] = c[-1]
1667                 X[i, j] = c[-1]
1668                 f_X[i, j] = c[-1]
1669                 if q >= 2:
1670                     break
1671
1672     def collide(self, s, r, rs):
1673         i, j = r
1674         for i2, j2 in rs:
1675             if abs(i - i2) < s and abs(j - j2) < s:
1676                 return True
1677         return False
1678
1679     def task_science_tag(self, A, f_A, B, f_B):
1680         c = torch.randperm(self.nb_colors - 1)[:4] + 1
1681         for X, f_X in [(A, f_A), (B, f_B)]:
1682             rs = []
1683             while len(rs) < 4:
1684                 i, j = (
1685                     torch.randint(self.height - 3, (1,)).item(),
1686                     torch.randint(self.width - 3, (1,)).item(),
1687                 )
1688                 if not self.collide(s=3, r=(i, j), rs=rs):
1689                     rs.append((i, j))
1690
1691             for k in range(len(rs)):
1692                 i, j = rs[k]
1693                 q = min(k, 2)
1694                 X[i, j : j + 3] = c[q]
1695                 X[i + 2, j : j + 3] = c[q]
1696                 X[i : i + 3, j] = c[q]
1697                 X[i : i + 3, j + 2] = c[q]
1698
1699                 f_X[i, j : j + 3] = c[q]
1700                 f_X[i + 2, j : j + 3] = c[q]
1701                 f_X[i : i + 3, j] = c[q]
1702                 f_X[i : i + 3, j + 2] = c[q]
1703                 if q == 2:
1704                     f_X[i + 1, j + 1] = c[-1]
1705
1706     # end_tasks
1707
1708     ######################################################################
1709
1710     def create_empty_quizzes(self, nb, struct=("A", "f_A", "B", "f_B")):
1711         S = self.height * self.width
1712         quizzes = torch.zeros(nb, 4 * (S + 1), dtype=torch.int64)
1713         quizzes[:, 0 * (S + 1)] = self.l2tok[struct[0]]
1714         quizzes[:, 1 * (S + 1)] = self.l2tok[struct[1]]
1715         quizzes[:, 2 * (S + 1)] = self.l2tok[struct[2]]
1716         quizzes[:, 3 * (S + 1)] = self.l2tok[struct[3]]
1717
1718         return quizzes
1719
1720     def generate_w_quizzes_(self, nb, tasks=None, progress_bar=False):
1721         S = self.height * self.width
1722
1723         if tasks is None:
1724             tasks = self.all_tasks
1725
1726         quizzes = self.create_empty_quizzes(nb, ("A", "f_A", "B", "f_B"))
1727
1728         if progress_bar:
1729             quizzes = tqdm.tqdm(
1730                 quizzes,
1731                 dynamic_ncols=True,
1732                 desc="world quizzes generation",
1733                 total=quizzes.size(0),
1734             )
1735
1736         for quiz in quizzes:
1737             q = quiz.reshape(4, S + 1)[:, 1:].reshape(4, self.height, self.width)
1738             q[...] = 0
1739             A, f_A, B, f_B = q
1740             task = tasks[torch.randint(len(tasks), (1,)).item()]
1741             task(A, f_A, B, f_B)
1742
1743         return quizzes
1744
1745     def save_some_examples(self, result_dir, prefix=""):
1746         nb, nrow = 128, 4
1747         for t in self.all_tasks:
1748             print(t.__name__)
1749             quizzes = self.generate_w_quizzes_(nb, tasks=[t])
1750             self.save_quizzes_as_image(
1751                 result_dir, prefix + t.__name__ + ".png", quizzes, nrow=nrow
1752             )
1753
1754
1755 ######################################################################
1756
1757 if __name__ == "__main__":
1758     import time
1759
1760     # grids = Grids(max_nb_cached_chunks=5, chunk_size=100, nb_threads=4)
1761
1762     grids = Grids()
1763
1764     # nb = 5
1765     # quizzes = grids.generate_w_quizzes_(nb, tasks=[grids.task_fill])
1766     # print(quizzes)
1767     # print(grids.get_structure(quizzes))
1768     # quizzes = grids.reconfigure(quizzes, struct=("A", "B", "f_A", "f_B"))
1769     # print("DEBUG2", quizzes)
1770     # print(grids.get_structure(quizzes))
1771     # print(quizzes)
1772
1773     # i = torch.rand(quizzes.size(0)) < 0.5
1774
1775     # quizzes[i] = grids.reconfigure(quizzes[i], struct=("f_B", "f_A", "B", "A"))
1776
1777     # j = grids.indices_select(quizzes, struct=("f_B", "f_A", "B", "A"))
1778
1779     # print(
1780     # i.equal(j),
1781     # grids.get_structure(quizzes[j]),
1782     # grids.get_structure(quizzes[j == False]),
1783     # )
1784
1785     #   exit(0)
1786
1787     # nb = 1000
1788     # grids = problem.MultiThreadProblem(
1789     # grids, max_nb_cached_chunks=50, chunk_size=100, nb_threads=1
1790     # )
1791     #    time.sleep(10)
1792     # start_time = time.perf_counter()
1793     # prompts, answers = grids.generate_w_quizzes(nb)
1794     # delay = time.perf_counter() - start_time
1795     # print(f"{prompts.size(0)/delay:02f} seq/s")
1796     # exit(0)
1797
1798     # if True:
1799     nb, nrow = 128, 4
1800     # nb, nrow = 8, 2
1801
1802     # for t in grids.all_tasks:
1803
1804     for t in [grids.task_recworld_immobile]:
1805         print(t.__name__)
1806         w_quizzes = grids.generate_w_quizzes_(nb, tasks=[t])
1807         grids.save_quizzes_as_image(
1808             "/tmp",
1809             t.__name__ + ".png",
1810             w_quizzes,
1811             comments=[f"{t.__name__} #{k}" for k in range(w_quizzes.size(0))],
1812         )
1813
1814     exit(0)
1815
1816     nb = 1000
1817
1818     for t in [
1819         # grids.task_bounce,
1820         # grids.task_contact,
1821         # grids.task_corners,
1822         # grids.task_detect,
1823         # grids.task_fill,
1824         # grids.task_frame,
1825         # grids.task_grow,
1826         # grids.task_half_fill,
1827         # grids.task_isometry,
1828         # grids.task_path,
1829         # grids.task_replace_color,
1830         # grids.task_scale,
1831         grids.task_symbols,
1832         # grids.task_trajectory,
1833         # grids.task_translate,
1834     ]:
1835         # for t in [grids.task_path]:
1836         start_time = time.perf_counter()
1837         w_quizzes = grids.generate_w_quizzes_(nb, tasks=[t])
1838         delay = time.perf_counter() - start_time
1839         print(f"{t.__name__} {w_quizzes.size(0)/delay:02f} seq/s")
1840         grids.save_quizzes_as_image("/tmp", t.__name__ + ".png", w_quizzes[:128])
1841
1842     exit(0)
1843
1844     m = torch.randint(2, (prompts.size(0),))
1845     predicted_prompts = m * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
1846     predicted_answers = (1 - m) * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
1847
1848     grids.save_quizzes_as_image(
1849         "/tmp",
1850         "test.png",
1851         prompts[:nb],
1852         answers[:nb],
1853         # You can add a bool to put a frame around the predicted parts
1854         predicted_prompts[:nb],
1855         predicted_answers[:nb],
1856     )