3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, tqdm, os, warnings
10 import torch, torchvision
13 from torch.nn import functional as F
15 ######################################################################
20 def grow_islands(nb, height, width, nb_seeds, nb_iterations):
21 w = torch.empty(5, 1, 3, 3)
23 w[0, 0] = torch.tensor(
31 w[1, 0] = torch.tensor(
39 w[2, 0] = torch.tensor(
47 w[3, 0] = torch.tensor(
55 w[4, 0] = torch.tensor(
63 Z = torch.zeros(nb, height, width)
66 for _ in range(nb_seeds):
67 M = F.conv2d(Z[:, None, :, :], w, padding=1)
68 M = torch.cat([M[:, :1], M[:, 1:].min(dim=1, keepdim=True).values], dim=1)
69 M = ((M[:, 0] == 0) & (Z == 0)).long()
70 M = M * torch.rand(M.size())
72 M = F.one_hot(M.argmax(dim=1), num_classes=M.size(1))
75 for _ in range(nb_iterations):
76 M = F.conv2d(Z[:, None, :, :], w, padding=1)
77 M = torch.cat([M[:, :1], M[:, 1:].min(dim=1, keepdim=True).values], dim=1)
78 M = ((M[:, 1] >= 0) & (Z == 0)).long()
79 M = M * torch.rand(M.size())
81 M = F.one_hot(M.argmax(dim=1), num_classes=M.size(1))
86 Z = Z * (torch.arange(Z.size(1) * Z.size(2)) + 1).reshape(1, Z.size(1), Z.size(2))
89 Z = F.max_pool2d(Z, 3, 1, 1) * M
93 V = F.one_hot(U).max(dim=1).values
94 W = V.cumsum(dim=1) - V
95 N = torch.arange(Z.size(0))[:, None, None].expand_as(Z)
101 class Grids(problem.Problem):
103 ("white", [255, 255, 255]),
104 ("red", [255, 0, 0]),
105 ("green", [0, 192, 0]),
106 ("blue", [0, 0, 255]),
107 ("yellow", [255, 224, 0]),
108 ("cyan", [0, 255, 255]),
109 ("violet", [224, 128, 255]),
110 ("lightgreen", [192, 255, 192]),
111 ("brown", [165, 42, 42]),
112 ("lightblue", [192, 192, 255]),
113 ("gray", [128, 128, 128]),
118 max_nb_cached_chunks=None,
123 self.colors = torch.tensor([c for _, c in self.named_colors])
126 self.cache_rec_coo = {}
129 self.task_replace_color,
136 self.task_trajectory,
145 self.all_tasks = all_tasks
147 self.all_tasks = [getattr(self, "task_" + t) for t in tasks.split(",")]
149 super().__init__(max_nb_cached_chunks, chunk_size, nb_threads)
151 ######################################################################
153 def frame2img(self, x, scale=15):
154 x = x.reshape(x.size(0), self.height, -1)
155 m = torch.logical_and(x >= 0, x < self.nb_token_values()).long()
156 x = self.colors[x * m].permute(0, 3, 1, 2)
158 x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
159 x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
161 x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
162 x[:, :, torch.arange(0, x.size(2), scale), :] = 0
165 for n in range(m.size(0)):
166 for i in range(m.size(1)):
167 for j in range(m.size(2)):
169 for k in range(2, scale - 2):
171 x[n, :, i * scale + k, j * scale + k - l] = 0
173 n, :, i * scale + scale - 1 - k, j * scale + k - l
184 predicted_prompts=None,
185 predicted_answers=None,
189 S = self.height * self.width
190 As = prompts[:, 0 * (S + 1) : 0 * (S + 1) + S].view(-1, self.height, self.width)
191 f_As = prompts[:, 1 * (S + 1) : 1 * (S + 1) + S].view(
192 -1, self.height, self.width
194 Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S].view(-1, self.height, self.width)
195 prompts = torch.cat([As, f_As, Bs], dim=2)
196 answers = answers.reshape(answers.size(0), self.height, self.width)
198 if predicted_prompts is None:
199 predicted_prompts = 255
201 if predicted_answers is None:
202 predicted_answers = 255
204 def add_frame(x, c, margin, bottom=False):
206 h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0
209 x.size(2) + 2 * margin,
210 x.size(3) + 2 * margin,
215 y = x.new_full((x.size(0), x.size(1), h, w), 0)
220 c = c.long()[:, None]
222 (1 - ((c == 1).long() + (c == 0).long() + (c == -1).long()))
223 * torch.tensor([64, 64, 64])
224 + (c == 1).long() * torch.tensor([0, 255, 0])
225 + (c == 0).long() * torch.tensor([255, 255, 255])
226 + (c == -1).long() * torch.tensor([255, 0, 0])
228 y[...] = c[:, :, None, None]
230 y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x
234 img_prompts = torch.cat(
237 add_frame(self.frame2img(x), c=0, margin=1),
241 for x in prompts.to("cpu").split(split_size=self.width, dim=2)
246 h = img_prompts.size(2)
247 img_answers = add_frame(
248 add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1),
253 separator_size = 2 * margin
255 separator = img_prompts.new_full(
265 marker = img_prompts.new_full(
275 # marker[:, :, 0] = 0
276 # marker[:, :, h - 1] = 0
278 for k in range(1, 2 * separator_size - 8):
279 i = k - (separator_size - 4)
280 j = separator_size - 5 - abs(i)
281 marker[:, :, h // 2 - 1 + i, 2 + j] = 0
282 marker[:, :, h // 2 - 1 + i + 1, 2 + j] = 0
293 image_name = os.path.join(result_dir, filename)
294 torchvision.utils.save_image(
302 ######################################################################
304 def nb_token_values(self):
305 return len(self.colors)
314 prevent_overlap=False,
316 if surface_max is None:
317 surface_max = self.height * self.width // 2
319 signature = (nb_rec, min_height, min_width, surface_max)
322 return self.cache_rec_coo[signature].pop()
331 i = torch.randint(self.height, (N * nb_rec, 2)).sort(dim=-1).values
332 j = torch.randint(self.width, (N * nb_rec, 2)).sort(dim=-1).values
335 (i[:, 1] >= i[:, 0] + min_height)
336 & (j[:, 1] >= j[:, 0] + min_height)
337 & ((i[:, 1] - i[:, 0]) * (j[:, 1] - j[:, 0]) <= surface_max)
340 i, j = i[big_enough], j[big_enough]
342 n = i.size(0) - i.size(0) % nb_rec
347 i = i[:n].reshape(n // nb_rec, nb_rec, -1)
348 j = j[:n].reshape(n // nb_rec, nb_rec, -1)
351 can_fit = ((i[:, :, 1] - i[:, :, 0]) * (j[:, :, 1] - j[:, :, 0])).sum(
353 ) <= self.height * self.width
354 i, j = i[can_fit], j[can_fit]
356 A_i1, A_i2, A_j1, A_j2 = (
362 B_i1, B_i2, B_j1, B_j2 = (
368 no_overlap = torch.logical_not(
374 i, j = i[no_overlap], j[no_overlap]
376 A_i1, A_i2, A_j1, A_j2 = (
382 B_i1, B_i2, B_j1, B_j2 = (
388 C_i1, C_i2, C_j1, C_j2 = (
414 i, j = (i[no_overlap], j[no_overlap])
421 self.cache_rec_coo[signature] = [
429 for k in range(nb_rec)
431 for n in range(i.size(0))
434 return self.cache_rec_coo[signature].pop()
436 ######################################################################
439 def task_replace_color(self, A, f_A, B, f_B):
441 c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
442 for X, f_X in [(A, f_A), (B, f_B)]:
443 r = self.rec_coo(nb_rec, prevent_overlap=True)
444 for n in range(nb_rec):
445 i1, j1, i2, j2 = r[n]
446 X[i1:i2, j1:j2] = c[n]
447 f_X[i1:i2, j1:j2] = c[n if n > 0 else -1]
450 def task_translate(self, A, f_A, B, f_B):
452 di, dj = torch.randint(3, (2,)) - 1
453 if di.abs() + dj.abs() > 0:
457 c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
458 for X, f_X in [(A, f_A), (B, f_B)]:
460 r = self.rec_coo(nb_rec, prevent_overlap=True)
461 i1, j1, i2, j2 = r[nb_rec - 1]
464 and i2 + di < X.size(0)
466 and j2 + dj < X.size(1)
470 for n in range(nb_rec):
471 i1, j1, i2, j2 = r[n]
472 X[i1:i2, j1:j2] = c[n]
474 f_X[i1 + di : i2 + di, j1 + dj : j2 + dj] = c[n]
476 f_X[i1:i2, j1:j2] = c[n]
479 def task_grow(self, A, f_A, B, f_B):
480 di, dj = torch.randint(2, (2,)) * 2 - 1
482 c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
483 direction = torch.randint(2, (1,)).item()
484 for X, f_X in [(A, f_A), (B, f_B)]:
486 r = self.rec_coo(nb_rec, prevent_overlap=True)
487 i1, j1, i2, j2 = r[nb_rec - 1]
488 if i1 + 3 < i2 and j1 + 3 < j2:
491 for n in range(nb_rec):
492 i1, j1, i2, j2 = r[n]
495 X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
496 f_X[i1:i2, j1:j2] = c[n]
498 X[i1:i2, j1:j2] = c[n]
499 f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
501 X[i1:i2, j1:j2] = c[n]
502 f_X[i1:i2, j1:j2] = c[n]
505 def task_half_fill(self, A, f_A, B, f_B):
506 di, dj = torch.randint(2, (2,)) * 2 - 1
508 c = torch.randperm(len(self.colors) - 1)[: 2 * nb_rec] + 1
509 direction = torch.randint(4, (1,)).item()
510 for X, f_X in [(A, f_A), (B, f_B)]:
511 r = self.rec_coo(nb_rec, prevent_overlap=True)
512 for n in range(nb_rec):
513 i1, j1, i2, j2 = r[n]
514 X[i1:i2, j1:j2] = c[2 * n]
515 f_X[i1:i2, j1:j2] = c[2 * n]
516 # Not my proudest moment
519 X[i : i + 1, j1:j2] = c[2 * n + 1]
521 f_X[i:i2, j1:j2] = c[2 * n + 1]
523 f_X[i : i + 1, j1:j2] = c[2 * n + 1]
525 i = (i1 + i2 - 1) // 2
526 X[i : i + 1, j1:j2] = c[2 * n + 1]
528 f_X[i1 : i + 1, j1:j2] = c[2 * n + 1]
530 f_X[i : i + 1, j1:j2] = c[2 * n + 1]
533 X[i1:i2, j : j + 1] = c[2 * n + 1]
535 f_X[i1:i2, j:j2] = c[2 * n + 1]
537 f_X[i1:i2, j : j + 1] = c[2 * n + 1]
539 j = (j1 + j2 - 1) // 2
540 X[i1:i2, j : j + 1] = c[2 * n + 1]
542 f_X[i1:i2, j1 : j + 1] = c[2 * n + 1]
544 f_X[i1:i2, j : j + 1] = c[2 * n + 1]
547 def task_frame(self, A, f_A, B, f_B):
549 c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
550 for X, f_X in [(A, f_A), (B, f_B)]:
551 r = self.rec_coo(nb_rec, prevent_overlap=True)
552 for n in range(nb_rec):
553 i1, j1, i2, j2 = r[n]
554 X[i1:i2, j1:j2] = c[n]
556 f_X[i1:i2, j1] = c[n]
557 f_X[i1:i2, j2 - 1] = c[n]
558 f_X[i1, j1:j2] = c[n]
559 f_X[i2 - 1, j1:j2] = c[n]
561 f_X[i1:i2, j1:j2] = c[n]
564 def task_detect(self, A, f_A, B, f_B):
566 c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
567 for X, f_X in [(A, f_A), (B, f_B)]:
568 r = self.rec_coo(nb_rec, prevent_overlap=True)
569 for n in range(nb_rec):
570 i1, j1, i2, j2 = r[n]
571 X[i1:i2, j1:j2] = c[n]
576 def contact(self, X, i, j, q):
590 if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width:
591 if X[ii, jj] != 0 and X[ii, jj] != q:
600 if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width:
601 if X[ii, jj] == q and X[i, jj] != q and X[ii, j] != q:
604 for ii, jj in [(i - 1, j), (i, j - 1), (i, j + 1), (i + 1, j)]:
605 if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width:
609 return no, nq, nq_diag
611 def task_count(self, A, f_A, B, f_B):
612 N = torch.randint(4, (1,)).item() + 2
613 c = torch.randperm(len(self.colors) - 1)[:N] + 1
615 for X, f_X in [(A, f_A), (B, f_B)]:
616 l_q = torch.randperm(self.height * self.width)[
617 : self.height * self.width // 20
619 l_d = torch.randint(N, l_q.size())
620 nb = torch.zeros(N, dtype=torch.int64)
622 for q, e in zip(l_q, l_d):
624 i, j = q % self.height, q // self.height
627 and X[max(0, i - 1) : i + 2, max(0, j - 1) : j + 2] == 0
632 l_q = torch.randperm((self.height - 2) * (self.width - 2))[
633 : self.height * self.width // 2
635 l_d = torch.randint(N, l_q.size())
636 for q, e in zip(l_q, l_d):
638 i, j = q % (self.height - 2) + 1, q // (self.height - 2) + 1
639 a1, a2, a3 = X[i - 1, j - 1 : j + 2]
640 a8, a4 = X[i, j - 1], X[i, j + 1]
641 a7, a6, a5 = X[i + 1, j - 1 : j + 2]
644 and nb[e] < self.width
645 and (a2 == 0 or a2 == d)
646 and (a4 == 0 or a4 == d)
647 and (a6 == 0 or a6 == d)
648 and (a8 == 0 or a8 == d)
649 and (a1 == 0 or a2 == d or a8 == d)
650 and (a3 == 0 or a4 == d or a2 == d)
651 and (a5 == 0 or a6 == d or a4 == d)
652 and (a7 == 0 or a8 == d or a6 == d)
665 for j in range(nb[e]):
669 def task_trajectory(self, A, f_A, B, f_B):
670 c = torch.randperm(len(self.colors) - 1)[:2] + 1
671 for X, f_X in [(A, f_A), (B, f_B)]:
673 di, dj = torch.randint(7, (2,)) - 3
675 torch.randint(self.height, (1,)).item(),
676 torch.randint(self.width, (1,)).item(),
679 abs(di) + abs(dj) > 0
681 and i + 2 * di < self.height
683 and j + 2 * dj < self.width
690 and i + k * di < self.height
692 and j + k * dj < self.width
695 X[i + k * di, j + k * dj] = c[k]
696 f_X[i + k * di, j + k * dj] = c[min(k, 1)]
700 def task_bounce(self, A, f_A, B, f_B):
701 c = torch.randperm(len(self.colors) - 1)[:3] + 1
702 for X, f_X in [(A, f_A), (B, f_B)]:
717 for _ in range((self.height * self.width) // 10):
719 torch.randint(self.height, (1,)).item(),
720 torch.randint(self.width, (1,)).item(),
726 di, dj = torch.randint(7, (2,)) - 3
727 if abs(di) + abs(dj) == 1:
731 torch.randint(self.height, (1,)).item(),
732 torch.randint(self.width, (1,)).item(),
741 if free(i + di, j + dj):
743 elif free(i - dj, j + di):
745 if free(i + dj, j - di):
746 if torch.rand(1) < 0.5:
748 elif free(i + dj, j - di):
753 i, j = i + di, j + dj
768 def task_scale(self, A, f_A, B, f_B):
769 c = torch.randperm(len(self.colors) - 1)[:2] + 1
772 torch.randint(self.height // 2, (1,)).item(),
773 torch.randint(self.width // 2, (1,)).item(),
776 for X, f_X in [(A, f_A), (B, f_B)]:
780 torch.randint(self.height // 2 + 1, (1,)).item(),
781 torch.randint(self.width // 2 + 1, (1,)).item(),
784 torch.randint(self.height // 2 + 1, (1,)).item(),
785 torch.randint(self.width // 2 + 1, (1,)).item(),
787 if i1 < i2 and j1 < j2 and min(i2 - i1, j2 - j1) <= 3:
789 X[i + i1 : i + i2, j + j1 : j + j2] = c[0]
790 f_X[2 * i1 : 2 * i2, 2 * j1 : 2 * j2] = c[0]
796 def task_symbols(self, A, f_A, B, f_B):
798 c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
800 for X, f_X in [(A, f_A), (B, f_B)]:
802 i, j = torch.randint(self.height - delta + 1, (nb_rec,)), torch.randint(
803 self.width - delta + 1, (nb_rec,)
805 d = (i[None, :] - i[:, None]).abs().max((j[None, :] - j[:, None]).abs())
806 d.fill_diagonal_(delta + 1)
810 for k in range(1, nb_rec):
811 X[i[k] : i[k] + delta, j[k] : j[k] + delta] = c[k]
813 ai, aj = i.float().mean(), j.float().mean()
815 q = torch.randint(3, (1,)).item() + 1
817 X[i[0] + delta // 2 - 1, j[0] + delta // 2 - 1] = c[0]
818 X[i[0] + delta // 2 - 1, j[0] + delta // 2 + 1] = c[0]
819 X[i[0] + delta // 2 + 1, j[0] + delta // 2 - 1] = c[0]
820 X[i[0] + delta // 2 + 1, j[0] + delta // 2 + 1] = c[0]
822 assert i[q] != ai and j[q] != aj
825 i[0] + delta // 2 + (i[q] - ai).sign().long(),
826 j[0] + delta // 2 + (j[q] - aj).sign().long(),
829 f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q]
832 def task_isometry(self, A, f_A, B, f_B):
834 di, dj = torch.randint(3, (2,)) - 1
835 o = torch.tensor([[0.0, 1.0], [-1.0, 0.0]])
837 for _ in range(torch.randint(4, (1,)).item()):
839 if torch.rand(1) < 0.5:
842 ci, cj = (self.height - 1) / 2, (self.width - 1) / 2
844 for X, f_X in [(A, f_A), (B, f_B)]:
849 c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
851 for r in range(nb_rec):
853 i1, i2 = torch.randint(self.height - 2, (2,)) + 1
854 j1, j2 = torch.randint(self.width - 2, (2,)) + 1
858 and max(i2 - i1, j2 - j1) >= 2
859 and min(i2 - i1, j2 - j1) <= 3
862 X[i1 : i2 + 1, j1 : j2 + 1] = c[r]
864 i1, j1, i2, j2 = i1 - ci, j1 - cj, i2 - ci, j2 - cj
866 i1, j1 = m[0, 0] * i1 + m[0, 1] * j1, m[1, 0] * i1 + m[1, 1] * j1
867 i2, j2 = m[0, 0] * i2 + m[0, 1] * j2, m[1, 0] * i2 + m[1, 1] * j2
869 i1, j1, i2, j2 = i1 + ci, j1 + cj, i2 + ci, j2 + cj
870 i1, i2 = i1.long() + di, i2.long() + di
871 j1, j2 = j1.long() + dj, j2.long() + dj
877 f_X[i1 : i2 + 1, j1 : j2 + 1] = c[r]
879 n = F.one_hot(X.flatten()).sum(dim=0)[1:]
881 n.sum() > self.height * self.width // 4
882 and (n > 0).long().sum() == nb_rec
886 def compute_distance(self, walls, goal_i, goal_j, start_i, start_j):
887 max_length = walls.numel()
888 dist = torch.full_like(walls, max_length)
890 dist[goal_i, goal_j] = 0
891 pred_dist = torch.empty_like(dist)
894 pred_dist.copy_(dist)
898 dist[None, 1:-1, 0:-2],
899 dist[None, 2:, 1:-1],
900 dist[None, 1:-1, 2:],
901 dist[None, 0:-2, 1:-1],
908 dist[1:-1, 1:-1].minimum_(d) # = torch.min(dist[1:-1, 1:-1], d)
909 dist = walls * max_length + (1 - walls) * dist
911 if dist[start_i, start_j] < max_length or dist.equal(pred_dist):
912 return dist * (1 - walls)
915 def task_path(self, A, f_A, B, f_B):
916 c = torch.randperm(len(self.colors) - 1)[:3] + 1
917 dist = torch.empty(self.height + 2, self.width + 2)
918 for X, f_X in [(A, f_A), (B, f_B)]:
919 nb_rec = torch.randint(3, (1,)).item() + 1
921 r = self.rec_coo(nb_rec, prevent_overlap=True)
924 for n in range(nb_rec):
925 i1, j1, i2, j2 = r[n]
926 X[i1:i2, j1:j2] = c[0]
927 f_X[i1:i2, j1:j2] = c[0]
930 torch.randint(self.height, (1,)).item(),
931 torch.randint(self.width, (1,)).item(),
937 torch.randint(self.height, (1,)).item(),
938 torch.randint(self.width, (1,)).item(),
943 dist[1:-1, 1:-1] = (X != 0).long()
944 dist[...] = self.compute_distance(dist, i1 + 1, j1 + 1, i0 + 1, j0 + 1)
945 if dist[i0 + 1, j0 + 1] >= 1 and dist[i0 + 1, j0 + 1] < self.height * 4:
948 dist[1:-1, 1:-1] += (X != 0).long() * self.height * self.width
949 dist[0, :] = self.height * self.width
950 dist[-1, :] = self.height * self.width
951 dist[:, 0] = self.height * self.width
952 dist[:, -1] = self.height * self.width
953 # dist += torch.rand(dist.size())
955 i, j = i0 + 1, j0 + 1
956 while i != i1 + 1 or j != j1 + 1:
957 f_X[i - 1, j - 1] = c[2]
980 # for X, f_X in [(A, f_A), (B, f_B)]:
981 # n = torch.arange(self.height * self.width).reshape(self.height, self.width)
982 # k = torch.randperm(self.height * self.width)
985 # i,j=q%self.height,q//self.height
989 def task_puzzle(self, A, f_A, B, f_B):
991 i0, j0 = (self.height - S) // 2, (self.width - S) // 2
992 c = torch.randperm(len(self.colors) - 1)[:4] + 1
993 for X, f_X in [(A, f_A), (B, f_B)]:
996 h = list(torch.randperm(c.size(0)))
997 n = torch.zeros(c.max() + 1)
999 k = torch.randperm(S * S)
1001 i, j = q % S + i0, q // S + j0
1009 r, s, t, u = torch.tensor([r, s, t, u])[torch.randperm(4)]
1010 if r > 0 and n[r] < 6:
1013 elif s > 0 and n[s] < 6:
1016 elif t > 0 and n[t] < 6:
1019 elif u > 0 and n[u] < 6:
1028 if n.sum() == S * S:
1035 torch.randint(self.height, (1,)).item(),
1036 torch.randint(self.width, (1,)).item(),
1042 ii + i >= self.height
1043 or jj + j >= self.width
1045 f_X[i + i0, j + j0] == c[d]
1046 and X[ii + i, jj + j] > 0
1054 if f_X[i + i0, j + j0] == c[d]:
1055 X[ii + i, jj + j] = c[d]
1057 def task_islands(self, A, f_A, B, f_B):
1058 c = torch.randperm(len(self.colors) - 1)[:2] + 1
1059 for X, f_X in [(A, f_A), (B, f_B)]:
1061 k = torch.randperm(self.height * self.width)
1062 Z = torch.zeros(self.height + 2, self.width + 2)
1065 torch.randint(self.height, (1,)).item() + 1,
1066 torch.randint(self.width, (1,)).item() + 1,
1069 Z[i0 - 1 : i0 + 2, j0 - 1 : j0 + 2] = 1
1074 i, j = q % self.height + 1, q // self.height + 1
1077 r, s, t, u, v, w, x, y = (
1089 (nb < 16 or r + s + t + u + v + w + x + y > 0)
1090 and (s == 0 or r + t > 0)
1091 and (u == 0 or t + v > 0)
1092 and (w == 0 or x + v > 0)
1093 and (y == 0 or x + r > 0)
1095 # if r+s+t+u+v+w+x+y==0:
1099 if nb == self.height * self.width // 2:
1102 if nb == self.height * self.width // 2:
1107 X[...] = (Z[1:-1, 1:-1] == 1) * c[0] + (Z[1:-1, 1:-1] == 2) * c[1]
1109 for _ in range(self.height + self.width):
1110 Z[1:-1, 1:-1] = Z[1:-1, 1:-1].maximum(
1112 torch.maximum(Z[0:-2, 1:-1], Z[2:, 1:-1]),
1113 torch.maximum(Z[1:-1, 0:-2], Z[1:-1, 2:]),
1118 f_X[...] = (Z[1:-1, 1:-1] == 1) * c[0] + (Z[1:-1, 1:-1] == 2) * c[1]
1120 ######################################################################
1122 def trivial_prompts_and_answers(self, prompts, answers):
1123 S = self.height * self.width
1124 Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S]
1126 return (Bs == f_Bs).long().min(dim=-1).values > 0
1128 def generate_prompts_and_answers_(self, nb, tasks=None, progress_bar=False):
1130 tasks = self.all_tasks
1132 S = self.height * self.width
1133 prompts = torch.zeros(nb, 3 * S + 2, dtype=torch.int64)
1134 answers = torch.zeros(nb, S, dtype=torch.int64)
1136 bunch = zip(prompts, answers)
1142 desc="world generation",
1143 total=prompts.size(0),
1146 for prompt, answer in bunch:
1147 A = prompt[0 * (S + 1) : 0 * (S + 1) + S].view(self.height, self.width)
1148 f_A = prompt[1 * (S + 1) : 1 * (S + 1) + S].view(self.height, self.width)
1149 B = prompt[2 * (S + 1) : 2 * (S + 1) + S].view(self.height, self.width)
1150 f_B = answer.view(self.height, self.width)
1151 task = tasks[torch.randint(len(tasks), (1,)).item()]
1152 task(A, f_A, B, f_B)
1154 return prompts.flatten(1), answers.flatten(1)
1156 def save_quiz_illustrations(
1162 predicted_prompts=None,
1163 predicted_answers=None,
1168 filename_prefix + ".png",
1176 def save_some_examples(self, result_dir):
1178 for t in self.all_tasks:
1180 prompts, answers = self.generate_prompts_and_answers_(nb, tasks=[t])
1181 self.save_quiz_illustrations(
1182 result_dir, t.__name__, prompts[:nb], answers[:nb], nrow=nrow
1186 ######################################################################
1188 if __name__ == "__main__":
1191 # grids = Grids(max_nb_cached_chunks=5, chunk_size=100, nb_threads=4)
1195 # grids = problem.MultiThreadProblem(
1196 # grids, max_nb_cached_chunks=50, chunk_size=100, nb_threads=1
1199 # start_time = time.perf_counter()
1200 # prompts, answers = grids.generate_prompts_and_answers(nb)
1201 # delay = time.perf_counter() - start_time
1202 # print(f"{prompts.size(0)/delay:02f} seq/s")
1209 # for t in grids.all_tasks:
1210 for t in [grids.task_count]:
1212 prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
1213 grids.save_quiz_illustrations(
1214 "/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow
1221 # for t in grids.all_tasks:
1222 for t in [grids.task_islands]:
1223 start_time = time.perf_counter()
1224 prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
1225 delay = time.perf_counter() - start_time
1226 print(f"{t.__name__} {prompts.size(0)/delay:02f} seq/s")
1230 m = torch.randint(2, (prompts.size(0),))
1231 predicted_prompts = m * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
1232 predicted_answers = (1 - m) * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
1234 grids.save_quiz_illustrations(
1239 # You can add a bool to put a frame around the predicted parts
1240 predicted_prompts[:nb],
1241 predicted_answers[:nb],