Merge branch 'dev'
[culture.git] / grids.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, tqdm, os, warnings
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 ######################################################################
16
17 import problem
18
19
20 def grow_islands(nb, height, width, nb_seeds, nb_iterations):
21     w = torch.empty(5, 1, 3, 3)
22
23     w[0, 0] = torch.tensor(
24         [
25             [1.0, 1.0, 1.0],
26             [1.0, 0.0, 1.0],
27             [1.0, 1.0, 1.0],
28         ]
29     )
30
31     w[1, 0] = torch.tensor(
32         [
33             [-1.0, 1.0, 0.0],
34             [1.0, 0.0, 0.0],
35             [0.0, 0.0, 0.0],
36         ]
37     )
38
39     w[2, 0] = torch.tensor(
40         [
41             [0.0, 1.0, -1.0],
42             [0.0, 0.0, 1.0],
43             [0.0, 0.0, 0.0],
44         ]
45     )
46
47     w[3, 0] = torch.tensor(
48         [
49             [0.0, 0.0, 0.0],
50             [0.0, 0.0, 1.0],
51             [0.0, 1.0, -1.0],
52         ]
53     )
54
55     w[4, 0] = torch.tensor(
56         [
57             [0.0, 0.0, 0.0],
58             [1.0, 0.0, 0.0],
59             [-1.0, 1.0, 0.0],
60         ]
61     )
62
63     Z = torch.zeros(nb, height, width)
64     U = Z.flatten(1)
65
66     for _ in range(nb_seeds):
67         M = F.conv2d(Z[:, None, :, :], w, padding=1)
68         M = torch.cat([M[:, :1], M[:, 1:].min(dim=1, keepdim=True).values], dim=1)
69         M = ((M[:, 0] == 0) & (Z == 0)).long()
70         Q = (M.flatten(1).max(dim=1).values > 0).long()[:, None]
71         M = M * torch.rand(M.size())
72         M = M.flatten(1)
73         M = F.one_hot(M.argmax(dim=1), num_classes=M.size(1))
74         U += M * Q
75
76     for _ in range(nb_iterations):
77         M = F.conv2d(Z[:, None, :, :], w, padding=1)
78         M = torch.cat([M[:, :1], M[:, 1:].min(dim=1, keepdim=True).values], dim=1)
79         M = ((M[:, 1] >= 0) & (Z == 0)).long()
80         Q = (M.flatten(1).max(dim=1).values > 0).long()[:, None]
81         M = M * torch.rand(M.size())
82         M = M.flatten(1)
83         M = F.one_hot(M.argmax(dim=1), num_classes=M.size(1))
84         U = Z.flatten(1)
85         U += M * Q
86
87     M = Z.clone()
88     Z = Z * (torch.arange(Z.size(1) * Z.size(2)) + 1).reshape(1, Z.size(1), Z.size(2))
89
90     while True:
91         W = Z.clone()
92         Z = F.max_pool2d(Z, 3, 1, 1) * M
93         if Z.equal(W):
94             break
95
96     Z = Z.long()
97     U = Z.flatten(1)
98     V = F.one_hot(U).max(dim=1).values
99     W = V.cumsum(dim=1) - V
100     N = torch.arange(Z.size(0))[:, None, None].expand_as(Z)
101     Z = W[N, Z]
102
103     return Z
104
105
106 class Grids(problem.Problem):
107     named_colors = [
108         ("white", [255, 255, 255]),
109         ("red", [255, 0, 0]),
110         ("green", [0, 192, 0]),
111         ("blue", [0, 0, 255]),
112         ("yellow", [255, 224, 0]),
113         ("cyan", [0, 255, 255]),
114         ("violet", [224, 128, 255]),
115         ("lightgreen", [192, 255, 192]),
116         ("brown", [165, 42, 42]),
117         ("lightblue", [192, 192, 255]),
118         ("gray", [128, 128, 128]),
119     ]
120
121     def __init__(
122         self,
123         max_nb_cached_chunks=None,
124         chunk_size=None,
125         nb_threads=-1,
126         tasks=None,
127     ):
128         self.colors = torch.tensor([c for _, c in self.named_colors])
129         self.height = 10
130         self.width = 10
131         self.cache_rec_coo = {}
132
133         all_tasks = [
134             self.task_replace_color,
135             self.task_translate,
136             self.task_grow,
137             self.task_half_fill,
138             self.task_frame,
139             self.task_detect,
140             self.task_count,
141             self.task_trajectory,
142             self.task_bounce,
143             self.task_scale,
144             self.task_symbols,
145             self.task_isometry,
146             #            self.task_islands,
147         ]
148
149         if tasks is None:
150             self.all_tasks = all_tasks
151         else:
152             self.all_tasks = [getattr(self, "task_" + t) for t in tasks.split(",")]
153
154         super().__init__(max_nb_cached_chunks, chunk_size, nb_threads)
155
156     ######################################################################
157
158     def frame2img(self, x, scale=15):
159         x = x.reshape(x.size(0), self.height, -1)
160         m = torch.logical_and(x >= 0, x < self.nb_token_values()).long()
161         x = self.colors[x * m].permute(0, 3, 1, 2)
162         s = x.shape
163         x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
164         x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
165
166         x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
167         x[:, :, torch.arange(0, x.size(2), scale), :] = 0
168         x = x[:, :, 1:, 1:]
169
170         for n in range(m.size(0)):
171             for i in range(m.size(1)):
172                 for j in range(m.size(2)):
173                     if m[n, i, j] == 0:
174                         for k in range(2, scale - 2):
175                             for l in [0, 1]:
176                                 x[n, :, i * scale + k, j * scale + k - l] = 0
177                                 x[
178                                     n, :, i * scale + scale - 1 - k, j * scale + k - l
179                                 ] = 0
180
181         return x
182
183     def save_image(
184         self,
185         result_dir,
186         filename,
187         prompts,
188         answers,
189         predicted_prompts=None,
190         predicted_answers=None,
191         nrow=4,
192         margin=8,
193     ):
194         S = self.height * self.width
195         As = prompts[:, 0 * (S + 1) : 0 * (S + 1) + S].view(-1, self.height, self.width)
196         f_As = prompts[:, 1 * (S + 1) : 1 * (S + 1) + S].view(
197             -1, self.height, self.width
198         )
199         Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S].view(-1, self.height, self.width)
200         prompts = torch.cat([As, f_As, Bs], dim=2)
201         answers = answers.reshape(answers.size(0), self.height, self.width)
202
203         if predicted_prompts is None:
204             predicted_prompts = 255
205
206         if predicted_answers is None:
207             predicted_answers = 255
208
209         def add_frame(x, c, margin, bottom=False):
210             if bottom:
211                 h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0
212             else:
213                 h, w, di, dj = (
214                     x.size(2) + 2 * margin,
215                     x.size(3) + 2 * margin,
216                     margin,
217                     margin,
218                 )
219
220             y = x.new_full((x.size(0), x.size(1), h, w), 0)
221
222             if type(c) is int:
223                 y[...] = c
224             else:
225                 c = c.long()[:, None]
226                 c = (
227                     (1 - ((c == 1).long() + (c == 0).long() + (c == -1).long()))
228                     * torch.tensor([64, 64, 64])
229                     + (c == 1).long() * torch.tensor([0, 255, 0])
230                     + (c == 0).long() * torch.tensor([255, 255, 255])
231                     + (c == -1).long() * torch.tensor([255, 0, 0])
232                 )
233                 y[...] = c[:, :, None, None]
234
235             y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x
236
237             return y
238
239         img_prompts = torch.cat(
240             [
241                 add_frame(
242                     add_frame(self.frame2img(x), c=0, margin=1),
243                     c=predicted_prompts,
244                     margin=margin,
245                 )
246                 for x in prompts.to("cpu").split(split_size=self.width, dim=2)
247             ],
248             dim=3,
249         )
250
251         h = img_prompts.size(2)
252         img_answers = add_frame(
253             add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1),
254             c=predicted_answers,
255             margin=margin,
256         )
257
258         separator_size = 2 * margin
259
260         separator = img_prompts.new_full(
261             (
262                 img_prompts.size(0),
263                 img_prompts.size(1),
264                 img_prompts.size(2),
265                 separator_size,
266             ),
267             255,
268         )
269
270         marker = img_prompts.new_full(
271             (
272                 img_prompts.size(0),
273                 img_prompts.size(1),
274                 img_prompts.size(2),
275                 separator_size,
276             ),
277             255,
278         )
279
280         # marker[:, :, 0] = 0
281         # marker[:, :, h - 1] = 0
282
283         for k in range(1, 2 * separator_size - 8):
284             i = k - (separator_size - 4)
285             j = separator_size - 5 - abs(i)
286             marker[:, :, h // 2 - 1 + i, 2 + j] = 0
287             marker[:, :, h // 2 - 1 + i + 1, 2 + j] = 0
288
289         img = torch.cat(
290             [
291                 img_prompts,
292                 marker,
293                 img_answers,
294             ],
295             dim=3,
296         )
297
298         image_name = os.path.join(result_dir, filename)
299         torchvision.utils.save_image(
300             img.float() / 255.0,
301             image_name,
302             nrow=nrow,
303             padding=margin * 4,
304             pad_value=1.0,
305         )
306
307     ######################################################################
308
309     def nb_token_values(self):
310         return len(self.colors)
311
312     # @torch.compile
313     def rec_coo(
314         self,
315         nb_rec,
316         min_height=3,
317         min_width=3,
318         surface_max=None,
319         prevent_overlap=False,
320     ):
321         if surface_max is None:
322             surface_max = self.height * self.width // 2
323
324         signature = (nb_rec, min_height, min_width, surface_max)
325
326         try:
327             return self.cache_rec_coo[signature].pop()
328         except IndexError:
329             pass
330         except KeyError:
331             pass
332
333         N = 10000
334         while True:
335             while True:
336                 i = torch.randint(self.height, (N * nb_rec, 2)).sort(dim=-1).values
337                 j = torch.randint(self.width, (N * nb_rec, 2)).sort(dim=-1).values
338
339                 big_enough = (
340                     (i[:, 1] >= i[:, 0] + min_height)
341                     & (j[:, 1] >= j[:, 0] + min_height)
342                     & ((i[:, 1] - i[:, 0]) * (j[:, 1] - j[:, 0]) <= surface_max)
343                 )
344
345                 i, j = i[big_enough], j[big_enough]
346
347                 n = i.size(0) - i.size(0) % nb_rec
348
349                 if n > 0:
350                     break
351
352             i = i[:n].reshape(n // nb_rec, nb_rec, -1)
353             j = j[:n].reshape(n // nb_rec, nb_rec, -1)
354
355             if prevent_overlap:
356                 can_fit = ((i[:, :, 1] - i[:, :, 0]) * (j[:, :, 1] - j[:, :, 0])).sum(
357                     dim=-1
358                 ) <= self.height * self.width
359                 i, j = i[can_fit], j[can_fit]
360                 if nb_rec == 2:
361                     A_i1, A_i2, A_j1, A_j2 = (
362                         i[:, 0, 0],
363                         i[:, 0, 1],
364                         j[:, 0, 0],
365                         j[:, 0, 1],
366                     )
367                     B_i1, B_i2, B_j1, B_j2 = (
368                         i[:, 1, 0],
369                         i[:, 1, 1],
370                         j[:, 1, 0],
371                         j[:, 1, 1],
372                     )
373                     no_overlap = torch.logical_not(
374                         (A_i1 >= B_i2)
375                         & (A_i2 <= B_i1)
376                         & (A_j1 >= B_j1)
377                         & (A_j2 <= B_j1)
378                     )
379                     i, j = i[no_overlap], j[no_overlap]
380                 elif nb_rec == 3:
381                     A_i1, A_i2, A_j1, A_j2 = (
382                         i[:, 0, 0],
383                         i[:, 0, 1],
384                         j[:, 0, 0],
385                         j[:, 0, 1],
386                     )
387                     B_i1, B_i2, B_j1, B_j2 = (
388                         i[:, 1, 0],
389                         i[:, 1, 1],
390                         j[:, 1, 0],
391                         j[:, 1, 1],
392                     )
393                     C_i1, C_i2, C_j1, C_j2 = (
394                         i[:, 2, 0],
395                         i[:, 2, 1],
396                         j[:, 2, 0],
397                         j[:, 2, 1],
398                     )
399                     no_overlap = (
400                         (
401                             (A_i1 >= B_i2)
402                             | (A_i2 <= B_i1)
403                             | (A_j1 >= B_j2)
404                             | (A_j2 <= B_j1)
405                         )
406                         & (
407                             (A_i1 >= C_i2)
408                             | (A_i2 <= C_i1)
409                             | (A_j1 >= C_j2)
410                             | (A_j2 <= C_j1)
411                         )
412                         & (
413                             (B_i1 >= C_i2)
414                             | (B_i2 <= C_i1)
415                             | (B_j1 >= C_j2)
416                             | (B_j2 <= C_j1)
417                         )
418                     )
419                     i, j = (i[no_overlap], j[no_overlap])
420                 else:
421                     assert nb_rec == 1
422
423             if i.size(0) > 1:
424                 break
425
426         self.cache_rec_coo[signature] = [
427             [
428                 (
429                     i[n, k, 0].item(),
430                     j[n, k, 0].item(),
431                     i[n, k, 1].item(),
432                     j[n, k, 1].item(),
433                 )
434                 for k in range(nb_rec)
435             ]
436             for n in range(i.size(0))
437         ]
438
439         return self.cache_rec_coo[signature].pop()
440
441     ######################################################################
442
443     # @torch.compile
444     def task_replace_color(self, A, f_A, B, f_B):
445         nb_rec = 3
446         c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
447         for X, f_X in [(A, f_A), (B, f_B)]:
448             r = self.rec_coo(nb_rec, prevent_overlap=True)
449             for n in range(nb_rec):
450                 i1, j1, i2, j2 = r[n]
451                 X[i1:i2, j1:j2] = c[n]
452                 f_X[i1:i2, j1:j2] = c[n if n > 0 else -1]
453
454     # @torch.compile
455     def task_translate(self, A, f_A, B, f_B):
456         while True:
457             di, dj = torch.randint(3, (2,)) - 1
458             if di.abs() + dj.abs() > 0:
459                 break
460
461         nb_rec = 3
462         c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
463         for X, f_X in [(A, f_A), (B, f_B)]:
464             while True:
465                 r = self.rec_coo(nb_rec, prevent_overlap=True)
466                 i1, j1, i2, j2 = r[nb_rec - 1]
467                 if (
468                     i1 + di >= 0
469                     and i2 + di < X.size(0)
470                     and j1 + dj >= 0
471                     and j2 + dj < X.size(1)
472                 ):
473                     break
474
475             for n in range(nb_rec):
476                 i1, j1, i2, j2 = r[n]
477                 X[i1:i2, j1:j2] = c[n]
478                 if n == nb_rec - 1:
479                     f_X[i1 + di : i2 + di, j1 + dj : j2 + dj] = c[n]
480                 else:
481                     f_X[i1:i2, j1:j2] = c[n]
482
483     # @torch.compile
484     def task_grow(self, A, f_A, B, f_B):
485         di, dj = torch.randint(2, (2,)) * 2 - 1
486         nb_rec = 3
487         c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
488         direction = torch.randint(2, (1,)).item()
489         for X, f_X in [(A, f_A), (B, f_B)]:
490             while True:
491                 r = self.rec_coo(nb_rec, prevent_overlap=True)
492                 i1, j1, i2, j2 = r[nb_rec - 1]
493                 if i1 + 3 < i2 and j1 + 3 < j2:
494                     break
495
496             for n in range(nb_rec):
497                 i1, j1, i2, j2 = r[n]
498                 if n == nb_rec - 1:
499                     if direction == 0:
500                         X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
501                         f_X[i1:i2, j1:j2] = c[n]
502                     else:
503                         X[i1:i2, j1:j2] = c[n]
504                         f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
505                 else:
506                     X[i1:i2, j1:j2] = c[n]
507                     f_X[i1:i2, j1:j2] = c[n]
508
509     # @torch.compile
510     def task_half_fill(self, A, f_A, B, f_B):
511         di, dj = torch.randint(2, (2,)) * 2 - 1
512         nb_rec = 3
513         c = torch.randperm(len(self.colors) - 1)[: 2 * nb_rec] + 1
514         direction = torch.randint(4, (1,)).item()
515         for X, f_X in [(A, f_A), (B, f_B)]:
516             r = self.rec_coo(nb_rec, prevent_overlap=True)
517             for n in range(nb_rec):
518                 i1, j1, i2, j2 = r[n]
519                 X[i1:i2, j1:j2] = c[2 * n]
520                 f_X[i1:i2, j1:j2] = c[2 * n]
521                 # Not my proudest moment
522                 if direction == 0:
523                     i = (i1 + i2) // 2
524                     X[i : i + 1, j1:j2] = c[2 * n + 1]
525                     if n == nb_rec - 1:
526                         f_X[i:i2, j1:j2] = c[2 * n + 1]
527                     else:
528                         f_X[i : i + 1, j1:j2] = c[2 * n + 1]
529                 elif direction == 1:
530                     i = (i1 + i2 - 1) // 2
531                     X[i : i + 1, j1:j2] = c[2 * n + 1]
532                     if n == nb_rec - 1:
533                         f_X[i1 : i + 1, j1:j2] = c[2 * n + 1]
534                     else:
535                         f_X[i : i + 1, j1:j2] = c[2 * n + 1]
536                 elif direction == 2:
537                     j = (j1 + j2) // 2
538                     X[i1:i2, j : j + 1] = c[2 * n + 1]
539                     if n == nb_rec - 1:
540                         f_X[i1:i2, j:j2] = c[2 * n + 1]
541                     else:
542                         f_X[i1:i2, j : j + 1] = c[2 * n + 1]
543                 elif direction == 3:
544                     j = (j1 + j2 - 1) // 2
545                     X[i1:i2, j : j + 1] = c[2 * n + 1]
546                     if n == nb_rec - 1:
547                         f_X[i1:i2, j1 : j + 1] = c[2 * n + 1]
548                     else:
549                         f_X[i1:i2, j : j + 1] = c[2 * n + 1]
550
551     # @torch.compile
552     def task_frame(self, A, f_A, B, f_B):
553         nb_rec = 3
554         c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
555         for X, f_X in [(A, f_A), (B, f_B)]:
556             r = self.rec_coo(nb_rec, prevent_overlap=True)
557             for n in range(nb_rec):
558                 i1, j1, i2, j2 = r[n]
559                 X[i1:i2, j1:j2] = c[n]
560                 if n == nb_rec - 1:
561                     f_X[i1:i2, j1] = c[n]
562                     f_X[i1:i2, j2 - 1] = c[n]
563                     f_X[i1, j1:j2] = c[n]
564                     f_X[i2 - 1, j1:j2] = c[n]
565                 else:
566                     f_X[i1:i2, j1:j2] = c[n]
567
568     # @torch.compile
569     def task_detect(self, A, f_A, B, f_B):
570         nb_rec = 3
571         c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
572         for X, f_X in [(A, f_A), (B, f_B)]:
573             r = self.rec_coo(nb_rec, prevent_overlap=True)
574             for n in range(nb_rec):
575                 i1, j1, i2, j2 = r[n]
576                 X[i1:i2, j1:j2] = c[n]
577                 if n < nb_rec - 1:
578                     f_X[i1, j1] = c[-1]
579
580     # @torch.compile
581     def contact(self, X, i, j, q):
582         nq, nq_diag = 0, 0
583         no = 0
584
585         for ii, jj in [
586             (i - 1, j - 1),
587             (i - 1, j),
588             (i - 1, j + 1),
589             (i, j - 1),
590             (i, j + 1),
591             (i + 1, j - 1),
592             (i + 1, j),
593             (i + 1, j + 1),
594         ]:
595             if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width:
596                 if X[ii, jj] != 0 and X[ii, jj] != q:
597                     no += 1
598
599         for ii, jj in [
600             (i - 1, j - 1),
601             (i - 1, j + 1),
602             (i + 1, j - 1),
603             (i + 1, j + 1),
604         ]:
605             if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width:
606                 if X[ii, jj] == q and X[i, jj] != q and X[ii, j] != q:
607                     nq_diag += 1
608
609         for ii, jj in [(i - 1, j), (i, j - 1), (i, j + 1), (i + 1, j)]:
610             if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width:
611                 if X[ii, jj] == q:
612                     nq += 1
613
614         return no, nq, nq_diag
615
616     def task_count(self, A, f_A, B, f_B):
617         while True:
618             error = False
619
620             N = torch.randint(5, (1,)).item() + 1
621             c = torch.zeros(N + 1)
622             c[1:] = torch.randperm(len(self.colors) - 1)[:N] + 1
623
624             for X, f_X in [(A, f_A), (B, f_B)]:
625                 if not hasattr(self, "cache_count") or len(self.cache_count) == 0:
626                     self.cache_count = list(
627                         grow_islands(
628                             1000,
629                             self.height,
630                             self.width,
631                             nb_seeds=self.height * self.width // 8,
632                             nb_iterations=self.height * self.width // 10,
633                         )
634                     )
635
636                 X[...] = self.cache_count.pop()
637
638                 k = (X.max() + 1 + (c.size(0) - 1)).item()
639                 V = torch.arange(k) // (c.size(0) - 1)
640                 V = (V + torch.rand(V.size())).sort().indices[: X.max() + 1] % (
641                     c.size(0) - 1
642                 ) + 1
643                 V[0] = 0
644                 X[...] = c[V[X]]
645
646                 if F.one_hot(X.flatten()).max(dim=0).values.sum().item() == N + 1:
647                     f_X[...] = 0
648                     for e in range(1, N + 1):
649                         for j in range((X == c[e]).sum() + 1):
650                             if j < self.width:
651                                 f_X[e - 1, j] = c[e]
652                             else:
653                                 error = True
654                                 break
655                 else:
656                     error = True
657                     break
658
659             if not error:
660                 break
661
662     # @torch.compile
663     def task_trajectory(self, A, f_A, B, f_B):
664         c = torch.randperm(len(self.colors) - 1)[:2] + 1
665         for X, f_X in [(A, f_A), (B, f_B)]:
666             while True:
667                 di, dj = torch.randint(7, (2,)) - 3
668                 i, j = (
669                     torch.randint(self.height, (1,)).item(),
670                     torch.randint(self.width, (1,)).item(),
671                 )
672                 if (
673                     abs(di) + abs(dj) > 0
674                     and i + 2 * di >= 0
675                     and i + 2 * di < self.height
676                     and j + 2 * dj >= 0
677                     and j + 2 * dj < self.width
678                 ):
679                     break
680
681             k = 0
682             while (
683                 i + k * di >= 0
684                 and i + k * di < self.height
685                 and j + k * dj >= 0
686                 and j + k * dj < self.width
687             ):
688                 if k < 2:
689                     X[i + k * di, j + k * dj] = c[k]
690                 f_X[i + k * di, j + k * dj] = c[min(k, 1)]
691                 k += 1
692
693     # @torch.compile
694     def task_bounce(self, A, f_A, B, f_B):
695         c = torch.randperm(len(self.colors) - 1)[:3] + 1
696         for X, f_X in [(A, f_A), (B, f_B)]:
697             # @torch.compile
698             def free(i, j):
699                 return (
700                     i >= 0
701                     and i < self.height
702                     and j >= 0
703                     and j < self.width
704                     and f_X[i, j] == 0
705                 )
706
707             while True:
708                 f_X[...] = 0
709                 X[...] = 0
710
711                 for _ in range((self.height * self.width) // 10):
712                     i, j = (
713                         torch.randint(self.height, (1,)).item(),
714                         torch.randint(self.width, (1,)).item(),
715                     )
716                     X[i, j] = c[0]
717                     f_X[i, j] = c[0]
718
719                 while True:
720                     di, dj = torch.randint(7, (2,)) - 3
721                     if abs(di) + abs(dj) == 1:
722                         break
723
724                 i, j = (
725                     torch.randint(self.height, (1,)).item(),
726                     torch.randint(self.width, (1,)).item(),
727                 )
728
729                 X[i, j] = c[1]
730                 f_X[i, j] = c[1]
731                 l = 0
732
733                 while True:
734                     l += 1
735                     if free(i + di, j + dj):
736                         pass
737                     elif free(i - dj, j + di):
738                         di, dj = -dj, di
739                         if free(i + dj, j - di):
740                             if torch.rand(1) < 0.5:
741                                 di, dj = -di, -dj
742                     elif free(i + dj, j - di):
743                         di, dj = dj, -di
744                     else:
745                         break
746
747                     i, j = i + di, j + dj
748                     f_X[i, j] = c[2]
749                     if l <= 1:
750                         X[i, j] = c[2]
751
752                     if l >= self.width:
753                         break
754
755                 f_X[i, j] = c[1]
756                 X[i, j] = c[1]
757
758                 if l > 3:
759                     break
760
761     # @torch.compile
762     def task_scale(self, A, f_A, B, f_B):
763         c = torch.randperm(len(self.colors) - 1)[:2] + 1
764
765         i, j = (
766             torch.randint(self.height // 2, (1,)).item(),
767             torch.randint(self.width // 2, (1,)).item(),
768         )
769
770         for X, f_X in [(A, f_A), (B, f_B)]:
771             for _ in range(3):
772                 while True:
773                     i1, j1 = (
774                         torch.randint(self.height // 2 + 1, (1,)).item(),
775                         torch.randint(self.width // 2 + 1, (1,)).item(),
776                     )
777                     i2, j2 = (
778                         torch.randint(self.height // 2 + 1, (1,)).item(),
779                         torch.randint(self.width // 2 + 1, (1,)).item(),
780                     )
781                     if i1 < i2 and j1 < j2 and min(i2 - i1, j2 - j1) <= 3:
782                         break
783                 X[i + i1 : i + i2, j + j1 : j + j2] = c[0]
784                 f_X[2 * i1 : 2 * i2, 2 * j1 : 2 * j2] = c[0]
785
786             X[i, j] = c[1]
787             f_X[0:2, 0:2] = c[1]
788
789     # @torch.compile
790     def task_symbols(self, A, f_A, B, f_B):
791         nb_rec = 4
792         c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
793         delta = 3
794         for X, f_X in [(A, f_A), (B, f_B)]:
795             while True:
796                 i, j = torch.randint(self.height - delta + 1, (nb_rec,)), torch.randint(
797                     self.width - delta + 1, (nb_rec,)
798                 )
799                 d = (i[None, :] - i[:, None]).abs().max((j[None, :] - j[:, None]).abs())
800                 d.fill_diagonal_(delta + 1)
801                 if d.min() > delta:
802                     break
803
804             for k in range(1, nb_rec):
805                 X[i[k] : i[k] + delta, j[k] : j[k] + delta] = c[k]
806
807             ai, aj = i.float().mean(), j.float().mean()
808
809             q = torch.randint(3, (1,)).item() + 1
810
811             X[i[0] + delta // 2 - 1, j[0] + delta // 2 - 1] = c[0]
812             X[i[0] + delta // 2 - 1, j[0] + delta // 2 + 1] = c[0]
813             X[i[0] + delta // 2 + 1, j[0] + delta // 2 - 1] = c[0]
814             X[i[0] + delta // 2 + 1, j[0] + delta // 2 + 1] = c[0]
815
816             assert i[q] != ai and j[q] != aj
817
818             X[
819                 i[0] + delta // 2 + (i[q] - ai).sign().long(),
820                 j[0] + delta // 2 + (j[q] - aj).sign().long(),
821             ] = c[nb_rec]
822
823             f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q]
824
825     # @torch.compile
826     def task_isometry(self, A, f_A, B, f_B):
827         nb_rec = 3
828         di, dj = torch.randint(3, (2,)) - 1
829         o = torch.tensor([[0.0, 1.0], [-1.0, 0.0]])
830         m = torch.eye(2)
831         for _ in range(torch.randint(4, (1,)).item()):
832             m = m @ o
833         if torch.rand(1) < 0.5:
834             m[0, :] = -m[0, :]
835
836         ci, cj = (self.height - 1) / 2, (self.width - 1) / 2
837
838         for X, f_X in [(A, f_A), (B, f_B)]:
839             while True:
840                 X[...] = 0
841                 f_X[...] = 0
842
843                 c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
844
845                 for r in range(nb_rec):
846                     while True:
847                         i1, i2 = torch.randint(self.height - 2, (2,)) + 1
848                         j1, j2 = torch.randint(self.width - 2, (2,)) + 1
849                         if (
850                             i2 >= i1
851                             and j2 >= j1
852                             and max(i2 - i1, j2 - j1) >= 2
853                             and min(i2 - i1, j2 - j1) <= 3
854                         ):
855                             break
856                     X[i1 : i2 + 1, j1 : j2 + 1] = c[r]
857
858                     i1, j1, i2, j2 = i1 - ci, j1 - cj, i2 - ci, j2 - cj
859
860                     i1, j1 = m[0, 0] * i1 + m[0, 1] * j1, m[1, 0] * i1 + m[1, 1] * j1
861                     i2, j2 = m[0, 0] * i2 + m[0, 1] * j2, m[1, 0] * i2 + m[1, 1] * j2
862
863                     i1, j1, i2, j2 = i1 + ci, j1 + cj, i2 + ci, j2 + cj
864                     i1, i2 = i1.long() + di, i2.long() + di
865                     j1, j2 = j1.long() + dj, j2.long() + dj
866                     if i1 > i2:
867                         i1, i2 = i2, i1
868                     if j1 > j2:
869                         j1, j2 = j2, j1
870
871                     f_X[i1 : i2 + 1, j1 : j2 + 1] = c[r]
872
873                 n = F.one_hot(X.flatten()).sum(dim=0)[1:]
874                 if (
875                     n.sum() > self.height * self.width // 4
876                     and (n > 0).long().sum() == nb_rec
877                 ):
878                     break
879
880     def compute_distance(self, walls, goal_i, goal_j):
881         max_length = walls.numel()
882         dist = torch.full_like(walls, max_length)
883
884         dist[goal_i, goal_j] = 0
885         pred_dist = torch.empty_like(dist)
886
887         while True:
888             pred_dist.copy_(dist)
889             dist[1:-1, 1:-1] = (
890                 torch.cat(
891                     (
892                         dist[None, 1:-1, 1:-1],
893                         dist[None, 1:-1, 0:-2],
894                         dist[None, 2:, 1:-1],
895                         dist[None, 1:-1, 2:],
896                         dist[None, 0:-2, 1:-1],
897                     ),
898                     0,
899                 ).min(dim=0)[0]
900                 + 1
901             )
902
903             dist = walls * max_length + (1 - walls) * dist
904
905             if dist.equal(pred_dist):
906                 return dist * (1 - walls)
907
908     # @torch.compile
909     def task_distance(self, A, f_A, B, f_B):
910         c = torch.randperm(len(self.colors) - 1)[:3] + 1
911         dist0 = torch.empty(self.height + 2, self.width + 2)
912         dist1 = torch.empty(self.height + 2, self.width + 2)
913         for X, f_X in [(A, f_A), (B, f_B)]:
914             nb_rec = torch.randint(3, (1,)).item() + 1
915             while True:
916                 r = self.rec_coo(nb_rec, prevent_overlap=True)
917                 X[...] = 0
918                 f_X[...] = 0
919                 for n in range(nb_rec):
920                     i1, j1, i2, j2 = r[n]
921                     X[i1:i2, j1:j2] = c[0]
922                     f_X[i1:i2, j1:j2] = c[0]
923                 while True:
924                     i0, j0 = (
925                         torch.randint(self.height, (1,)).item(),
926                         torch.randint(self.width, (1,)).item(),
927                     )
928                     if X[i0, j0] == 0:
929                         break
930                 while True:
931                     i1, j1 = (
932                         torch.randint(self.height, (1,)).item(),
933                         torch.randint(self.width, (1,)).item(),
934                     )
935                     if X[i1, j1] == 0:
936                         break
937                 dist1[...] = 1
938                 dist1[1:-1, 1:-1] = (X != 0).long()
939                 dist1[...] = self.compute_distance(dist1, i1 + 1, j1 + 1)
940                 if (
941                     dist1[i0 + 1, j0 + 1] >= 1
942                     and dist1[i0 + 1, j0 + 1] < self.height * 4
943                 ):
944                     break
945
946             dist0[...] = 1
947             dist0[1:-1, 1:-1] = (X != 0).long()
948             dist0[...] = self.compute_distance(dist0, i0 + 1, j0 + 1)
949
950             dist0 = dist0[1:-1, 1:-1]
951             dist1 = dist1[1:-1, 1:-1]
952
953             D = dist1[i0, j0]
954             for d in range(1, D):
955                 M = (dist0 == d) & (dist1 == D - d)
956                 f_X[...] = (1 - M) * f_X + M * c[1]
957
958             X[i0, j0] = c[2]
959             f_X[i0, j0] = c[2]
960             X[i1, j1] = c[2]
961             f_X[i1, j1] = c[2]
962
963     # for X, f_X in [(A, f_A), (B, f_B)]:
964     # n = torch.arange(self.height * self.width).reshape(self.height, self.width)
965     # k = torch.randperm(self.height * self.width)
966     # X[...]=-1
967     # for q in k:
968     # i,j=q%self.height,q//self.height
969     # if
970
971     # @torch.compile
972     def task_puzzle(self, A, f_A, B, f_B):
973         S = 4
974         i0, j0 = (self.height - S) // 2, (self.width - S) // 2
975         c = torch.randperm(len(self.colors) - 1)[:4] + 1
976         for X, f_X in [(A, f_A), (B, f_B)]:
977             while True:
978                 f_X[...] = 0
979                 h = list(torch.randperm(c.size(0)))
980                 n = torch.zeros(c.max() + 1)
981                 for _ in range(2):
982                     k = torch.randperm(S * S)
983                     for q in k:
984                         i, j = q % S + i0, q // S + j0
985                         if f_X[i, j] == 0:
986                             r, s, t, u = (
987                                 f_X[i - 1, j],
988                                 f_X[i, j - 1],
989                                 f_X[i + 1, j],
990                                 f_X[i, j + 1],
991                             )
992                             r, s, t, u = torch.tensor([r, s, t, u])[torch.randperm(4)]
993                             if r > 0 and n[r] < 6:
994                                 n[r] += 1
995                                 f_X[i, j] = r
996                             elif s > 0 and n[s] < 6:
997                                 n[s] += 1
998                                 f_X[i, j] = s
999                             elif t > 0 and n[t] < 6:
1000                                 n[t] += 1
1001                                 f_X[i, j] = t
1002                             elif u > 0 and n[u] < 6:
1003                                 n[u] += 1
1004                                 f_X[i, j] = u
1005                             else:
1006                                 if len(h) > 0:
1007                                     d = c[h.pop()]
1008                                     n[d] += 1
1009                                     f_X[i, j] = d
1010
1011                 if n.sum() == S * S:
1012                     break
1013
1014             k = 0
1015             for d in range(4):
1016                 while True:
1017                     ii, jj = (
1018                         torch.randint(self.height, (1,)).item(),
1019                         torch.randint(self.width, (1,)).item(),
1020                     )
1021                     e = 0
1022                     for i in range(S):
1023                         for j in range(S):
1024                             if (
1025                                 ii + i >= self.height
1026                                 or jj + j >= self.width
1027                                 or (
1028                                     f_X[i + i0, j + j0] == c[d]
1029                                     and X[ii + i, jj + j] > 0
1030                                 )
1031                             ):
1032                                 e = 1
1033                     if e == 0:
1034                         break
1035                 for i in range(S):
1036                     for j in range(S):
1037                         if f_X[i + i0, j + j0] == c[d]:
1038                             X[ii + i, jj + j] = c[d]
1039
1040     def task_islands(self, A, f_A, B, f_B):
1041         c = torch.randperm(len(self.colors) - 1)[:2] + 1
1042         for X, f_X in [(A, f_A), (B, f_B)]:
1043             if not hasattr(self, "cache_islands") or len(self.cache_islands) == 0:
1044                 self.cache_islands = list(
1045                     grow_islands(
1046                         1000,
1047                         self.height,
1048                         self.width,
1049                         nb_seeds=self.height * self.width // 20,
1050                         nb_iterations=self.height * self.width // 2,
1051                     )
1052                 )
1053
1054             A = self.cache_islands.pop()
1055
1056             while True:
1057                 i, j = (
1058                     torch.randint(self.height // 2, (1,)).item(),
1059                     torch.randint(self.width // 2, (1,)).item(),
1060                 )
1061                 if A[i, j] > 0:
1062                     break
1063
1064             X[...] = (A > 0) * c[0]
1065             X[i, j] = c[1]
1066             f_X[...] = (A == A[i, j]) * c[1] + ((A > 0) & (A != A[i, j])) * c[0]
1067
1068     ######################################################################
1069
1070     def trivial_prompts_and_answers(self, prompts, answers):
1071         S = self.height * self.width
1072         Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S]
1073         f_Bs = answers
1074         return (Bs == f_Bs).long().min(dim=-1).values > 0
1075
1076     def generate_prompts_and_answers_(self, nb, tasks=None, progress_bar=False):
1077         if tasks is None:
1078             tasks = self.all_tasks
1079
1080         S = self.height * self.width
1081         prompts = torch.zeros(nb, 3 * S + 2, dtype=torch.int64)
1082         answers = torch.zeros(nb, S, dtype=torch.int64)
1083
1084         bunch = zip(prompts, answers)
1085
1086         if progress_bar:
1087             bunch = tqdm.tqdm(
1088                 bunch,
1089                 dynamic_ncols=True,
1090                 desc="world generation",
1091                 total=prompts.size(0),
1092             )
1093
1094         for prompt, answer in bunch:
1095             A = prompt[0 * (S + 1) : 0 * (S + 1) + S].view(self.height, self.width)
1096             f_A = prompt[1 * (S + 1) : 1 * (S + 1) + S].view(self.height, self.width)
1097             B = prompt[2 * (S + 1) : 2 * (S + 1) + S].view(self.height, self.width)
1098             f_B = answer.view(self.height, self.width)
1099             task = tasks[torch.randint(len(tasks), (1,)).item()]
1100             task(A, f_A, B, f_B)
1101
1102         return prompts.flatten(1), answers.flatten(1)
1103
1104     def save_quiz_illustrations(
1105         self,
1106         result_dir,
1107         filename_prefix,
1108         prompts,
1109         answers,
1110         predicted_prompts=None,
1111         predicted_answers=None,
1112         nrow=4,
1113     ):
1114         self.save_image(
1115             result_dir,
1116             filename_prefix + ".png",
1117             prompts,
1118             answers,
1119             predicted_prompts,
1120             predicted_answers,
1121             nrow,
1122         )
1123
1124     def save_some_examples(self, result_dir):
1125         nb, nrow = 72, 4
1126         for t in self.all_tasks:
1127             print(t.__name__)
1128             prompts, answers = self.generate_prompts_and_answers_(nb, tasks=[t])
1129             self.save_quiz_illustrations(
1130                 result_dir, t.__name__, prompts[:nb], answers[:nb], nrow=nrow
1131             )
1132
1133
1134 ######################################################################
1135
1136 if __name__ == "__main__":
1137     import time
1138
1139     # grids = Grids(max_nb_cached_chunks=5, chunk_size=100, nb_threads=4)
1140     grids = Grids()
1141
1142     # nb = 1000
1143     # grids = problem.MultiThreadProblem(
1144     # grids, max_nb_cached_chunks=50, chunk_size=100, nb_threads=1
1145     # )
1146     #    time.sleep(10)
1147     # start_time = time.perf_counter()
1148     # prompts, answers = grids.generate_prompts_and_answers(nb)
1149     # delay = time.perf_counter() - start_time
1150     # print(f"{prompts.size(0)/delay:02f} seq/s")
1151     # exit(0)
1152
1153     # if True:
1154     nb, nrow = 72, 4
1155     # nb, nrow = 8, 2
1156
1157     # for t in grids.all_tasks:
1158     for t in [grids.task_distance]:
1159         print(t.__name__)
1160         prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
1161         grids.save_quiz_illustrations(
1162             "/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow
1163         )
1164
1165     # exit(0)
1166
1167     nb = 1000
1168
1169     # for t in grids.all_tasks:
1170     for t in [grids.task_distance]:
1171         start_time = time.perf_counter()
1172         prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
1173         delay = time.perf_counter() - start_time
1174         print(f"{t.__name__} {prompts.size(0)/delay:02f} seq/s")
1175
1176     exit(0)
1177
1178     m = torch.randint(2, (prompts.size(0),))
1179     predicted_prompts = m * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
1180     predicted_answers = (1 - m) * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
1181
1182     grids.save_quiz_illustrations(
1183         "/tmp",
1184         "test",
1185         prompts[:nb],
1186         answers[:nb],
1187         # You can add a bool to put a frame around the predicted parts
1188         predicted_prompts[:nb],
1189         predicted_answers[:nb],
1190     )