3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, tqdm, os, warnings, cairo
10 import torch, torchvision
13 from torch.nn import functional as F
15 ######################################################################
18 def text_img(height, width, text):
19 pixel_map = torch.full((height, width, 4), 255, dtype=torch.uint8)
21 surface = cairo.ImageSurface.create_for_data(
22 pixel_map.numpy(), cairo.FORMAT_ARGB32, pixel_map.size(1), pixel_map.size(0)
25 ctx = cairo.Context(surface)
26 ctx.set_source_rgb(0, 0, 0)
28 ctx.select_font_face("courier", cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_NORMAL)
30 for line in text.split("\n"):
31 xbearing, ybearing, width, height, dx, dy = ctx.text_extents(line)
42 return pixel_map.permute(2, 0, 1)[None, :3].contiguous()
45 ######################################################################
50 def grow_islands(nb, height, width, nb_seeds, nb_iterations):
51 w = torch.empty(5, 1, 3, 3)
53 w[0, 0] = torch.tensor(
61 w[1, 0] = torch.tensor(
69 w[2, 0] = torch.tensor(
77 w[3, 0] = torch.tensor(
85 w[4, 0] = torch.tensor(
93 Z = torch.zeros(nb, height, width)
96 for _ in range(nb_seeds):
97 M = F.conv2d(Z[:, None, :, :], w, padding=1)
98 M = torch.cat([M[:, :1], M[:, 1:].min(dim=1, keepdim=True).values], dim=1)
99 M = ((M[:, 0] == 0) & (Z == 0)).long()
100 Q = (M.flatten(1).max(dim=1).values > 0).long()[:, None]
101 M = M * torch.rand(M.size())
103 M = F.one_hot(M.argmax(dim=1), num_classes=M.size(1))
106 for _ in range(nb_iterations):
107 M = F.conv2d(Z[:, None, :, :], w, padding=1)
108 M = torch.cat([M[:, :1], M[:, 1:].min(dim=1, keepdim=True).values], dim=1)
109 M = ((M[:, 1] >= 0) & (Z == 0)).long()
110 Q = (M.flatten(1).max(dim=1).values > 0).long()[:, None]
111 M = M * torch.rand(M.size())
113 M = F.one_hot(M.argmax(dim=1), num_classes=M.size(1))
118 Z = Z * (torch.arange(Z.size(1) * Z.size(2)) + 1).reshape(1, Z.size(1), Z.size(2))
122 Z = F.max_pool2d(Z, 3, 1, 1) * M
128 V = F.one_hot(U).max(dim=1).values
129 W = V.cumsum(dim=1) - V
130 N = torch.arange(Z.size(0))[:, None, None].expand_as(Z)
136 class Grids(problem.Problem):
138 ("white", [255, 255, 255]),
139 ("red", [255, 0, 0]),
140 ("green", [0, 192, 0]),
141 ("blue", [0, 0, 255]),
142 ("yellow", [255, 224, 0]),
143 ("cyan", [0, 255, 255]),
144 ("violet", [224, 128, 255]),
145 ("lightgreen", [192, 255, 192]),
146 ("brown", [165, 42, 42]),
147 ("lightblue", [192, 192, 255]),
148 ("gray", [128, 128, 128]),
151 def check_structure(self, quizzes, struct):
152 S = self.height * self.width
155 (quizzes[:, 0 * (S + 1)] == self.l2tok[struct[0]])
156 & (quizzes[:, 1 * (S + 1)] == self.l2tok[struct[1]])
157 & (quizzes[:, 2 * (S + 1)] == self.l2tok[struct[2]])
158 & (quizzes[:, 3 * (S + 1)] == self.l2tok[struct[3]])
161 def get_structure(self, quizzes):
162 S = self.height * self.width
165 for n in quizzes.reshape(quizzes.size(0), 4, S + 1)[0, :, 0]
167 self.check_structure(quizzes, struct)
170 def inject_noise(self, quizzes, noise, struct, mask):
171 assert self.check_structure(quizzes, struct=struct)
172 S = self.height * self.width
174 mask = torch.tensor(mask, device=quizzes.device)
175 mask = mask[None, :, None].expand(1, 4, S + 1).clone()
177 mask = mask.reshape(1, -1).expand_as(quizzes)
178 mask = mask * (torch.rand(mask.size(), device=mask.device) <= noise).long()
179 random = torch.randint(self.nb_colors, mask.size())
180 quizzes = mask * random + (1 - mask) * quizzes
185 def reconfigure(self, quizzes, struct=("A", "f_A", "B", "f_B")):
186 if torch.is_tensor(quizzes):
187 return self.reconfigure([quizzes], struct=struct)[0]
189 S = self.height * self.width
190 result = [x.new(x.size()) for x in quizzes]
192 struct_from = self.get_structure(quizzes[0][:1])
193 i = self.indices_select(quizzes[0], struct_from)
195 sf = dict((l, n) for n, l in enumerate(struct_from))
199 for x, y in zip(quizzes, result):
201 y[i, q * l : (q + 1) * l] = x[i, k * l : (k + 1) * l]
207 self.reconfigure([x[j] for x in quizzes], struct=struct), result
213 def trivial(self, quizzes):
214 S = self.height * self.width
215 assert self.check_structure(quizzes, struct=("A", "f_A", "B", "f_B"))
216 a = quizzes.reshape(quizzes.size(0), 4, S + 1)[:, :, 1:]
217 return (a[:, 0] == a[:, 1]).min(dim=1).values | (a[:, 2] == a[:, 3]).min(
222 self, quizzes, struct=("A", "f_A", "B", "f_B"), mask=(0, 0, 0, 1)
224 assert self.check_structure(quizzes, struct)
226 ar_mask = quizzes.new_zeros(quizzes.size())
228 S = self.height * self.width
229 a = ar_mask.reshape(ar_mask.size(0), 4, S + 1)[:, :, 1:]
237 def indices_select(self, quizzes, struct=("A", "f_A", "B", "f_B")):
238 S = self.height * self.width
239 q = quizzes.reshape(quizzes.size(0), 4, S + 1)
241 (q[:, 0, 0] == self.l2tok[struct[0]])
242 & (q[:, 1, 0] == self.l2tok[struct[1]])
243 & (q[:, 2, 0] == self.l2tok[struct[2]])
244 & (q[:, 3, 0] == self.l2tok[struct[3]])
249 max_nb_cached_chunks=None,
254 self.colors = torch.tensor([c for _, c in self.named_colors])
256 self.nb_colors = len(self.colors)
257 self.token_A = self.nb_colors
258 self.token_f_A = self.token_A + 1
259 self.token_B = self.token_f_A + 1
260 self.token_f_B = self.token_B + 1
263 self.rfree = torch.tensor([])
267 "f_A": self.token_f_A,
269 "f_B": self.token_f_B,
274 self.token_f_A: "f_A",
276 self.token_f_B: "f_B",
281 self.seq_len = 4 * (1 + self.height * self.width)
282 self.nb_token_values = self.token_f_B + 1
284 self.cache_rec_coo = {}
287 self.task_replace_color,
299 ############################################ hard ones
301 self.task_trajectory,
303 # self.task_count, # NOT REVERSIBLE
304 # self.task_islands, # TOO MESSY
308 self.all_tasks = all_tasks
310 self.all_tasks = [getattr(self, "task_" + t) for t in tasks.split(",")]
312 super().__init__(max_nb_cached_chunks, chunk_size, nb_threads)
314 ######################################################################
316 def grid2img(self, x, scale=15):
317 m = torch.logical_and(x >= 0, x < self.nb_colors).long()
318 y = self.colors[x * m].permute(0, 3, 1, 2)
320 y = y[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
321 y = y.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
323 y[:, :, :, torch.arange(0, y.size(3), scale)] = 64
324 y[:, :, torch.arange(0, y.size(2), scale), :] = 64
326 for n in range(m.size(0)):
327 for i in range(m.size(1)):
328 for j in range(m.size(2)):
330 for k in range(3, scale - 2):
331 y[n, :, i * scale + k, j * scale + k] = 0
332 y[n, :, i * scale + k, j * scale + scale - k] = 0
338 def add_frame(self, img, colors, thickness):
342 img.size(2) + 2 * thickness,
343 img.size(3) + 2 * thickness,
346 result[...] = colors[:, :, None, None]
347 result[:, :, thickness:-thickness, thickness:-thickness] = img
351 def save_quizzes_as_image(
356 predicted_parts=None,
363 quizzes = quizzes.to("cpu")
365 to_reconfigure = [quizzes]
366 if predicted_parts is not None:
367 to_reconfigure.append(predicted_parts)
368 if correct_parts is not None:
369 to_reconfigure.append(correct_parts)
371 to_reconfigure = self.reconfigure(to_reconfigure, ("A", "f_A", "B", "f_B"))
373 quizzes = to_reconfigure.pop(0)
374 if predicted_parts is not None:
375 predicted_parts = to_reconfigure.pop(0)
376 if correct_parts is not None:
377 correct_parts = to_reconfigure.pop(0)
379 S = self.height * self.width
382 quizzes.reshape(quizzes.size(0), 4, S + 1)[:, :, 1:]
383 .reshape(quizzes.size(0), 4, self.height, self.width)
387 frame, white, gray, green, red = torch.tensor(
388 [[64, 64, 64], [255, 255, 255], [200, 200, 200], [0, 255, 0], [255, 0, 0]],
389 device=quizzes.device,
392 img_A = self.add_frame(self.grid2img(A), frame[None, :], thickness=1)
393 img_f_A = self.add_frame(self.grid2img(f_A), frame[None, :], thickness=1)
394 img_B = self.add_frame(self.grid2img(B), frame[None, :], thickness=1)
395 img_f_B = self.add_frame(self.grid2img(f_B), frame[None, :], thickness=1)
397 # predicted_parts Nx4
400 if predicted_parts is None:
401 colors = white[None, None, :].expand(-1, 4, -1)
403 predicted_parts = predicted_parts.to("cpu")
404 if correct_parts is None:
406 predicted_parts[:, :, None] * gray[None, None, :]
407 + (1 - predicted_parts[:, :, None]) * white[None, None, :]
410 correct_parts = correct_parts.to("cpu")
412 predicted_parts[:, :, None]
414 (correct_parts[:, :, None] == 1).long() * green[None, None, :]
415 + (correct_parts[:, :, None] == 0).long() * gray[None, None, :]
416 + (correct_parts[:, :, None] == -1).long() * red[None, None, :]
418 + (1 - predicted_parts[:, :, None]) * white[None, None, :]
421 img_A = self.add_frame(img_A, colors[:, 0], thickness=8)
422 img_f_A = self.add_frame(img_f_A, colors[:, 1], thickness=8)
423 img_B = self.add_frame(img_B, colors[:, 2], thickness=8)
424 img_f_B = self.add_frame(img_f_B, colors[:, 3], thickness=8)
426 img_A = self.add_frame(img_A, white[None, :], thickness=2)
427 img_f_A = self.add_frame(img_f_A, white[None, :], thickness=2)
428 img_B = self.add_frame(img_B, white[None, :], thickness=2)
429 img_f_B = self.add_frame(img_f_B, white[None, :], thickness=2)
431 img = torch.cat([img_A, img_f_A, img_B, img_f_B], dim=3)
433 if comments is not None:
434 comment_img = [text_img(comment_height, img.size(3), t) for t in comments]
435 comment_img = torch.cat(comment_img, dim=0)
436 img = torch.cat([img, comment_img], dim=2)
438 image_name = os.path.join(result_dir, filename)
440 torchvision.utils.save_image(
448 ######################################################################
457 prevent_overlap=False,
459 if surface_max is None:
460 surface_max = self.height * self.width // 2
462 signature = (nb_rec, min_height, min_width, surface_max)
465 return self.cache_rec_coo[signature].pop()
474 i = torch.randint(self.height, (N * nb_rec, 2)).sort(dim=-1).values
475 j = torch.randint(self.width, (N * nb_rec, 2)).sort(dim=-1).values
479 (i[:, 1] >= i[:, 0] + min_height)
480 & (j[:, 1] >= j[:, 0] + min_height)
481 & ((i[:, 1] - i[:, 0]) * (j[:, 1] - j[:, 0]) <= surface_max)
484 i, j = i[big_enough], j[big_enough]
486 n = i.size(0) - i.size(0) % nb_rec
491 i = i[:n].reshape(n // nb_rec, nb_rec, -1)
492 j = j[:n].reshape(n // nb_rec, nb_rec, -1)
495 can_fit = ((i[:, :, 1] - i[:, :, 0]) * (j[:, :, 1] - j[:, :, 0])).sum(
497 ) <= self.height * self.width
498 i, j = i[can_fit], j[can_fit]
500 A_i1, A_i2, A_j1, A_j2 = (
506 B_i1, B_i2, B_j1, B_j2 = (
518 i, j = (i[no_overlap], j[no_overlap])
520 A_i1, A_i2, A_j1, A_j2 = (
526 B_i1, B_i2, B_j1, B_j2 = (
532 C_i1, C_i2, C_j1, C_j2 = (
558 i, j = (i[no_overlap], j[no_overlap])
565 self.cache_rec_coo[signature] = [
573 for k in range(nb_rec)
575 for n in range(i.size(0))
578 return self.cache_rec_coo[signature].pop()
580 ######################################################################
582 def contact_matrices(self, rn, ri, rj, rz):
583 n = torch.arange(self.nb_rec_max)
588 (ri[:, :, None, 0] == ri[:, None, :, 1] + 1)
589 | (ri[:, :, None, 1] + 1 == ri[:, None, :, 0])
591 & (rj[:, :, None, 0] <= rj[:, None, :, 1])
592 & (rj[:, :, None, 1] >= rj[:, None, :, 0])
596 (rj[:, :, None, 0] == rj[:, None, :, 1] + 1)
597 | (rj[:, :, None, 1] + 1 == rj[:, None, :, 0])
599 & (ri[:, :, None, 0] <= ri[:, None, :, 1])
600 & (ri[:, :, None, 1] >= ri[:, None, :, 0])
603 # & (rz[:, :, None] == rz[:, None, :])
604 & (n[None, :, None] < rn[:, None, None])
605 & (n[None, None, :] < n[None, :, None])
608 def sample_rworld_states(self, N=1000):
611 torch.randint(self.height - 2, (N, self.nb_rec_max, 2))
617 torch.randint(self.width - 2, (N, self.nb_rec_max, 2))
622 rn = torch.randint(self.nb_rec_max - 1, (N,)) + 2
623 rz = torch.randint(2, (N, self.nb_rec_max))
624 rc = torch.randint(self.nb_colors - 1, (N, self.nb_rec_max)) + 1
625 n = torch.arange(self.nb_rec_max)
628 (ri[:, :, None, 0] <= ri[:, None, :, 1])
629 & (ri[:, :, None, 1] >= ri[:, None, :, 0])
630 & (rj[:, :, None, 0] <= rj[:, None, :, 1])
631 & (rj[:, :, None, 1] >= rj[:, None, :, 0])
632 & (rz[:, :, None] == rz[:, None, :])
633 & (n[None, :, None] < rn[:, None, None])
634 & (n[None, None, :] < n[None, :, None])
641 no_collision = nb_collisions == 0
643 if no_collision.any():
644 print(no_collision.long().sum() / N)
645 self.rn = rn[no_collision]
646 self.ri = ri[no_collision]
647 self.rj = rj[no_collision]
648 self.rz = rz[no_collision]
649 self.rc = rc[no_collision]
652 self.contact_matrices(rn, ri, rj, rz).long().flatten(1).sum(dim=1)
655 self.rcontact = nb_contact > 0
656 self.rfree = torch.full((self.rn.size(0),), True)
660 def get_recworld_state(self):
661 if not self.rfree.any():
662 self.sample_rworld_states()
663 k = torch.arange(self.rn.size(0))[self.rfree]
664 k = k[torch.randint(k.size(0), (1,))].item()
665 self.rfree[k] = False
666 return self.rn[k], self.ri[k], self.rj[k], self.rz[k], self.rc[k]
668 def draw_state(self, X, rn, ri, rj, rz, rc):
669 for n in sorted(list(range(rn)), key=lambda n: rz[n].item()):
670 X[ri[n, 0] : ri[n, 1] + 1, rj[n, 0] : rj[n, 1] + 1] = rc[n]
672 def task_recworld_immobile(self, A, f_A, B, f_B):
673 for X, f_X in [(A, f_A), (B, f_B)]:
674 rn, ri, rj, rz, rc = self.get_recworld_state()
675 self.draw_state(X, rn, ri, rj, rz, rc)
677 self.draw_state(f_X, rn, ri, rj, rz, rc)
679 ######################################################################
682 def task_replace_color(self, A, f_A, B, f_B):
684 c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
685 for X, f_X in [(A, f_A), (B, f_B)]:
686 r = self.rec_coo(nb_rec, prevent_overlap=True)
687 for n in range(nb_rec):
688 i1, j1, i2, j2 = r[n]
689 X[i1:i2, j1:j2] = c[n]
690 f_X[i1:i2, j1:j2] = c[n if n > 0 else -1]
693 def task_translate(self, A, f_A, B, f_B):
695 di, dj = torch.randint(3, (2,)) - 1
696 if di.abs() + dj.abs() > 0:
700 c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
701 for X, f_X in [(A, f_A), (B, f_B)]:
703 r = self.rec_coo(nb_rec, prevent_overlap=True)
704 i1, j1, i2, j2 = r[nb_rec - 1]
707 and i2 + di < X.size(0)
709 and j2 + dj < X.size(1)
713 for n in range(nb_rec):
714 i1, j1, i2, j2 = r[n]
715 X[i1:i2, j1:j2] = c[n]
717 f_X[i1 + di : i2 + di, j1 + dj : j2 + dj] = c[n]
719 f_X[i1:i2, j1:j2] = c[n]
722 def task_grow(self, A, f_A, B, f_B):
723 di, dj = torch.randint(2, (2,)) * 2 - 1
725 c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
726 direction = torch.randint(2, (1,)).item()
727 for X, f_X in [(A, f_A), (B, f_B)]:
729 r = self.rec_coo(nb_rec, prevent_overlap=True)
730 i1, j1, i2, j2 = r[nb_rec - 1]
731 if i1 + 3 < i2 and j1 + 3 < j2:
734 for n in range(nb_rec):
735 i1, j1, i2, j2 = r[n]
738 X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
739 f_X[i1:i2, j1:j2] = c[n]
741 X[i1:i2, j1:j2] = c[n]
742 f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
744 X[i1:i2, j1:j2] = c[n]
745 f_X[i1:i2, j1:j2] = c[n]
748 def task_half_fill(self, A, f_A, B, f_B):
749 di, dj = torch.randint(2, (2,)) * 2 - 1
751 c = torch.randperm(self.nb_colors - 1)[: 2 * nb_rec] + 1
752 direction = torch.randint(4, (1,)).item()
753 for X, f_X in [(A, f_A), (B, f_B)]:
754 r = self.rec_coo(nb_rec, prevent_overlap=True)
755 for n in range(nb_rec):
756 i1, j1, i2, j2 = r[n]
757 X[i1:i2, j1:j2] = c[2 * n]
758 f_X[i1:i2, j1:j2] = c[2 * n]
759 # Not my proudest moment
762 X[i : i + 1, j1:j2] = c[2 * n + 1]
764 f_X[i:i2, j1:j2] = c[2 * n + 1]
766 f_X[i : i + 1, j1:j2] = c[2 * n + 1]
768 i = (i1 + i2 - 1) // 2
769 X[i : i + 1, j1:j2] = c[2 * n + 1]
771 f_X[i1 : i + 1, j1:j2] = c[2 * n + 1]
773 f_X[i : i + 1, j1:j2] = c[2 * n + 1]
776 X[i1:i2, j : j + 1] = c[2 * n + 1]
778 f_X[i1:i2, j:j2] = c[2 * n + 1]
780 f_X[i1:i2, j : j + 1] = c[2 * n + 1]
782 j = (j1 + j2 - 1) // 2
783 X[i1:i2, j : j + 1] = c[2 * n + 1]
785 f_X[i1:i2, j1 : j + 1] = c[2 * n + 1]
787 f_X[i1:i2, j : j + 1] = c[2 * n + 1]
790 def task_frame(self, A, f_A, B, f_B):
792 c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
793 for X, f_X in [(A, f_A), (B, f_B)]:
794 r = self.rec_coo(nb_rec, prevent_overlap=True)
795 for n in range(nb_rec):
796 i1, j1, i2, j2 = r[n]
797 X[i1:i2, j1:j2] = c[n]
799 f_X[i1:i2, j1] = c[n]
800 f_X[i1:i2, j2 - 1] = c[n]
801 f_X[i1, j1:j2] = c[n]
802 f_X[i2 - 1, j1:j2] = c[n]
804 f_X[i1:i2, j1:j2] = c[n]
807 def task_detect(self, A, f_A, B, f_B):
809 c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
810 for X, f_X in [(A, f_A), (B, f_B)]:
811 r = self.rec_coo(nb_rec, prevent_overlap=True)
812 for n in range(nb_rec):
813 i1, j1, i2, j2 = r[n]
814 X[i1:i2, j1:j2] = c[n]
815 f_X[i1:i2, j1:j2] = c[n]
818 f_X[i1 + k, j1] = c[-1]
819 f_X[i1, j1 + k] = c[-1]
822 def contact(self, X, i, j, q):
836 if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width:
837 if X[ii, jj] != 0 and X[ii, jj] != q:
846 if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width:
847 if X[ii, jj] == q and X[i, jj] != q and X[ii, j] != q:
850 for ii, jj in [(i - 1, j), (i, j - 1), (i, j + 1), (i + 1, j)]:
851 if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width:
855 return no, nq, nq_diag
857 def REMOVED_task_count(self, A, f_A, B, f_B):
862 c = torch.zeros(N + 2, dtype=torch.int64)
863 c[1:] = torch.randperm(self.nb_colors - 1)[: N + 1] + 1
865 for X, f_X in [(A, f_A), (B, f_B)]:
866 if not hasattr(self, "cache_count") or len(self.cache_count) == 0:
867 self.cache_count = list(
872 nb_seeds=self.height * self.width // 8,
873 nb_iterations=self.height * self.width // 5,
877 X[...] = self.cache_count.pop()
879 # k = (X.max() + 1 + (c.size(0) - 1)).item()
880 # V = torch.arange(k) // (c.size(0) - 1)
881 # V = (V + torch.rand(V.size())).sort().indices[: X.max() + 1] % (
885 V = torch.randint(N, (X.max() + 1,)) + 1
887 NB = F.one_hot(c[V]).sum(dim=0)
891 if F.one_hot(X.flatten()).max(dim=0).values.sum().item() >= 3:
893 if (NB[c[:-1]] == m).long().sum() == 1:
894 for e in range(1, N + 1):
896 a = (f_X == c[e]).long()
897 f_X[...] = (1 - a) * f_X + a * c[-1]
905 assert F.one_hot(A.flatten()).max(dim=0).values.sum() >= 3
908 def task_trajectory(self, A, f_A, B, f_B):
909 c = torch.randperm(self.nb_colors - 1)[:2] + 1
910 for X, f_X in [(A, f_A), (B, f_B)]:
912 di, dj = torch.randint(7, (2,)) - 3
914 torch.randint(self.height, (1,)).item(),
915 torch.randint(self.width, (1,)).item(),
918 abs(di) + abs(dj) > 0
920 and i + 2 * di < self.height
922 and j + 2 * dj < self.width
929 and i + k * di < self.height
931 and j + k * dj < self.width
934 X[i + k * di, j + k * dj] = c[k]
935 f_X[i + k * di, j + k * dj] = c[min(k, 1)]
939 def task_bounce(self, A, f_A, B, f_B):
940 c = torch.randperm(self.nb_colors - 1)[:3] + 1
941 for X, f_X in [(A, f_A), (B, f_B)]:
956 for _ in range((self.height * self.width) // 10):
958 torch.randint(self.height, (1,)).item(),
959 torch.randint(self.width, (1,)).item(),
965 di, dj = torch.randint(7, (2,)) - 3
966 if abs(di) + abs(dj) == 1:
970 torch.randint(self.height, (1,)).item(),
971 torch.randint(self.width, (1,)).item(),
980 if free(i + di, j + dj):
982 elif free(i - dj, j + di):
984 if free(i + dj, j - di):
985 if torch.rand(1) < 0.5:
987 elif free(i + dj, j - di):
992 i, j = i + di, j + dj
1008 def task_scale(self, A, f_A, B, f_B):
1009 c = torch.randperm(self.nb_colors - 1)[:2] + 1
1012 torch.randint(self.height // 2, (1,)).item(),
1013 torch.randint(self.width // 2, (1,)).item(),
1016 for X, f_X in [(A, f_A), (B, f_B)]:
1020 torch.randint(self.height // 2 + 1, (1,)).item(),
1021 torch.randint(self.width // 2 + 1, (1,)).item(),
1024 torch.randint(self.height // 2 + 1, (1,)).item(),
1025 torch.randint(self.width // 2 + 1, (1,)).item(),
1027 if i1 < i2 and j1 < j2 and min(i2 - i1, j2 - j1) <= 3:
1029 X[i + i1 : i + i2, j + j1 : j + j2] = c[0]
1030 f_X[2 * i1 : 2 * i2, 2 * j1 : 2 * j2] = c[0]
1035 f_X[i + k, j] = c[1]
1036 f_X[i, j + k] = c[1]
1039 def task_symbols(self, A, f_A, B, f_B):
1041 c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
1043 for X, f_X in [(A, f_A), (B, f_B)]:
1045 i, j = torch.randint(self.height - delta + 1, (nb_rec,)), torch.randint(
1046 self.width - delta + 1, (nb_rec,)
1048 d = (i[None, :] - i[:, None]).abs().max((j[None, :] - j[:, None]).abs())
1049 d.fill_diagonal_(delta + 1)
1053 ai, aj = i.float().mean(), j.float().mean()
1055 q = torch.randint(3, (1,)).item() + 1
1057 assert i[q] != ai and j[q] != aj
1060 for k in range(0, nb_rec):
1061 Z[i[k] : i[k] + delta, j[k] : j[k] + delta] = c[k]
1062 # Z[i[0] + delta // 2 - 1, j[0] + delta // 2 - 1] = c[0]
1063 # Z[i[0] + delta // 2 - 1, j[0] + delta // 2 + 1] = c[0]
1064 # Z[i[0] + delta // 2 + 1, j[0] + delta // 2 - 1] = c[0]
1065 # Z[i[0] + delta // 2 + 1, j[0] + delta // 2 + 1] = c[0]
1067 # f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q]
1069 f_X[i[0] + delta // 2, j[0] + delta // 2] = c[q]
1070 # f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q]
1073 i[0] + delta // 2 + (i[q] - ai).sign().long(),
1074 j[0] + delta // 2 + (j[q] - aj).sign().long(),
1077 X[ii, jj] = c[nb_rec]
1078 X[i[0] + delta // 2, jj] = c[nb_rec]
1079 X[ii, j[0] + delta // 2] = c[nb_rec]
1081 f_X[ii, jj] = c[nb_rec]
1082 f_X[i[0] + delta // 2, jj] = c[nb_rec]
1083 f_X[ii, j[0] + delta // 2] = c[nb_rec]
1086 def task_isometry(self, A, f_A, B, f_B):
1088 di, dj = torch.randint(3, (2,)) - 1
1089 o = torch.tensor([[0.0, 1.0], [-1.0, 0.0]])
1091 for _ in range(torch.randint(4, (1,)).item()):
1093 if torch.rand(1) < 0.5:
1096 ci, cj = (self.height - 1) / 2, (self.width - 1) / 2
1098 for X, f_X in [(A, f_A), (B, f_B)]:
1103 c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
1105 for r in range(nb_rec):
1107 i1, i2 = torch.randint(self.height - 2, (2,)) + 1
1108 j1, j2 = torch.randint(self.width - 2, (2,)) + 1
1112 and max(i2 - i1, j2 - j1) >= 2
1113 and min(i2 - i1, j2 - j1) <= 3
1116 X[i1 : i2 + 1, j1 : j2 + 1] = c[r]
1118 i1, j1, i2, j2 = i1 - ci, j1 - cj, i2 - ci, j2 - cj
1120 i1, j1 = m[0, 0] * i1 + m[0, 1] * j1, m[1, 0] * i1 + m[1, 1] * j1
1121 i2, j2 = m[0, 0] * i2 + m[0, 1] * j2, m[1, 0] * i2 + m[1, 1] * j2
1123 i1, j1, i2, j2 = i1 + ci, j1 + cj, i2 + ci, j2 + cj
1124 i1, i2 = i1.long() + di, i2.long() + di
1125 j1, j2 = j1.long() + dj, j2.long() + dj
1131 f_X[i1 : i2 + 1, j1 : j2 + 1] = c[r]
1133 n = F.one_hot(X.flatten()).sum(dim=0)[1:]
1135 n.sum() > self.height * self.width // 4
1136 and (n > 0).long().sum() == nb_rec
1140 def compute_distance(self, walls, goal_i, goal_j):
1141 max_length = walls.numel()
1142 dist = torch.full_like(walls, max_length)
1144 dist[goal_i, goal_j] = 0
1145 pred_dist = torch.empty_like(dist)
1148 pred_dist.copy_(dist)
1149 dist[1:-1, 1:-1] = (
1152 dist[None, 1:-1, 1:-1],
1153 dist[None, 1:-1, 0:-2],
1154 dist[None, 2:, 1:-1],
1155 dist[None, 1:-1, 2:],
1156 dist[None, 0:-2, 1:-1],
1163 dist = walls * max_length + (1 - walls) * dist
1165 if dist.equal(pred_dist):
1166 return dist * (1 - walls)
1169 def REMOVED_task_distance(self, A, f_A, B, f_B):
1170 c = torch.randperm(self.nb_colors - 1)[:3] + 1
1171 dist0 = torch.empty(self.height + 2, self.width + 2)
1172 dist1 = torch.empty(self.height + 2, self.width + 2)
1173 for X, f_X in [(A, f_A), (B, f_B)]:
1174 nb_rec = torch.randint(3, (1,)).item() + 1
1176 r = self.rec_coo(nb_rec, prevent_overlap=True)
1179 for n in range(nb_rec):
1180 i1, j1, i2, j2 = r[n]
1181 X[i1:i2, j1:j2] = c[0]
1182 f_X[i1:i2, j1:j2] = c[0]
1185 torch.randint(self.height, (1,)).item(),
1186 torch.randint(self.width, (1,)).item(),
1192 torch.randint(self.height, (1,)).item(),
1193 torch.randint(self.width, (1,)).item(),
1198 dist1[1:-1, 1:-1] = (X != 0).long()
1199 dist1[...] = self.compute_distance(dist1, i1 + 1, j1 + 1)
1201 dist1[i0 + 1, j0 + 1] >= 1
1202 and dist1[i0 + 1, j0 + 1] < self.height * 4
1207 dist0[1:-1, 1:-1] = (X != 0).long()
1208 dist0[...] = self.compute_distance(dist0, i0 + 1, j0 + 1)
1210 dist0 = dist0[1:-1, 1:-1]
1211 dist1 = dist1[1:-1, 1:-1]
1214 for d in range(1, D):
1215 M = (dist0 == d) & (dist1 == D - d)
1216 f_X[...] = (1 - M) * f_X + M * c[1]
1223 # for X, f_X in [(A, f_A), (B, f_B)]:
1224 # n = torch.arange(self.height * self.width).reshape(self.height, self.width)
1225 # k = torch.randperm(self.height * self.width)
1228 # i,j=q%self.height,q//self.height
1232 def TOO_HARD_task_puzzle(self, A, f_A, B, f_B):
1234 i0, j0 = (self.height - S) // 2, (self.width - S) // 2
1235 c = torch.randperm(self.nb_colors - 1)[:4] + 1
1236 for X, f_X in [(A, f_A), (B, f_B)]:
1239 h = list(torch.randperm(c.size(0)))
1240 n = torch.zeros(c.max() + 1)
1242 k = torch.randperm(S * S)
1244 i, j = q % S + i0, q // S + j0
1252 r, s, t, u = torch.tensor([r, s, t, u])[torch.randperm(4)]
1253 if r > 0 and n[r] < 6:
1256 elif s > 0 and n[s] < 6:
1259 elif t > 0 and n[t] < 6:
1262 elif u > 0 and n[u] < 6:
1271 if n.sum() == S * S:
1278 torch.randint(self.height, (1,)).item(),
1279 torch.randint(self.width, (1,)).item(),
1285 ii + i >= self.height
1286 or jj + j >= self.width
1288 f_X[i + i0, j + j0] == c[d]
1289 and X[ii + i, jj + j] > 0
1297 if f_X[i + i0, j + j0] == c[d]:
1298 X[ii + i, jj + j] = c[d]
1300 def TOO_MESSY_task_islands(self, A, f_A, B, f_B):
1301 c = torch.randperm(self.nb_colors - 1)[:2] + 1
1302 for X, f_X in [(A, f_A), (B, f_B)]:
1303 if not hasattr(self, "cache_islands") or len(self.cache_islands) == 0:
1304 self.cache_islands = list(
1309 nb_seeds=self.height * self.width // 20,
1310 nb_iterations=self.height * self.width // 2,
1314 A = self.cache_islands.pop()
1318 torch.randint(self.height // 2, (1,)).item(),
1319 torch.randint(self.width // 2, (1,)).item(),
1324 X[...] = (A > 0) * c[0]
1325 f_X[...] = (A == A[i, j]) * c[1] + ((A > 0) & (A != A[i, j])) * c[0]
1330 def TOO_HARD_task_stack(self, A, f_A, B, f_B):
1332 c = torch.randperm(self.nb_colors - 1)[:N] + 1
1333 for X, f_X in [(A, f_A), (B, f_B)]:
1335 self.height // 2 - 1,
1336 self.width // 2 - 1,
1337 self.height // 2 + 1,
1338 self.width // 2 + 1,
1340 op = torch.tensor((0, 1, 2, 3) * 4)
1341 op = op[torch.randperm(op.size(0))[:9]]
1342 for q in range(op.size(0)):
1345 d = c[torch.randint(N, (1,)).item()]
1347 if op[q] == 0: # right
1348 X[u : u + 3, v + 2] = d
1349 elif op[q] == 1: # let
1351 elif op[q] == 2: # bottom
1352 X[u + 2, v : v + 3] = d
1353 elif op[q] == 3: # top
1357 f_X[i1:i2, j1:j2] = d
1358 elif op[q] == 0: # right
1361 elif op[q] == 1: # let
1364 elif op[q] == 2: # bottom
1367 elif op[q] == 3: # top
1371 def randint(self, *m):
1373 return (torch.rand(m.size()) * m).long()
1375 def TOO_HARD_task_matrices(self, A, f_A, B, f_B):
1377 c = torch.randperm(self.nb_colors - 1)[:N] + 1
1379 for X, f_X in [(A, f_A), (B, f_B)]:
1380 M1 = torch.randint(2, (5, 5))
1381 M2 = torch.randint(2, (5, 5))
1385 X[i, j] = c[M1[i, j]]
1386 X[i, j + 5] = c[M2[i, j]]
1387 f_X[i, j] = c[M1[i, j]]
1388 f_X[i, j + 5] = c[M2[i, j]]
1389 f_X[i + 5, j + 5] = c[P[i, j]]
1391 def TOO_HARD_task_compute(self, A, f_A, B, f_B):
1393 c = torch.randperm(self.nb_colors - 1)[:N] + 1
1394 for X, f_X in [(A, f_A), (B, f_B)]:
1395 v = torch.randint((self.width - 1) // 2, (N,)) + 1
1396 chain = torch.randperm(N)
1398 for i in range(chain.size(0) - 1):
1399 i1, i2 = chain[i], chain[i + 1]
1400 v1, v2 = v[i1], v[i2]
1401 k = torch.arange(self.width // 2) + 1
1402 d = ((k[None, :] * v1 - k[:, None] * v2) == 0).nonzero() + 1
1403 d = d[torch.randint(d.size(0), (1,)).item()]
1405 eq.append((c[i1], w1, c[i2], w2))
1407 ii = torch.randperm(self.height - 2)[: len(eq)]
1409 for k, x in enumerate(eq):
1412 s = torch.randint(self.width - (w1 + w2) + 1, (1,)).item()
1413 X[i, s : s + w1] = c1
1414 X[i, s + w1 : s + w1 + w2] = c2
1415 f_X[i, s : s + w1] = c1
1416 f_X[i, s + w1 : s + w1 + w2] = c2
1418 i1, i2 = torch.randperm(N)[:2]
1419 v1, v2 = v[i1], v[i2]
1420 k = torch.arange(self.width // 2) + 1
1421 d = ((k[None, :] * v1 - k[:, None] * v2) == 0).nonzero() + 1
1422 d = d[torch.randint(d.size(0), (1,)).item()]
1424 c1, c2 = c[i1], c[i2]
1425 s = 0 # torch.randint(self.width - (w1 + w2) + 1, (1,)).item()
1427 X[i, s : s + w1] = c1
1428 X[i, s + w1 : s + w1 + 1] = c2
1429 f_X[i, s : s + w1] = c1
1430 f_X[i, s + w1 : s + w1 + w2] = c2
1433 # [ai1,ai2] [bi1,bi2]
1434 def task_contact(self, A, f_A, B, f_B):
1436 ai1, aj1, ai2, aj2 = a
1437 bi1, bj1, bi2, bj2 = b
1438 v = max(ai1 - bi2, bi1 - ai2)
1439 h = max(aj1 - bj2, bj1 - aj2)
1440 return min(max(v, 0) + max(h + 1, 0), max(v + 1, 0) + max(h, 0))
1443 c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
1444 for X, f_X in [(A, f_A), (B, f_B)]:
1446 r = self.rec_coo(nb_rec, prevent_overlap=True)
1447 d = [rec_dist(r[0], r[k]) for k in range(nb_rec)]
1451 for n in range(nb_rec):
1452 i1, j1, i2, j2 = r[n]
1453 X[i1:i2, j1:j2] = c[n]
1454 f_X[i1:i2, j1:j2] = c[n]
1456 f_X[i1, j1:j2] = c[0]
1457 f_X[i2 - 1, j1:j2] = c[0]
1458 f_X[i1:i2, j1] = c[0]
1459 f_X[i1:i2, j2 - 1] = c[0]
1462 # [ai1,ai2] [bi1,bi2]
1463 def task_corners(self, A, f_A, B, f_B):
1464 polarity = torch.randint(2, (1,)).item()
1466 c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
1467 for X, f_X in [(A, f_A), (B, f_B)]:
1468 r = self.rec_coo(nb_rec, prevent_overlap=True)
1470 for n in range(nb_rec):
1471 i1, j1, i2, j2 = r[n]
1474 X[i1 + k, j1] = c[n]
1475 X[i2 - 1 - k, j2 - 1] = c[n]
1476 X[i1, j1 + k] = c[n]
1477 X[i2 - 1, j2 - 1 - k] = c[n]
1479 X[i1 + k, j2 - 1] = c[n]
1480 X[i2 - 1 - k, j1] = c[n]
1481 X[i1, j2 - 1 - k] = c[n]
1482 X[i2 - 1, j1 + k] = c[n]
1483 f_X[i1:i2, j1:j2] = c[n]
1485 def compdist(self, X, i, j):
1486 dd = X.new_full((self.height + 2, self.width + 2), self.height * self.width)
1494 d.min(dd[:-2, 1:-1] + 1)
1495 .min(dd[2:, 1:-1] + 1)
1496 .min(dd[1:-1, :-2] + 1)
1497 .min(dd[1:-1, 2:] + 1)
1499 d[...] = (1 - m) * d + m * self.height * self.width
1506 def task_path(self, A, f_A, B, f_B):
1508 c = torch.randperm(self.nb_colors - 1)[: nb_rec + 2] + 1
1509 for X, f_X in [(A, f_A), (B, f_B)]:
1514 r = self.rec_coo(nb_rec, prevent_overlap=True)
1515 for n in range(nb_rec):
1516 i1, j1, i2, j2 = r[n]
1517 X[i1:i2, j1:j2] = c[n]
1518 f_X[i1:i2, j1:j2] = c[n]
1520 i1, i2 = torch.randint(self.height, (2,))
1521 j1, j2 = torch.randint(self.width, (2,))
1523 abs(i1 - i2) + abs(j1 - j2) > 2
1527 d2 = self.compdist(X, i2, j2)
1528 d = self.compdist(X, i1, j1)
1530 if d2[i1, j1] < 2 * self.width:
1533 m = ((d + d2) == d[i2, j2]).long()
1534 f_X[...] = m * c[-1] + (1 - m) * f_X
1542 def task_fill(self, A, f_A, B, f_B):
1544 c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
1545 for X, f_X in [(A, f_A), (B, f_B)]:
1546 accept_full = torch.rand(1) < 0.5
1552 r = self.rec_coo(nb_rec, prevent_overlap=True)
1553 for n in range(nb_rec):
1554 i1, j1, i2, j2 = r[n]
1555 X[i1:i2, j1:j2] = c[n]
1556 f_X[i1:i2, j1:j2] = c[n]
1560 torch.randint(self.height, (1,)).item(),
1561 torch.randint(self.width, (1,)).item(),
1566 d = self.compdist(X, i, j)
1567 m = (d < self.height * self.width).long()
1569 f_X[...] = m * c[-1] + (1 - m) * f_X
1572 if accept_full or (d * (X == 0)).max() == self.height * self.width:
1575 def TOO_HARD_task_addition(self, A, f_A, B, f_B):
1576 c = torch.randperm(self.nb_colors - 1)[:4] + 1
1577 for X, f_X in [(A, f_A), (B, f_B)]:
1578 N1 = torch.randint(2 ** (self.width - 1) - 1, (1,)).item()
1579 N2 = torch.randint(2 ** (self.width - 1) - 1, (1,)).item()
1581 for j in range(self.width):
1582 r1 = (N1 // (2**j)) % 2
1583 X[0, -j - 1] = c[r1]
1584 f_X[0, -j - 1] = c[r1]
1585 r2 = (N2 // (2**j)) % 2
1586 X[1, -j - 1] = c[r2]
1587 f_X[1, -j - 1] = c[r2]
1588 rs = (S // (2**j)) % 2
1589 f_X[2, -j - 1] = c[2 + rs]
1591 def task_science_implicit(self, A, f_A, B, f_B):
1593 c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
1595 for X, f_X in [(A, f_A), (B, f_B)]:
1597 i1, i2 = torch.randint(self.height, (2,)).sort().values
1598 if i1 >= 1 and i2 < self.height and i1 + 3 < i2:
1602 j1, j2 = torch.randint(self.width, (2,)).sort().values
1603 if j1 >= 1 and j2 < self.width and j1 + 3 < j2:
1606 f_X[i1:i2, j1:j2] = c[0]
1608 # ---------------------
1611 ii1, ii2 = torch.randint(self.height, (2,)).sort().values
1612 if ii1 >= i1 and ii2 <= i2 and ii1 + 1 < ii2:
1614 jj = torch.randint(j1, (1,))
1615 X[ii1:ii2, jj:j1] = c[1]
1616 f_X[ii1:ii2, jj:j1] = c[1]
1619 ii1, ii2 = torch.randint(self.height, (2,)).sort().values
1620 if ii1 >= i1 and ii2 <= i2 and ii1 + 1 < ii2:
1622 jj = torch.randint(self.width - j2, (1,)) + j2 + 1
1623 X[ii1:ii2, j2:jj] = c[2]
1624 f_X[ii1:ii2, j2:jj] = c[2]
1626 # ---------------------
1629 jj1, jj2 = torch.randint(self.width, (2,)).sort().values
1630 if jj1 >= j1 and jj2 <= j2 and jj1 + 1 < jj2:
1632 ii = torch.randint(i1, (1,))
1633 X[ii:i1, jj1:jj2] = c[3]
1634 f_X[ii:i1, jj1:jj2] = c[3]
1637 jj1, jj2 = torch.randint(self.width, (2,)).sort().values
1638 if jj1 >= j1 and jj2 <= j2 and jj1 + 1 < jj2:
1640 ii = torch.randint(self.height - i2, (1,)) + i2 + 1
1641 X[i2:ii, jj1:jj2] = c[4]
1642 f_X[i2:ii, jj1:jj2] = c[4]
1644 def task_science_dot(self, A, f_A, B, f_B):
1646 c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
1647 for X, f_X in [(A, f_A), (B, f_B)]:
1651 r = self.rec_coo(nb_rec, prevent_overlap=True)
1653 torch.randint(self.height, (1,)).item(),
1654 torch.randint(self.width, (1,)).item(),
1657 for n in range(nb_rec):
1658 i1, j1, i2, j2 = r[n]
1659 X[i1:i2, j1:j2] = c[n]
1660 f_X[i1:i2, j1:j2] = c[n]
1661 if i >= i1 and i < i2:
1663 f_X[i, j1:j2] = c[-1]
1664 if j >= j1 and j < j2:
1666 f_X[i1:i2, j] = c[-1]
1672 def collide(self, s, r, rs):
1675 if abs(i - i2) < s and abs(j - j2) < s:
1679 def task_science_tag(self, A, f_A, B, f_B):
1680 c = torch.randperm(self.nb_colors - 1)[:4] + 1
1681 for X, f_X in [(A, f_A), (B, f_B)]:
1685 torch.randint(self.height - 3, (1,)).item(),
1686 torch.randint(self.width - 3, (1,)).item(),
1688 if not self.collide(s=3, r=(i, j), rs=rs):
1691 for k in range(len(rs)):
1694 X[i, j : j + 3] = c[q]
1695 X[i + 2, j : j + 3] = c[q]
1696 X[i : i + 3, j] = c[q]
1697 X[i : i + 3, j + 2] = c[q]
1699 f_X[i, j : j + 3] = c[q]
1700 f_X[i + 2, j : j + 3] = c[q]
1701 f_X[i : i + 3, j] = c[q]
1702 f_X[i : i + 3, j + 2] = c[q]
1704 f_X[i + 1, j + 1] = c[-1]
1708 ######################################################################
1710 def create_empty_quizzes(self, nb, struct=("A", "f_A", "B", "f_B")):
1711 S = self.height * self.width
1712 quizzes = torch.zeros(nb, 4 * (S + 1), dtype=torch.int64)
1713 quizzes[:, 0 * (S + 1)] = self.l2tok[struct[0]]
1714 quizzes[:, 1 * (S + 1)] = self.l2tok[struct[1]]
1715 quizzes[:, 2 * (S + 1)] = self.l2tok[struct[2]]
1716 quizzes[:, 3 * (S + 1)] = self.l2tok[struct[3]]
1720 def generate_w_quizzes_(self, nb, tasks=None, progress_bar=False):
1721 S = self.height * self.width
1724 tasks = self.all_tasks
1726 quizzes = self.create_empty_quizzes(nb, ("A", "f_A", "B", "f_B"))
1729 quizzes = tqdm.tqdm(
1732 desc="world quizzes generation",
1733 total=quizzes.size(0),
1736 for quiz in quizzes:
1737 q = quiz.reshape(4, S + 1)[:, 1:].reshape(4, self.height, self.width)
1740 task = tasks[torch.randint(len(tasks), (1,)).item()]
1741 task(A, f_A, B, f_B)
1745 def save_some_examples(self, result_dir, prefix=""):
1747 for t in self.all_tasks:
1749 quizzes = self.generate_w_quizzes_(nb, tasks=[t])
1750 self.save_quizzes_as_image(
1751 result_dir, prefix + t.__name__ + ".png", quizzes, nrow=nrow
1755 ######################################################################
1757 if __name__ == "__main__":
1760 # grids = Grids(max_nb_cached_chunks=5, chunk_size=100, nb_threads=4)
1765 # quizzes = grids.generate_w_quizzes_(nb, tasks=[grids.task_fill])
1767 # print(grids.get_structure(quizzes))
1768 # quizzes = grids.reconfigure(quizzes, struct=("A", "B", "f_A", "f_B"))
1769 # print("DEBUG2", quizzes)
1770 # print(grids.get_structure(quizzes))
1773 # i = torch.rand(quizzes.size(0)) < 0.5
1775 # quizzes[i] = grids.reconfigure(quizzes[i], struct=("f_B", "f_A", "B", "A"))
1777 # j = grids.indices_select(quizzes, struct=("f_B", "f_A", "B", "A"))
1781 # grids.get_structure(quizzes[j]),
1782 # grids.get_structure(quizzes[j == False]),
1788 # grids = problem.MultiThreadProblem(
1789 # grids, max_nb_cached_chunks=50, chunk_size=100, nb_threads=1
1792 # start_time = time.perf_counter()
1793 # prompts, answers = grids.generate_w_quizzes(nb)
1794 # delay = time.perf_counter() - start_time
1795 # print(f"{prompts.size(0)/delay:02f} seq/s")
1802 # for t in grids.all_tasks:
1804 for t in [grids.task_recworld_immobile]:
1806 w_quizzes = grids.generate_w_quizzes_(nb, tasks=[t])
1807 grids.save_quizzes_as_image(
1809 t.__name__ + ".png",
1811 comments=[f"{t.__name__} #{k}" for k in range(w_quizzes.size(0))],
1819 # grids.task_bounce,
1820 # grids.task_contact,
1821 # grids.task_corners,
1822 # grids.task_detect,
1826 # grids.task_half_fill,
1827 # grids.task_isometry,
1829 # grids.task_replace_color,
1832 # grids.task_trajectory,
1833 # grids.task_translate,
1835 # for t in [grids.task_path]:
1836 start_time = time.perf_counter()
1837 w_quizzes = grids.generate_w_quizzes_(nb, tasks=[t])
1838 delay = time.perf_counter() - start_time
1839 print(f"{t.__name__} {w_quizzes.size(0)/delay:02f} seq/s")
1840 grids.save_quizzes_as_image("/tmp", t.__name__ + ".png", w_quizzes[:128])
1844 m = torch.randint(2, (prompts.size(0),))
1845 predicted_prompts = m * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
1846 predicted_answers = (1 - m) * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
1848 grids.save_quizzes_as_image(
1853 # You can add a bool to put a frame around the predicted parts
1854 predicted_prompts[:nb],
1855 predicted_answers[:nb],