aa21543f849a927c41871c34897090693e9eeb32
[culture.git] / grids.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, tqdm, os, warnings
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 ######################################################################
16
17 import problem
18
19
20 def grow_islands(nb, height, width, nb_seeds, nb_iterations):
21     w = torch.empty(5, 1, 3, 3)
22
23     w[0, 0] = torch.tensor(
24         [
25             [1.0, 1.0, 1.0],
26             [1.0, 0.0, 1.0],
27             [1.0, 1.0, 1.0],
28         ]
29     )
30
31     w[1, 0] = torch.tensor(
32         [
33             [-1.0, 1.0, 0.0],
34             [1.0, 0.0, 0.0],
35             [0.0, 0.0, 0.0],
36         ]
37     )
38
39     w[2, 0] = torch.tensor(
40         [
41             [0.0, 1.0, -1.0],
42             [0.0, 0.0, 1.0],
43             [0.0, 0.0, 0.0],
44         ]
45     )
46
47     w[3, 0] = torch.tensor(
48         [
49             [0.0, 0.0, 0.0],
50             [0.0, 0.0, 1.0],
51             [0.0, 1.0, -1.0],
52         ]
53     )
54
55     w[4, 0] = torch.tensor(
56         [
57             [0.0, 0.0, 0.0],
58             [1.0, 0.0, 0.0],
59             [-1.0, 1.0, 0.0],
60         ]
61     )
62
63     Z = torch.zeros(nb, height, width)
64     U = Z.flatten(1)
65
66     for _ in range(nb_seeds):
67         M = F.conv2d(Z[:, None, :, :], w, padding=1)
68         M = torch.cat([M[:, :1], M[:, 1:].min(dim=1, keepdim=True).values], dim=1)
69         M = ((M[:, 0] == 0) & (Z == 0)).long()
70         Q = (M.flatten(1).max(dim=1).values > 0).long()[:, None]
71         M = M * torch.rand(M.size())
72         M = M.flatten(1)
73         M = F.one_hot(M.argmax(dim=1), num_classes=M.size(1))
74         U += M * Q
75
76     for _ in range(nb_iterations):
77         M = F.conv2d(Z[:, None, :, :], w, padding=1)
78         M = torch.cat([M[:, :1], M[:, 1:].min(dim=1, keepdim=True).values], dim=1)
79         M = ((M[:, 1] >= 0) & (Z == 0)).long()
80         Q = (M.flatten(1).max(dim=1).values > 0).long()[:, None]
81         M = M * torch.rand(M.size())
82         M = M.flatten(1)
83         M = F.one_hot(M.argmax(dim=1), num_classes=M.size(1))
84         U = Z.flatten(1)
85         U += M * Q
86
87     M = Z.clone()
88     Z = Z * (torch.arange(Z.size(1) * Z.size(2)) + 1).reshape(1, Z.size(1), Z.size(2))
89
90     while True:
91         W = Z.clone()
92         Z = F.max_pool2d(Z, 3, 1, 1) * M
93         if Z.equal(W):
94             break
95
96     Z = Z.long()
97     U = Z.flatten(1)
98     V = F.one_hot(U).max(dim=1).values
99     W = V.cumsum(dim=1) - V
100     N = torch.arange(Z.size(0))[:, None, None].expand_as(Z)
101     Z = W[N, Z]
102
103     return Z
104
105
106 class Grids(problem.Problem):
107     named_colors = [
108         ("white", [255, 255, 255]),
109         ("red", [255, 0, 0]),
110         ("green", [0, 192, 0]),
111         ("blue", [0, 0, 255]),
112         ("yellow", [255, 224, 0]),
113         ("cyan", [0, 255, 255]),
114         ("violet", [224, 128, 255]),
115         ("lightgreen", [192, 255, 192]),
116         ("brown", [165, 42, 42]),
117         ("lightblue", [192, 192, 255]),
118         ("gray", [128, 128, 128]),
119     ]
120
121     def __init__(
122         self,
123         max_nb_cached_chunks=None,
124         chunk_size=None,
125         nb_threads=-1,
126         tasks=None,
127     ):
128         self.colors = torch.tensor([c for _, c in self.named_colors])
129         self.height = 10
130         self.width = 10
131         self.cache_rec_coo = {}
132
133         all_tasks = [
134             self.task_replace_color,
135             self.task_translate,
136             self.task_grow,
137             self.task_half_fill,
138             self.task_frame,
139             self.task_detect,
140             self.task_count,
141             self.task_trajectory,
142             self.task_bounce,
143             self.task_scale,
144             self.task_symbols,
145             self.task_isometry,
146             #            self.task_path,
147         ]
148
149         if tasks is None:
150             self.all_tasks = all_tasks
151         else:
152             self.all_tasks = [getattr(self, "task_" + t) for t in tasks.split(",")]
153
154         super().__init__(max_nb_cached_chunks, chunk_size, nb_threads)
155
156     ######################################################################
157
158     def frame2img(self, x, scale=15):
159         x = x.reshape(x.size(0), self.height, -1)
160         m = torch.logical_and(x >= 0, x < self.nb_token_values()).long()
161         x = self.colors[x * m].permute(0, 3, 1, 2)
162         s = x.shape
163         x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
164         x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
165
166         x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
167         x[:, :, torch.arange(0, x.size(2), scale), :] = 0
168         x = x[:, :, 1:, 1:]
169
170         for n in range(m.size(0)):
171             for i in range(m.size(1)):
172                 for j in range(m.size(2)):
173                     if m[n, i, j] == 0:
174                         for k in range(2, scale - 2):
175                             for l in [0, 1]:
176                                 x[n, :, i * scale + k, j * scale + k - l] = 0
177                                 x[
178                                     n, :, i * scale + scale - 1 - k, j * scale + k - l
179                                 ] = 0
180
181         return x
182
183     def save_image(
184         self,
185         result_dir,
186         filename,
187         prompts,
188         answers,
189         predicted_prompts=None,
190         predicted_answers=None,
191         nrow=4,
192         margin=8,
193     ):
194         S = self.height * self.width
195         As = prompts[:, 0 * (S + 1) : 0 * (S + 1) + S].view(-1, self.height, self.width)
196         f_As = prompts[:, 1 * (S + 1) : 1 * (S + 1) + S].view(
197             -1, self.height, self.width
198         )
199         Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S].view(-1, self.height, self.width)
200         prompts = torch.cat([As, f_As, Bs], dim=2)
201         answers = answers.reshape(answers.size(0), self.height, self.width)
202
203         if predicted_prompts is None:
204             predicted_prompts = 255
205
206         if predicted_answers is None:
207             predicted_answers = 255
208
209         def add_frame(x, c, margin, bottom=False):
210             if bottom:
211                 h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0
212             else:
213                 h, w, di, dj = (
214                     x.size(2) + 2 * margin,
215                     x.size(3) + 2 * margin,
216                     margin,
217                     margin,
218                 )
219
220             y = x.new_full((x.size(0), x.size(1), h, w), 0)
221
222             if type(c) is int:
223                 y[...] = c
224             else:
225                 c = c.long()[:, None]
226                 c = (
227                     (1 - ((c == 1).long() + (c == 0).long() + (c == -1).long()))
228                     * torch.tensor([64, 64, 64])
229                     + (c == 1).long() * torch.tensor([0, 255, 0])
230                     + (c == 0).long() * torch.tensor([255, 255, 255])
231                     + (c == -1).long() * torch.tensor([255, 0, 0])
232                 )
233                 y[...] = c[:, :, None, None]
234
235             y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x
236
237             return y
238
239         img_prompts = torch.cat(
240             [
241                 add_frame(
242                     add_frame(self.frame2img(x), c=0, margin=1),
243                     c=predicted_prompts,
244                     margin=margin,
245                 )
246                 for x in prompts.to("cpu").split(split_size=self.width, dim=2)
247             ],
248             dim=3,
249         )
250
251         h = img_prompts.size(2)
252         img_answers = add_frame(
253             add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1),
254             c=predicted_answers,
255             margin=margin,
256         )
257
258         separator_size = 2 * margin
259
260         separator = img_prompts.new_full(
261             (
262                 img_prompts.size(0),
263                 img_prompts.size(1),
264                 img_prompts.size(2),
265                 separator_size,
266             ),
267             255,
268         )
269
270         marker = img_prompts.new_full(
271             (
272                 img_prompts.size(0),
273                 img_prompts.size(1),
274                 img_prompts.size(2),
275                 separator_size,
276             ),
277             255,
278         )
279
280         # marker[:, :, 0] = 0
281         # marker[:, :, h - 1] = 0
282
283         for k in range(1, 2 * separator_size - 8):
284             i = k - (separator_size - 4)
285             j = separator_size - 5 - abs(i)
286             marker[:, :, h // 2 - 1 + i, 2 + j] = 0
287             marker[:, :, h // 2 - 1 + i + 1, 2 + j] = 0
288
289         img = torch.cat(
290             [
291                 img_prompts,
292                 marker,
293                 img_answers,
294             ],
295             dim=3,
296         )
297
298         image_name = os.path.join(result_dir, filename)
299         torchvision.utils.save_image(
300             img.float() / 255.0,
301             image_name,
302             nrow=nrow,
303             padding=margin * 4,
304             pad_value=1.0,
305         )
306
307     ######################################################################
308
309     def nb_token_values(self):
310         return len(self.colors)
311
312     # @torch.compile
313     def rec_coo(
314         self,
315         nb_rec,
316         min_height=3,
317         min_width=3,
318         surface_max=None,
319         prevent_overlap=False,
320     ):
321         if surface_max is None:
322             surface_max = self.height * self.width // 2
323
324         signature = (nb_rec, min_height, min_width, surface_max)
325
326         try:
327             return self.cache_rec_coo[signature].pop()
328         except IndexError:
329             pass
330         except KeyError:
331             pass
332
333         N = 10000
334         while True:
335             while True:
336                 i = torch.randint(self.height, (N * nb_rec, 2)).sort(dim=-1).values
337                 j = torch.randint(self.width, (N * nb_rec, 2)).sort(dim=-1).values
338
339                 big_enough = (
340                     (i[:, 1] >= i[:, 0] + min_height)
341                     & (j[:, 1] >= j[:, 0] + min_height)
342                     & ((i[:, 1] - i[:, 0]) * (j[:, 1] - j[:, 0]) <= surface_max)
343                 )
344
345                 i, j = i[big_enough], j[big_enough]
346
347                 n = i.size(0) - i.size(0) % nb_rec
348
349                 if n > 0:
350                     break
351
352             i = i[:n].reshape(n // nb_rec, nb_rec, -1)
353             j = j[:n].reshape(n // nb_rec, nb_rec, -1)
354
355             if prevent_overlap:
356                 can_fit = ((i[:, :, 1] - i[:, :, 0]) * (j[:, :, 1] - j[:, :, 0])).sum(
357                     dim=-1
358                 ) <= self.height * self.width
359                 i, j = i[can_fit], j[can_fit]
360                 if nb_rec == 2:
361                     A_i1, A_i2, A_j1, A_j2 = (
362                         i[:, 0, 0],
363                         i[:, 0, 1],
364                         j[:, 0, 0],
365                         j[:, 0, 1],
366                     )
367                     B_i1, B_i2, B_j1, B_j2 = (
368                         i[:, 1, 0],
369                         i[:, 1, 1],
370                         j[:, 1, 0],
371                         j[:, 1, 1],
372                     )
373                     no_overlap = torch.logical_not(
374                         (A_i1 >= B_i2)
375                         & (A_i2 <= B_i1)
376                         & (A_j1 >= B_j1)
377                         & (A_j2 <= B_j1)
378                     )
379                     i, j = i[no_overlap], j[no_overlap]
380                 elif nb_rec == 3:
381                     A_i1, A_i2, A_j1, A_j2 = (
382                         i[:, 0, 0],
383                         i[:, 0, 1],
384                         j[:, 0, 0],
385                         j[:, 0, 1],
386                     )
387                     B_i1, B_i2, B_j1, B_j2 = (
388                         i[:, 1, 0],
389                         i[:, 1, 1],
390                         j[:, 1, 0],
391                         j[:, 1, 1],
392                     )
393                     C_i1, C_i2, C_j1, C_j2 = (
394                         i[:, 2, 0],
395                         i[:, 2, 1],
396                         j[:, 2, 0],
397                         j[:, 2, 1],
398                     )
399                     no_overlap = (
400                         (
401                             (A_i1 >= B_i2)
402                             | (A_i2 <= B_i1)
403                             | (A_j1 >= B_j2)
404                             | (A_j2 <= B_j1)
405                         )
406                         & (
407                             (A_i1 >= C_i2)
408                             | (A_i2 <= C_i1)
409                             | (A_j1 >= C_j2)
410                             | (A_j2 <= C_j1)
411                         )
412                         & (
413                             (B_i1 >= C_i2)
414                             | (B_i2 <= C_i1)
415                             | (B_j1 >= C_j2)
416                             | (B_j2 <= C_j1)
417                         )
418                     )
419                     i, j = (i[no_overlap], j[no_overlap])
420                 else:
421                     assert nb_rec == 1
422
423             if i.size(0) > 1:
424                 break
425
426         self.cache_rec_coo[signature] = [
427             [
428                 (
429                     i[n, k, 0].item(),
430                     j[n, k, 0].item(),
431                     i[n, k, 1].item(),
432                     j[n, k, 1].item(),
433                 )
434                 for k in range(nb_rec)
435             ]
436             for n in range(i.size(0))
437         ]
438
439         return self.cache_rec_coo[signature].pop()
440
441     ######################################################################
442
443     # @torch.compile
444     def task_replace_color(self, A, f_A, B, f_B):
445         nb_rec = 3
446         c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
447         for X, f_X in [(A, f_A), (B, f_B)]:
448             r = self.rec_coo(nb_rec, prevent_overlap=True)
449             for n in range(nb_rec):
450                 i1, j1, i2, j2 = r[n]
451                 X[i1:i2, j1:j2] = c[n]
452                 f_X[i1:i2, j1:j2] = c[n if n > 0 else -1]
453
454     # @torch.compile
455     def task_translate(self, A, f_A, B, f_B):
456         while True:
457             di, dj = torch.randint(3, (2,)) - 1
458             if di.abs() + dj.abs() > 0:
459                 break
460
461         nb_rec = 3
462         c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
463         for X, f_X in [(A, f_A), (B, f_B)]:
464             while True:
465                 r = self.rec_coo(nb_rec, prevent_overlap=True)
466                 i1, j1, i2, j2 = r[nb_rec - 1]
467                 if (
468                     i1 + di >= 0
469                     and i2 + di < X.size(0)
470                     and j1 + dj >= 0
471                     and j2 + dj < X.size(1)
472                 ):
473                     break
474
475             for n in range(nb_rec):
476                 i1, j1, i2, j2 = r[n]
477                 X[i1:i2, j1:j2] = c[n]
478                 if n == nb_rec - 1:
479                     f_X[i1 + di : i2 + di, j1 + dj : j2 + dj] = c[n]
480                 else:
481                     f_X[i1:i2, j1:j2] = c[n]
482
483     # @torch.compile
484     def task_grow(self, A, f_A, B, f_B):
485         di, dj = torch.randint(2, (2,)) * 2 - 1
486         nb_rec = 3
487         c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
488         direction = torch.randint(2, (1,)).item()
489         for X, f_X in [(A, f_A), (B, f_B)]:
490             while True:
491                 r = self.rec_coo(nb_rec, prevent_overlap=True)
492                 i1, j1, i2, j2 = r[nb_rec - 1]
493                 if i1 + 3 < i2 and j1 + 3 < j2:
494                     break
495
496             for n in range(nb_rec):
497                 i1, j1, i2, j2 = r[n]
498                 if n == nb_rec - 1:
499                     if direction == 0:
500                         X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
501                         f_X[i1:i2, j1:j2] = c[n]
502                     else:
503                         X[i1:i2, j1:j2] = c[n]
504                         f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
505                 else:
506                     X[i1:i2, j1:j2] = c[n]
507                     f_X[i1:i2, j1:j2] = c[n]
508
509     # @torch.compile
510     def task_half_fill(self, A, f_A, B, f_B):
511         di, dj = torch.randint(2, (2,)) * 2 - 1
512         nb_rec = 3
513         c = torch.randperm(len(self.colors) - 1)[: 2 * nb_rec] + 1
514         direction = torch.randint(4, (1,)).item()
515         for X, f_X in [(A, f_A), (B, f_B)]:
516             r = self.rec_coo(nb_rec, prevent_overlap=True)
517             for n in range(nb_rec):
518                 i1, j1, i2, j2 = r[n]
519                 X[i1:i2, j1:j2] = c[2 * n]
520                 f_X[i1:i2, j1:j2] = c[2 * n]
521                 # Not my proudest moment
522                 if direction == 0:
523                     i = (i1 + i2) // 2
524                     X[i : i + 1, j1:j2] = c[2 * n + 1]
525                     if n == nb_rec - 1:
526                         f_X[i:i2, j1:j2] = c[2 * n + 1]
527                     else:
528                         f_X[i : i + 1, j1:j2] = c[2 * n + 1]
529                 elif direction == 1:
530                     i = (i1 + i2 - 1) // 2
531                     X[i : i + 1, j1:j2] = c[2 * n + 1]
532                     if n == nb_rec - 1:
533                         f_X[i1 : i + 1, j1:j2] = c[2 * n + 1]
534                     else:
535                         f_X[i : i + 1, j1:j2] = c[2 * n + 1]
536                 elif direction == 2:
537                     j = (j1 + j2) // 2
538                     X[i1:i2, j : j + 1] = c[2 * n + 1]
539                     if n == nb_rec - 1:
540                         f_X[i1:i2, j:j2] = c[2 * n + 1]
541                     else:
542                         f_X[i1:i2, j : j + 1] = c[2 * n + 1]
543                 elif direction == 3:
544                     j = (j1 + j2 - 1) // 2
545                     X[i1:i2, j : j + 1] = c[2 * n + 1]
546                     if n == nb_rec - 1:
547                         f_X[i1:i2, j1 : j + 1] = c[2 * n + 1]
548                     else:
549                         f_X[i1:i2, j : j + 1] = c[2 * n + 1]
550
551     # @torch.compile
552     def task_frame(self, A, f_A, B, f_B):
553         nb_rec = 3
554         c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
555         for X, f_X in [(A, f_A), (B, f_B)]:
556             r = self.rec_coo(nb_rec, prevent_overlap=True)
557             for n in range(nb_rec):
558                 i1, j1, i2, j2 = r[n]
559                 X[i1:i2, j1:j2] = c[n]
560                 if n == nb_rec - 1:
561                     f_X[i1:i2, j1] = c[n]
562                     f_X[i1:i2, j2 - 1] = c[n]
563                     f_X[i1, j1:j2] = c[n]
564                     f_X[i2 - 1, j1:j2] = c[n]
565                 else:
566                     f_X[i1:i2, j1:j2] = c[n]
567
568     # @torch.compile
569     def task_detect(self, A, f_A, B, f_B):
570         nb_rec = 3
571         c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
572         for X, f_X in [(A, f_A), (B, f_B)]:
573             r = self.rec_coo(nb_rec, prevent_overlap=True)
574             for n in range(nb_rec):
575                 i1, j1, i2, j2 = r[n]
576                 X[i1:i2, j1:j2] = c[n]
577                 if n < nb_rec - 1:
578                     f_X[i1, j1] = c[-1]
579
580     # @torch.compile
581     def contact(self, X, i, j, q):
582         nq, nq_diag = 0, 0
583         no = 0
584
585         for ii, jj in [
586             (i - 1, j - 1),
587             (i - 1, j),
588             (i - 1, j + 1),
589             (i, j - 1),
590             (i, j + 1),
591             (i + 1, j - 1),
592             (i + 1, j),
593             (i + 1, j + 1),
594         ]:
595             if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width:
596                 if X[ii, jj] != 0 and X[ii, jj] != q:
597                     no += 1
598
599         for ii, jj in [
600             (i - 1, j - 1),
601             (i - 1, j + 1),
602             (i + 1, j - 1),
603             (i + 1, j + 1),
604         ]:
605             if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width:
606                 if X[ii, jj] == q and X[i, jj] != q and X[ii, j] != q:
607                     nq_diag += 1
608
609         for ii, jj in [(i - 1, j), (i, j - 1), (i, j + 1), (i + 1, j)]:
610             if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width:
611                 if X[ii, jj] == q:
612                     nq += 1
613
614         return no, nq, nq_diag
615
616     def task_count(self, A, f_A, B, f_B):
617         while True:
618             error = False
619
620             N = torch.randint(5, (1,)).item() + 1
621             c = torch.zeros(N + 1)
622             c[1:] = torch.randperm(len(self.colors) - 1)[:N] + 1
623
624             for X, f_X in [(A, f_A), (B, f_B)]:
625                 if not hasattr(self, "cache_count") or len(self.cache_count) == 0:
626                     self.cache_count = list(
627                         grow_islands(
628                             1000,
629                             self.height,
630                             self.width,
631                             nb_seeds=self.height * self.width // 9,
632                             nb_iterations=self.height * self.width // 20,
633                         )
634                     )
635
636                 X[...] = self.cache_count.pop()
637
638                 k = (X.max() + 1 + (c.size(0) - 1)).item()
639                 V = torch.arange(k) // (c.size(0) - 1)
640                 V = (V + torch.rand(V.size())).sort().indices[: X.max() + 1] % (
641                     c.size(0) - 1
642                 ) + 1
643                 V[0] = 0
644                 X[...] = c[V[X]]
645
646                 if F.one_hot(X.flatten()).max(dim=0).values.sum().item() == N + 1:
647                     f_X[...] = 0
648                     for e in range(1, N + 1):
649                         for j in range((X == c[e]).sum() + 1):
650                             if j < self.width:
651                                 f_X[e - 1, j] = c[e]
652                             else:
653                                 error = True
654                                 break
655                 else:
656                     error = True
657                     break
658
659             if not error:
660                 break
661
662     # @torch.compile
663     def task_trajectory(self, A, f_A, B, f_B):
664         c = torch.randperm(len(self.colors) - 1)[:2] + 1
665         for X, f_X in [(A, f_A), (B, f_B)]:
666             while True:
667                 di, dj = torch.randint(7, (2,)) - 3
668                 i, j = (
669                     torch.randint(self.height, (1,)).item(),
670                     torch.randint(self.width, (1,)).item(),
671                 )
672                 if (
673                     abs(di) + abs(dj) > 0
674                     and i + 2 * di >= 0
675                     and i + 2 * di < self.height
676                     and j + 2 * dj >= 0
677                     and j + 2 * dj < self.width
678                 ):
679                     break
680
681             k = 0
682             while (
683                 i + k * di >= 0
684                 and i + k * di < self.height
685                 and j + k * dj >= 0
686                 and j + k * dj < self.width
687             ):
688                 if k < 2:
689                     X[i + k * di, j + k * dj] = c[k]
690                 f_X[i + k * di, j + k * dj] = c[min(k, 1)]
691                 k += 1
692
693     # @torch.compile
694     def task_bounce(self, A, f_A, B, f_B):
695         c = torch.randperm(len(self.colors) - 1)[:3] + 1
696         for X, f_X in [(A, f_A), (B, f_B)]:
697             # @torch.compile
698             def free(i, j):
699                 return (
700                     i >= 0
701                     and i < self.height
702                     and j >= 0
703                     and j < self.width
704                     and f_X[i, j] == 0
705                 )
706
707             while True:
708                 f_X[...] = 0
709                 X[...] = 0
710
711                 for _ in range((self.height * self.width) // 10):
712                     i, j = (
713                         torch.randint(self.height, (1,)).item(),
714                         torch.randint(self.width, (1,)).item(),
715                     )
716                     X[i, j] = c[0]
717                     f_X[i, j] = c[0]
718
719                 while True:
720                     di, dj = torch.randint(7, (2,)) - 3
721                     if abs(di) + abs(dj) == 1:
722                         break
723
724                 i, j = (
725                     torch.randint(self.height, (1,)).item(),
726                     torch.randint(self.width, (1,)).item(),
727                 )
728
729                 X[i, j] = c[1]
730                 f_X[i, j] = c[1]
731                 l = 0
732
733                 while True:
734                     l += 1
735                     if free(i + di, j + dj):
736                         pass
737                     elif free(i - dj, j + di):
738                         di, dj = -dj, di
739                         if free(i + dj, j - di):
740                             if torch.rand(1) < 0.5:
741                                 di, dj = -di, -dj
742                     elif free(i + dj, j - di):
743                         di, dj = dj, -di
744                     else:
745                         break
746
747                     i, j = i + di, j + dj
748                     f_X[i, j] = c[2]
749                     if l <= 1:
750                         X[i, j] = c[2]
751
752                     if l >= self.width:
753                         break
754
755                 f_X[i, j] = c[1]
756                 X[i, j] = c[1]
757
758                 if l > 3:
759                     break
760
761     # @torch.compile
762     def task_scale(self, A, f_A, B, f_B):
763         c = torch.randperm(len(self.colors) - 1)[:2] + 1
764
765         i, j = (
766             torch.randint(self.height // 2, (1,)).item(),
767             torch.randint(self.width // 2, (1,)).item(),
768         )
769
770         for X, f_X in [(A, f_A), (B, f_B)]:
771             for _ in range(3):
772                 while True:
773                     i1, j1 = (
774                         torch.randint(self.height // 2 + 1, (1,)).item(),
775                         torch.randint(self.width // 2 + 1, (1,)).item(),
776                     )
777                     i2, j2 = (
778                         torch.randint(self.height // 2 + 1, (1,)).item(),
779                         torch.randint(self.width // 2 + 1, (1,)).item(),
780                     )
781                     if i1 < i2 and j1 < j2 and min(i2 - i1, j2 - j1) <= 3:
782                         break
783                 X[i + i1 : i + i2, j + j1 : j + j2] = c[0]
784                 f_X[2 * i1 : 2 * i2, 2 * j1 : 2 * j2] = c[0]
785
786             X[i, j] = c[1]
787             f_X[0:2, 0:2] = c[1]
788
789     # @torch.compile
790     def task_symbols(self, A, f_A, B, f_B):
791         nb_rec = 4
792         c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
793         delta = 3
794         for X, f_X in [(A, f_A), (B, f_B)]:
795             while True:
796                 i, j = torch.randint(self.height - delta + 1, (nb_rec,)), torch.randint(
797                     self.width - delta + 1, (nb_rec,)
798                 )
799                 d = (i[None, :] - i[:, None]).abs().max((j[None, :] - j[:, None]).abs())
800                 d.fill_diagonal_(delta + 1)
801                 if d.min() > delta:
802                     break
803
804             for k in range(1, nb_rec):
805                 X[i[k] : i[k] + delta, j[k] : j[k] + delta] = c[k]
806
807             ai, aj = i.float().mean(), j.float().mean()
808
809             q = torch.randint(3, (1,)).item() + 1
810
811             X[i[0] + delta // 2 - 1, j[0] + delta // 2 - 1] = c[0]
812             X[i[0] + delta // 2 - 1, j[0] + delta // 2 + 1] = c[0]
813             X[i[0] + delta // 2 + 1, j[0] + delta // 2 - 1] = c[0]
814             X[i[0] + delta // 2 + 1, j[0] + delta // 2 + 1] = c[0]
815
816             assert i[q] != ai and j[q] != aj
817
818             X[
819                 i[0] + delta // 2 + (i[q] - ai).sign().long(),
820                 j[0] + delta // 2 + (j[q] - aj).sign().long(),
821             ] = c[nb_rec]
822
823             f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q]
824
825     # @torch.compile
826     def task_isometry(self, A, f_A, B, f_B):
827         nb_rec = 3
828         di, dj = torch.randint(3, (2,)) - 1
829         o = torch.tensor([[0.0, 1.0], [-1.0, 0.0]])
830         m = torch.eye(2)
831         for _ in range(torch.randint(4, (1,)).item()):
832             m = m @ o
833         if torch.rand(1) < 0.5:
834             m[0, :] = -m[0, :]
835
836         ci, cj = (self.height - 1) / 2, (self.width - 1) / 2
837
838         for X, f_X in [(A, f_A), (B, f_B)]:
839             while True:
840                 X[...] = 0
841                 f_X[...] = 0
842
843                 c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
844
845                 for r in range(nb_rec):
846                     while True:
847                         i1, i2 = torch.randint(self.height - 2, (2,)) + 1
848                         j1, j2 = torch.randint(self.width - 2, (2,)) + 1
849                         if (
850                             i2 >= i1
851                             and j2 >= j1
852                             and max(i2 - i1, j2 - j1) >= 2
853                             and min(i2 - i1, j2 - j1) <= 3
854                         ):
855                             break
856                     X[i1 : i2 + 1, j1 : j2 + 1] = c[r]
857
858                     i1, j1, i2, j2 = i1 - ci, j1 - cj, i2 - ci, j2 - cj
859
860                     i1, j1 = m[0, 0] * i1 + m[0, 1] * j1, m[1, 0] * i1 + m[1, 1] * j1
861                     i2, j2 = m[0, 0] * i2 + m[0, 1] * j2, m[1, 0] * i2 + m[1, 1] * j2
862
863                     i1, j1, i2, j2 = i1 + ci, j1 + cj, i2 + ci, j2 + cj
864                     i1, i2 = i1.long() + di, i2.long() + di
865                     j1, j2 = j1.long() + dj, j2.long() + dj
866                     if i1 > i2:
867                         i1, i2 = i2, i1
868                     if j1 > j2:
869                         j1, j2 = j2, j1
870
871                     f_X[i1 : i2 + 1, j1 : j2 + 1] = c[r]
872
873                 n = F.one_hot(X.flatten()).sum(dim=0)[1:]
874                 if (
875                     n.sum() > self.height * self.width // 4
876                     and (n > 0).long().sum() == nb_rec
877                 ):
878                     break
879
880     def compute_distance(self, walls, goal_i, goal_j, start_i, start_j):
881         max_length = walls.numel()
882         dist = torch.full_like(walls, max_length)
883
884         dist[goal_i, goal_j] = 0
885         pred_dist = torch.empty_like(dist)
886
887         while True:
888             pred_dist.copy_(dist)
889             d = (
890                 torch.cat(
891                     (
892                         dist[None, 1:-1, 0:-2],
893                         dist[None, 2:, 1:-1],
894                         dist[None, 1:-1, 2:],
895                         dist[None, 0:-2, 1:-1],
896                     ),
897                     0,
898                 ).min(dim=0)[0]
899                 + 1
900             )
901
902             dist[1:-1, 1:-1].minimum_(d)  # = torch.min(dist[1:-1, 1:-1], d)
903             dist = walls * max_length + (1 - walls) * dist
904
905             if dist[start_i, start_j] < max_length or dist.equal(pred_dist):
906                 return dist * (1 - walls)
907
908     # @torch.compile
909     def task_path(self, A, f_A, B, f_B):
910         c = torch.randperm(len(self.colors) - 1)[:3] + 1
911         dist = torch.empty(self.height + 2, self.width + 2)
912         for X, f_X in [(A, f_A), (B, f_B)]:
913             nb_rec = torch.randint(3, (1,)).item() + 1
914             while True:
915                 r = self.rec_coo(nb_rec, prevent_overlap=True)
916                 X[...] = 0
917                 f_X[...] = 0
918                 for n in range(nb_rec):
919                     i1, j1, i2, j2 = r[n]
920                     X[i1:i2, j1:j2] = c[0]
921                     f_X[i1:i2, j1:j2] = c[0]
922                 while True:
923                     i0, j0 = (
924                         torch.randint(self.height, (1,)).item(),
925                         torch.randint(self.width, (1,)).item(),
926                     )
927                     if X[i0, j0] == 0:
928                         break
929                 while True:
930                     i1, j1 = (
931                         torch.randint(self.height, (1,)).item(),
932                         torch.randint(self.width, (1,)).item(),
933                     )
934                     if X[i1, j1] == 0:
935                         break
936                 dist[...] = 1
937                 dist[1:-1, 1:-1] = (X != 0).long()
938                 dist[...] = self.compute_distance(dist, i1 + 1, j1 + 1, i0 + 1, j0 + 1)
939                 if dist[i0 + 1, j0 + 1] >= 1 and dist[i0 + 1, j0 + 1] < self.height * 4:
940                     break
941
942             dist[1:-1, 1:-1] += (X != 0).long() * self.height * self.width
943             dist[0, :] = self.height * self.width
944             dist[-1, :] = self.height * self.width
945             dist[:, 0] = self.height * self.width
946             dist[:, -1] = self.height * self.width
947             # dist += torch.rand(dist.size())
948
949             i, j = i0 + 1, j0 + 1
950             while i != i1 + 1 or j != j1 + 1:
951                 f_X[i - 1, j - 1] = c[2]
952                 r, s, t, u = (
953                     dist[i - 1, j],
954                     dist[i, j - 1],
955                     dist[i + 1, j],
956                     dist[i, j + 1],
957                 )
958                 m = min(r, s, t, u)
959                 if r == m:
960                     i = i - 1
961                 elif t == m:
962                     i = i + 1
963                 elif s == m:
964                     j = j - 1
965                 else:
966                     j = j + 1
967
968             X[i0, j0] = c[2]
969             # f_X[i0, j0] = c[1]
970
971             X[i1, j1] = c[1]
972             f_X[i1, j1] = c[1]
973
974     # for X, f_X in [(A, f_A), (B, f_B)]:
975     # n = torch.arange(self.height * self.width).reshape(self.height, self.width)
976     # k = torch.randperm(self.height * self.width)
977     # X[...]=-1
978     # for q in k:
979     # i,j=q%self.height,q//self.height
980     # if
981
982     # @torch.compile
983     def task_puzzle(self, A, f_A, B, f_B):
984         S = 4
985         i0, j0 = (self.height - S) // 2, (self.width - S) // 2
986         c = torch.randperm(len(self.colors) - 1)[:4] + 1
987         for X, f_X in [(A, f_A), (B, f_B)]:
988             while True:
989                 f_X[...] = 0
990                 h = list(torch.randperm(c.size(0)))
991                 n = torch.zeros(c.max() + 1)
992                 for _ in range(2):
993                     k = torch.randperm(S * S)
994                     for q in k:
995                         i, j = q % S + i0, q // S + j0
996                         if f_X[i, j] == 0:
997                             r, s, t, u = (
998                                 f_X[i - 1, j],
999                                 f_X[i, j - 1],
1000                                 f_X[i + 1, j],
1001                                 f_X[i, j + 1],
1002                             )
1003                             r, s, t, u = torch.tensor([r, s, t, u])[torch.randperm(4)]
1004                             if r > 0 and n[r] < 6:
1005                                 n[r] += 1
1006                                 f_X[i, j] = r
1007                             elif s > 0 and n[s] < 6:
1008                                 n[s] += 1
1009                                 f_X[i, j] = s
1010                             elif t > 0 and n[t] < 6:
1011                                 n[t] += 1
1012                                 f_X[i, j] = t
1013                             elif u > 0 and n[u] < 6:
1014                                 n[u] += 1
1015                                 f_X[i, j] = u
1016                             else:
1017                                 if len(h) > 0:
1018                                     d = c[h.pop()]
1019                                     n[d] += 1
1020                                     f_X[i, j] = d
1021
1022                 if n.sum() == S * S:
1023                     break
1024
1025             k = 0
1026             for d in range(4):
1027                 while True:
1028                     ii, jj = (
1029                         torch.randint(self.height, (1,)).item(),
1030                         torch.randint(self.width, (1,)).item(),
1031                     )
1032                     e = 0
1033                     for i in range(S):
1034                         for j in range(S):
1035                             if (
1036                                 ii + i >= self.height
1037                                 or jj + j >= self.width
1038                                 or (
1039                                     f_X[i + i0, j + j0] == c[d]
1040                                     and X[ii + i, jj + j] > 0
1041                                 )
1042                             ):
1043                                 e = 1
1044                     if e == 0:
1045                         break
1046                 for i in range(S):
1047                     for j in range(S):
1048                         if f_X[i + i0, j + j0] == c[d]:
1049                             X[ii + i, jj + j] = c[d]
1050
1051     def task_islands(self, A, f_A, B, f_B):
1052         c = torch.randperm(len(self.colors) - 1)[:2] + 1
1053         for X, f_X in [(A, f_A), (B, f_B)]:
1054             while True:
1055                 k = torch.randperm(self.height * self.width)
1056                 Z = torch.zeros(self.height + 2, self.width + 2)
1057
1058                 i0, j0 = (
1059                     torch.randint(self.height, (1,)).item() + 1,
1060                     torch.randint(self.width, (1,)).item() + 1,
1061                 )
1062
1063                 Z[i0 - 1 : i0 + 2, j0 - 1 : j0 + 2] = 1
1064
1065                 nb = 9
1066
1067                 for q in k:
1068                     i, j = q % self.height + 1, q // self.height + 1
1069
1070                     if Z[i, j] == 0:
1071                         r, s, t, u, v, w, x, y = (
1072                             Z[i - 1, j],
1073                             Z[i - 1, j + 1],
1074                             Z[i, j + 1],
1075                             Z[i + 1, j + 1],
1076                             Z[i + 1, j],
1077                             Z[i + 1, j - 1],
1078                             Z[i, j - 1],
1079                             Z[i - 1, j - 1],
1080                         )
1081
1082                         if (
1083                             (nb < 16 or r + s + t + u + v + w + x + y > 0)
1084                             and (s == 0 or r + t > 0)
1085                             and (u == 0 or t + v > 0)
1086                             and (w == 0 or x + v > 0)
1087                             and (y == 0 or x + r > 0)
1088                         ):
1089                             # if r+s+t+u+v+w+x+y==0:
1090                             Z[i, j] = 1
1091                             nb += 1
1092
1093                     if nb == self.height * self.width // 2:
1094                         break
1095
1096                 if nb == self.height * self.width // 2:
1097                     break
1098
1099             M = Z.clone()
1100             Z[i0, j0] = 2
1101             X[...] = (Z[1:-1, 1:-1] == 1) * c[0] + (Z[1:-1, 1:-1] == 2) * c[1]
1102
1103             for _ in range(self.height + self.width):
1104                 Z[1:-1, 1:-1] = Z[1:-1, 1:-1].maximum(
1105                     torch.maximum(
1106                         torch.maximum(Z[0:-2, 1:-1], Z[2:, 1:-1]),
1107                         torch.maximum(Z[1:-1, 0:-2], Z[1:-1, 2:]),
1108                     )
1109                 )
1110                 Z *= M
1111
1112             f_X[...] = (Z[1:-1, 1:-1] == 1) * c[0] + (Z[1:-1, 1:-1] == 2) * c[1]
1113
1114     ######################################################################
1115
1116     def trivial_prompts_and_answers(self, prompts, answers):
1117         S = self.height * self.width
1118         Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S]
1119         f_Bs = answers
1120         return (Bs == f_Bs).long().min(dim=-1).values > 0
1121
1122     def generate_prompts_and_answers_(self, nb, tasks=None, progress_bar=False):
1123         if tasks is None:
1124             tasks = self.all_tasks
1125
1126         S = self.height * self.width
1127         prompts = torch.zeros(nb, 3 * S + 2, dtype=torch.int64)
1128         answers = torch.zeros(nb, S, dtype=torch.int64)
1129
1130         bunch = zip(prompts, answers)
1131
1132         if progress_bar:
1133             bunch = tqdm.tqdm(
1134                 bunch,
1135                 dynamic_ncols=True,
1136                 desc="world generation",
1137                 total=prompts.size(0),
1138             )
1139
1140         for prompt, answer in bunch:
1141             A = prompt[0 * (S + 1) : 0 * (S + 1) + S].view(self.height, self.width)
1142             f_A = prompt[1 * (S + 1) : 1 * (S + 1) + S].view(self.height, self.width)
1143             B = prompt[2 * (S + 1) : 2 * (S + 1) + S].view(self.height, self.width)
1144             f_B = answer.view(self.height, self.width)
1145             task = tasks[torch.randint(len(tasks), (1,)).item()]
1146             task(A, f_A, B, f_B)
1147
1148         return prompts.flatten(1), answers.flatten(1)
1149
1150     def save_quiz_illustrations(
1151         self,
1152         result_dir,
1153         filename_prefix,
1154         prompts,
1155         answers,
1156         predicted_prompts=None,
1157         predicted_answers=None,
1158         nrow=4,
1159     ):
1160         self.save_image(
1161             result_dir,
1162             filename_prefix + ".png",
1163             prompts,
1164             answers,
1165             predicted_prompts,
1166             predicted_answers,
1167             nrow,
1168         )
1169
1170     def save_some_examples(self, result_dir):
1171         nb, nrow = 72, 4
1172         for t in self.all_tasks:
1173             print(t.__name__)
1174             prompts, answers = self.generate_prompts_and_answers_(nb, tasks=[t])
1175             self.save_quiz_illustrations(
1176                 result_dir, t.__name__, prompts[:nb], answers[:nb], nrow=nrow
1177             )
1178
1179
1180 ######################################################################
1181
1182 if __name__ == "__main__":
1183     import time
1184
1185     # grids = Grids(max_nb_cached_chunks=5, chunk_size=100, nb_threads=4)
1186     grids = Grids()
1187
1188     # nb = 1000
1189     # grids = problem.MultiThreadProblem(
1190     # grids, max_nb_cached_chunks=50, chunk_size=100, nb_threads=1
1191     # )
1192     #    time.sleep(10)
1193     # start_time = time.perf_counter()
1194     # prompts, answers = grids.generate_prompts_and_answers(nb)
1195     # delay = time.perf_counter() - start_time
1196     # print(f"{prompts.size(0)/delay:02f} seq/s")
1197     # exit(0)
1198
1199     # if True:
1200     nb, nrow = 72, 4
1201     # nb, nrow = 8, 2
1202
1203     # for t in grids.all_tasks:
1204     for t in [grids.task_count]:
1205         print(t.__name__)
1206         prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
1207         grids.save_quiz_illustrations(
1208             "/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow
1209         )
1210
1211     # exit(0)
1212
1213     nb = 1000
1214
1215     # for t in grids.all_tasks:
1216     for t in [grids.task_count]:
1217         start_time = time.perf_counter()
1218         prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
1219         delay = time.perf_counter() - start_time
1220         print(f"{t.__name__} {prompts.size(0)/delay:02f} seq/s")
1221
1222     exit(0)
1223
1224     m = torch.randint(2, (prompts.size(0),))
1225     predicted_prompts = m * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
1226     predicted_answers = (1 - m) * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
1227
1228     grids.save_quiz_illustrations(
1229         "/tmp",
1230         "test",
1231         prompts[:nb],
1232         answers[:nb],
1233         # You can add a bool to put a frame around the predicted parts
1234         predicted_prompts[:nb],
1235         predicted_answers[:nb],
1236     )