3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, tqdm, os, warnings
10 import torch, torchvision
13 from torch.nn import functional as F
15 ######################################################################
20 class Grids(problem.Problem):
22 ("white", [255, 255, 255]),
24 ("green", [0, 192, 0]),
25 ("blue", [0, 0, 255]),
26 ("yellow", [255, 224, 0]),
27 ("cyan", [0, 255, 255]),
28 ("violet", [224, 128, 255]),
29 ("lightgreen", [192, 255, 192]),
30 ("brown", [165, 42, 42]),
31 ("lightblue", [192, 192, 255]),
32 ("gray", [128, 128, 128]),
37 max_nb_cached_chunks=None,
42 self.colors = torch.tensor([c for _, c in self.named_colors])
45 self.cache_rec_coo = {}
48 self.task_replace_color,
64 self.all_tasks = all_tasks
66 self.all_tasks = [getattr(self, "task_" + t) for t in tasks.split(",")]
68 super().__init__(max_nb_cached_chunks, chunk_size, nb_threads)
70 ######################################################################
72 def frame2img(self, x, scale=15):
73 x = x.reshape(x.size(0), self.height, -1)
74 m = torch.logical_and(x >= 0, x < self.nb_token_values()).long()
75 x = self.colors[x * m].permute(0, 3, 1, 2)
77 x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
78 x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
80 x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
81 x[:, :, torch.arange(0, x.size(2), scale), :] = 0
84 for n in range(m.size(0)):
85 for i in range(m.size(1)):
86 for j in range(m.size(2)):
88 for k in range(2, scale - 2):
90 x[n, :, i * scale + k, j * scale + k - l] = 0
92 n, :, i * scale + scale - 1 - k, j * scale + k - l
103 predicted_prompts=None,
104 predicted_answers=None,
108 S = self.height * self.width
109 As = prompts[:, 0 * (S + 1) : 0 * (S + 1) + S].view(-1, self.height, self.width)
110 f_As = prompts[:, 1 * (S + 1) : 1 * (S + 1) + S].view(
111 -1, self.height, self.width
113 Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S].view(-1, self.height, self.width)
114 prompts = torch.cat([As, f_As, Bs], dim=2)
115 answers = answers.reshape(answers.size(0), self.height, self.width)
117 if predicted_prompts is None:
118 predicted_prompts = 255
120 if predicted_answers is None:
121 predicted_answers = 255
123 def add_frame(x, c, margin, bottom=False):
125 h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0
128 x.size(2) + 2 * margin,
129 x.size(3) + 2 * margin,
134 y = x.new_full((x.size(0), x.size(1), h, w), 0)
139 c = c.long()[:, None]
141 (1 - ((c == 1).long() + (c == 0).long() + (c == -1).long()))
142 * torch.tensor([64, 64, 64])
143 + (c == 1).long() * torch.tensor([0, 255, 0])
144 + (c == 0).long() * torch.tensor([255, 255, 255])
145 + (c == -1).long() * torch.tensor([255, 0, 0])
147 y[...] = c[:, :, None, None]
149 y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x
153 img_prompts = torch.cat(
156 add_frame(self.frame2img(x), c=0, margin=1),
160 for x in prompts.to("cpu").split(split_size=self.width, dim=2)
165 h = img_prompts.size(2)
166 img_answers = add_frame(
167 add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1),
172 separator_size = 2 * margin
174 separator = img_prompts.new_full(
184 marker = img_prompts.new_full(
194 # marker[:, :, 0] = 0
195 # marker[:, :, h - 1] = 0
197 for k in range(1, 2 * separator_size - 8):
198 i = k - (separator_size - 4)
199 j = separator_size - 5 - abs(i)
200 marker[:, :, h // 2 - 1 + i, 2 + j] = 0
201 marker[:, :, h // 2 - 1 + i + 1, 2 + j] = 0
212 image_name = os.path.join(result_dir, filename)
213 torchvision.utils.save_image(
221 ######################################################################
223 def nb_token_values(self):
224 return len(self.colors)
233 prevent_overlap=False,
235 if surface_max is None:
236 surface_max = self.height * self.width // 2
238 signature = (nb_rec, min_height, min_width, surface_max)
241 return self.cache_rec_coo[signature].pop()
250 i = torch.randint(self.height, (N * nb_rec, 2)).sort(dim=-1).values
251 j = torch.randint(self.width, (N * nb_rec, 2)).sort(dim=-1).values
254 (i[:, 1] >= i[:, 0] + min_height)
255 & (j[:, 1] >= j[:, 0] + min_height)
256 & ((i[:, 1] - i[:, 0]) * (j[:, 1] - j[:, 0]) <= surface_max)
259 i, j = i[big_enough], j[big_enough]
261 n = i.size(0) - i.size(0) % nb_rec
266 i = i[:n].reshape(n // nb_rec, nb_rec, -1)
267 j = j[:n].reshape(n // nb_rec, nb_rec, -1)
270 can_fit = ((i[:, :, 1] - i[:, :, 0]) * (j[:, :, 1] - j[:, :, 0])).sum(
272 ) <= self.height * self.width
273 i, j = i[can_fit], j[can_fit]
275 A_i1, A_i2, A_j1, A_j2 = (
281 B_i1, B_i2, B_j1, B_j2 = (
287 no_overlap = torch.logical_not(
293 i, j = i[no_overlap], j[no_overlap]
295 A_i1, A_i2, A_j1, A_j2 = (
301 B_i1, B_i2, B_j1, B_j2 = (
307 C_i1, C_i2, C_j1, C_j2 = (
333 i, j = (i[no_overlap], j[no_overlap])
340 self.cache_rec_coo[signature] = [
348 for k in range(nb_rec)
350 for n in range(i.size(0))
353 return self.cache_rec_coo[signature].pop()
355 ######################################################################
358 def task_replace_color(self, A, f_A, B, f_B):
360 c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
361 for X, f_X in [(A, f_A), (B, f_B)]:
362 r = self.rec_coo(nb_rec, prevent_overlap=True)
363 for n in range(nb_rec):
364 i1, j1, i2, j2 = r[n]
365 X[i1:i2, j1:j2] = c[n]
366 f_X[i1:i2, j1:j2] = c[n if n > 0 else -1]
369 def task_translate(self, A, f_A, B, f_B):
371 di, dj = torch.randint(3, (2,)) - 1
372 if di.abs() + dj.abs() > 0:
376 c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
377 for X, f_X in [(A, f_A), (B, f_B)]:
379 r = self.rec_coo(nb_rec, prevent_overlap=True)
380 i1, j1, i2, j2 = r[nb_rec - 1]
383 and i2 + di < X.size(0)
385 and j2 + dj < X.size(1)
389 for n in range(nb_rec):
390 i1, j1, i2, j2 = r[n]
391 X[i1:i2, j1:j2] = c[n]
393 f_X[i1 + di : i2 + di, j1 + dj : j2 + dj] = c[n]
395 f_X[i1:i2, j1:j2] = c[n]
398 def task_grow(self, A, f_A, B, f_B):
399 di, dj = torch.randint(2, (2,)) * 2 - 1
401 c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
402 direction = torch.randint(2, (1,))
403 for X, f_X in [(A, f_A), (B, f_B)]:
405 r = self.rec_coo(nb_rec, prevent_overlap=True)
406 i1, j1, i2, j2 = r[nb_rec - 1]
407 if i1 + 3 < i2 and j1 + 3 < j2:
410 for n in range(nb_rec):
411 i1, j1, i2, j2 = r[n]
414 X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
415 f_X[i1:i2, j1:j2] = c[n]
417 X[i1:i2, j1:j2] = c[n]
418 f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
420 X[i1:i2, j1:j2] = c[n]
421 f_X[i1:i2, j1:j2] = c[n]
424 def task_half_fill(self, A, f_A, B, f_B):
425 di, dj = torch.randint(2, (2,)) * 2 - 1
427 c = torch.randperm(len(self.colors) - 1)[: 2 * nb_rec] + 1
428 direction = torch.randint(4, (1,))
429 for X, f_X in [(A, f_A), (B, f_B)]:
430 r = self.rec_coo(nb_rec, prevent_overlap=True)
431 for n in range(nb_rec):
432 i1, j1, i2, j2 = r[n]
433 X[i1:i2, j1:j2] = c[2 * n]
434 f_X[i1:i2, j1:j2] = c[2 * n]
435 # Not my proudest moment
438 X[i : i + 1, j1:j2] = c[2 * n + 1]
440 f_X[i:i2, j1:j2] = c[2 * n + 1]
442 f_X[i : i + 1, j1:j2] = c[2 * n + 1]
444 i = (i1 + i2 - 1) // 2
445 X[i : i + 1, j1:j2] = c[2 * n + 1]
447 f_X[i1 : i + 1, j1:j2] = c[2 * n + 1]
449 f_X[i : i + 1, j1:j2] = c[2 * n + 1]
452 X[i1:i2, j : j + 1] = c[2 * n + 1]
454 f_X[i1:i2, j:j2] = c[2 * n + 1]
456 f_X[i1:i2, j : j + 1] = c[2 * n + 1]
458 j = (j1 + j2 - 1) // 2
459 X[i1:i2, j : j + 1] = c[2 * n + 1]
461 f_X[i1:i2, j1 : j + 1] = c[2 * n + 1]
463 f_X[i1:i2, j : j + 1] = c[2 * n + 1]
466 def task_frame(self, A, f_A, B, f_B):
468 c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
469 for X, f_X in [(A, f_A), (B, f_B)]:
470 r = self.rec_coo(nb_rec, prevent_overlap=True)
471 for n in range(nb_rec):
472 i1, j1, i2, j2 = r[n]
473 X[i1:i2, j1:j2] = c[n]
475 f_X[i1:i2, j1] = c[n]
476 f_X[i1:i2, j2 - 1] = c[n]
477 f_X[i1, j1:j2] = c[n]
478 f_X[i2 - 1, j1:j2] = c[n]
480 f_X[i1:i2, j1:j2] = c[n]
483 def task_detect(self, A, f_A, B, f_B):
485 c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
486 for X, f_X in [(A, f_A), (B, f_B)]:
487 r = self.rec_coo(nb_rec, prevent_overlap=True)
488 for n in range(nb_rec):
489 i1, j1, i2, j2 = r[n]
490 X[i1:i2, j1:j2] = c[n]
495 def contact(self, X, i, j, q):
509 if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width:
510 if X[ii, jj] != 0 and X[ii, jj] != q:
519 if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width:
520 if X[ii, jj] == q and X[i, jj] != q and X[ii, j] != q:
523 for ii, jj in [(i - 1, j), (i, j - 1), (i, j + 1), (i + 1, j)]:
524 if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width:
528 return no, nq, nq_diag
530 def task_count(self, A, f_A, B, f_B):
531 N = (torch.randint(4, (1,)) + 2).item()
532 c = torch.randperm(len(self.colors) - 1)[:N] + 1
534 for X, f_X in [(A, f_A), (B, f_B)]:
535 l_q = torch.randperm(self.height * self.width)[
536 : self.height * self.width // 20
538 l_d = torch.randint(N, l_q.size())
539 nb = torch.zeros(N, dtype=torch.int64)
541 for q, e in zip(l_q, l_d):
543 i, j = q % self.height, q // self.height
546 and X[max(0, i - 1) : i + 2, max(0, j - 1) : j + 2] == 0
551 l_q = torch.randperm((self.height - 2) * (self.width - 2))[
552 : self.height * self.width // 2
554 l_d = torch.randint(N, l_q.size())
555 for q, e in zip(l_q, l_d):
557 i, j = q % (self.height - 2) + 1, q // (self.height - 2) + 1
558 a1, a2, a3 = X[i - 1, j - 1 : j + 2]
559 a8, a4 = X[i, j - 1], X[i, j + 1]
560 a7, a6, a5 = X[i + 1, j - 1 : j + 2]
563 and nb[e] < self.width
564 and (a2 == 0 or a2 == d)
565 and (a4 == 0 or a4 == d)
566 and (a6 == 0 or a6 == d)
567 and (a8 == 0 or a8 == d)
568 and (a1 == 0 or a2 == d or a8 == d)
569 and (a3 == 0 or a4 == d or a2 == d)
570 and (a5 == 0 or a6 == d or a4 == d)
571 and (a7 == 0 or a8 == d or a6 == d)
584 for j in range(nb[e]):
588 def task_trajectory(self, A, f_A, B, f_B):
589 c = torch.randperm(len(self.colors) - 1)[:2] + 1
590 for X, f_X in [(A, f_A), (B, f_B)]:
592 di, dj = torch.randint(7, (2,)) - 3
593 i, j = torch.randint(self.height, (1,)), torch.randint(self.width, (1,))
595 abs(di) + abs(dj) > 0
597 and i + 2 * di < self.height
599 and j + 2 * dj < self.width
606 and i + k * di < self.height
608 and j + k * dj < self.width
611 X[i + k * di, j + k * dj] = c[k]
612 f_X[i + k * di, j + k * dj] = c[min(k, 1)]
616 def task_bounce(self, A, f_A, B, f_B):
617 c = torch.randperm(len(self.colors) - 1)[:3] + 1
618 for X, f_X in [(A, f_A), (B, f_B)]:
633 for _ in range((self.height * self.width) // 10):
634 i, j = torch.randint(self.height, (1,)), torch.randint(
641 di, dj = torch.randint(7, (2,)) - 3
642 if abs(di) + abs(dj) == 1:
645 i, j = torch.randint(self.height, (1,)), torch.randint(self.width, (1,))
653 if free(i + di, j + dj):
655 elif free(i - dj, j + di):
657 if free(i + dj, j - di):
658 if torch.rand(1) < 0.5:
660 elif free(i + dj, j - di):
665 i, j = i + di, j + dj
680 def task_scale(self, A, f_A, B, f_B):
681 c = torch.randperm(len(self.colors) - 1)[:2] + 1
683 i, j = torch.randint(self.height // 2, (1,)), torch.randint(
684 self.width // 2, (1,)
687 for X, f_X in [(A, f_A), (B, f_B)]:
690 i1, j1 = torch.randint(self.height // 2 + 1, (1,)), torch.randint(
691 self.width // 2 + 1, (1,)
693 i2, j2 = torch.randint(self.height // 2 + 1, (1,)), torch.randint(
694 self.width // 2 + 1, (1,)
696 if i1 < i2 and j1 < j2 and min(i2 - i1, j2 - j1) <= 3:
698 X[i + i1 : i + i2, j + j1 : j + j2] = c[0]
699 f_X[2 * i1 : 2 * i2, 2 * j1 : 2 * j2] = c[0]
705 def task_symbols(self, A, f_A, B, f_B):
707 c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
709 for X, f_X in [(A, f_A), (B, f_B)]:
711 i, j = torch.randint(self.height - delta + 1, (nb_rec,)), torch.randint(
712 self.width - delta + 1, (nb_rec,)
714 d = (i[None, :] - i[:, None]).abs().max((j[None, :] - j[:, None]).abs())
715 d.fill_diagonal_(delta + 1)
719 for k in range(1, nb_rec):
720 X[i[k] : i[k] + delta, j[k] : j[k] + delta] = c[k]
722 ai, aj = i.float().mean(), j.float().mean()
724 q = torch.randint(3, (1,)) + 1
726 X[i[0] + delta // 2 - 1, j[0] + delta // 2 - 1] = c[0]
727 X[i[0] + delta // 2 - 1, j[0] + delta // 2 + 1] = c[0]
728 X[i[0] + delta // 2 + 1, j[0] + delta // 2 - 1] = c[0]
729 X[i[0] + delta // 2 + 1, j[0] + delta // 2 + 1] = c[0]
731 assert i[q] != ai and j[q] != aj
734 i[0] + delta // 2 + (i[q] - ai).sign().long(),
735 j[0] + delta // 2 + (j[q] - aj).sign().long(),
738 f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q]
741 def task_isometry(self, A, f_A, B, f_B):
743 di, dj = torch.randint(3, (2,)) - 1
744 o = torch.tensor([[0.0, 1.0], [-1.0, 0.0]])
746 for _ in range(torch.randint(4, (1,))):
748 if torch.rand(1) < 0.5:
751 ci, cj = (self.height - 1) / 2, (self.width - 1) / 2
753 for X, f_X in [(A, f_A), (B, f_B)]:
758 c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
760 for r in range(nb_rec):
762 i1, i2 = torch.randint(self.height - 2, (2,)) + 1
763 j1, j2 = torch.randint(self.width - 2, (2,)) + 1
767 and max(i2 - i1, j2 - j1) >= 2
768 and min(i2 - i1, j2 - j1) <= 3
771 X[i1 : i2 + 1, j1 : j2 + 1] = c[r]
773 i1, j1, i2, j2 = i1 - ci, j1 - cj, i2 - ci, j2 - cj
775 i1, j1 = m[0, 0] * i1 + m[0, 1] * j1, m[1, 0] * i1 + m[1, 1] * j1
776 i2, j2 = m[0, 0] * i2 + m[0, 1] * j2, m[1, 0] * i2 + m[1, 1] * j2
778 i1, j1, i2, j2 = i1 + ci, j1 + cj, i2 + ci, j2 + cj
779 i1, i2 = i1.long() + di, i2.long() + di
780 j1, j2 = j1.long() + dj, j2.long() + dj
786 f_X[i1 : i2 + 1, j1 : j2 + 1] = c[r]
788 n = F.one_hot(X.flatten()).sum(dim=0)[1:]
790 n.sum() > self.height * self.width // 4
791 and (n > 0).long().sum() == nb_rec
795 def compute_distance(self, walls, goal_i, goal_j, start_i, start_j):
796 max_length = walls.numel()
797 dist = torch.full_like(walls, max_length)
799 dist[goal_i, goal_j] = 0
800 pred_dist = torch.empty_like(dist)
803 pred_dist.copy_(dist)
807 dist[None, 1:-1, 0:-2],
808 dist[None, 2:, 1:-1],
809 dist[None, 1:-1, 2:],
810 dist[None, 0:-2, 1:-1],
817 dist[1:-1, 1:-1].minimum_(d) # = torch.min(dist[1:-1, 1:-1], d)
818 dist = walls * max_length + (1 - walls) * dist
820 if dist[start_i, start_j] < max_length or dist.equal(pred_dist):
821 return dist * (1 - walls)
824 def task_path(self, A, f_A, B, f_B):
825 c = torch.randperm(len(self.colors) - 1)[:3] + 1
826 dist = torch.empty(self.height + 2, self.width + 2)
827 for X, f_X in [(A, f_A), (B, f_B)]:
828 nb_rec = torch.randint(3, (1,)) + 1
830 r = self.rec_coo(nb_rec, prevent_overlap=True)
833 for n in range(nb_rec):
834 i1, j1, i2, j2 = r[n]
835 X[i1:i2, j1:j2] = c[0]
836 f_X[i1:i2, j1:j2] = c[0]
838 i0, j0 = torch.randint(self.height, (1,)), torch.randint(
844 i1, j1 = torch.randint(self.height, (1,)), torch.randint(
850 dist[1:-1, 1:-1] = (X != 0).long()
851 dist[...] = self.compute_distance(dist, i1 + 1, j1 + 1, i0 + 1, j0 + 1)
852 if dist[i0 + 1, j0 + 1] >= 1 and dist[i0 + 1, j0 + 1] < self.height * 4:
855 dist[1:-1, 1:-1] += (X != 0).long() * self.height * self.width
856 dist[0, :] = self.height * self.width
857 dist[-1, :] = self.height * self.width
858 dist[:, 0] = self.height * self.width
859 dist[:, -1] = self.height * self.width
860 # dist += torch.rand(dist.size())
862 i, j = i0 + 1, j0 + 1
863 while i != i1 + 1 or j != j1 + 1:
864 f_X[i - 1, j - 1] = c[2]
887 # for X, f_X in [(A, f_A), (B, f_B)]:
888 # n = torch.arange(self.height * self.width).reshape(self.height, self.width)
889 # k = torch.randperm(self.height * self.width)
892 # i,j=q%self.height,q//self.height
896 def task_puzzle(self, A, f_A, B, f_B):
898 i0, j0 = (self.height - S) // 2, (self.width - S) // 2
899 c = torch.randperm(len(self.colors) - 1)[:4] + 1
900 for X, f_X in [(A, f_A), (B, f_B)]:
903 h = list(torch.randperm(c.size(0)))
904 n = torch.zeros(c.max() + 1)
906 k = torch.randperm(S * S)
908 i, j = q % S + i0, q // S + j0
916 r, s, t, u = torch.tensor([r, s, t, u])[torch.randperm(4)]
917 if r > 0 and n[r] < 6:
920 elif s > 0 and n[s] < 6:
923 elif t > 0 and n[t] < 6:
926 elif u > 0 and n[u] < 6:
941 ii, jj = torch.randint(self.height, (1,)), torch.randint(
948 ii + i >= self.height
949 or jj + j >= self.width
951 f_X[i + i0, j + j0] == c[d]
952 and X[ii + i, jj + j] > 0
960 if f_X[i + i0, j + j0] == c[d]:
961 X[ii + i, jj + j] = c[d]
963 ######################################################################
965 def trivial_prompts_and_answers(self, prompts, answers):
966 S = self.height * self.width
967 Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S]
969 return (Bs == f_Bs).long().min(dim=-1).values > 0
971 def generate_prompts_and_answers_(self, nb, tasks=None, progress_bar=False):
973 tasks = self.all_tasks
975 S = self.height * self.width
976 prompts = torch.zeros(nb, 3 * S + 2, dtype=torch.int64)
977 answers = torch.zeros(nb, S, dtype=torch.int64)
979 bunch = zip(prompts, answers)
985 desc="world generation",
986 total=prompts.size(0),
989 for prompt, answer in bunch:
990 A = prompt[0 * (S + 1) : 0 * (S + 1) + S].view(self.height, self.width)
991 f_A = prompt[1 * (S + 1) : 1 * (S + 1) + S].view(self.height, self.width)
992 B = prompt[2 * (S + 1) : 2 * (S + 1) + S].view(self.height, self.width)
993 f_B = answer.view(self.height, self.width)
994 task = tasks[torch.randint(len(tasks), (1,))]
997 return prompts.flatten(1), answers.flatten(1)
1005 predicted_prompts=None,
1006 predicted_answers=None,
1011 filename_prefix + ".png",
1019 def save_some_examples(self, result_dir):
1021 for t in self.all_tasks:
1023 prompts, answers = self.generate_prompts_and_answers_(nb, tasks=[t])
1025 result_dir, t.__name__, prompts[:nb], answers[:nb], nrow=nrow
1029 ######################################################################
1031 if __name__ == "__main__":
1034 # grids = Grids(max_nb_cached_chunks=5, chunk_size=100, nb_threads=4)
1038 # grids = problem.MultiThreadProblem(
1039 # grids, max_nb_cached_chunks=50, chunk_size=100, nb_threads=1
1042 # start_time = time.perf_counter()
1043 # prompts, answers = grids.generate_prompts_and_answers(nb)
1044 # delay = time.perf_counter() - start_time
1045 # print(f"{prompts.size(0)/delay:02f} seq/s")
1052 # for t in grids.all_tasks:
1054 grids.task_replace_color,
1058 prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
1059 grids.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow)
1065 for t in grids.all_tasks:
1066 # for t in [ grids.task_replace_color ]: #grids.all_tasks:
1067 start_time = time.perf_counter()
1068 prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
1069 delay = time.perf_counter() - start_time
1070 print(f"{t.__name__} {prompts.size(0)/delay:02f} seq/s")
1074 m = torch.randint(2, (prompts.size(0),))
1075 predicted_prompts = m * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
1076 predicted_answers = (1 - m) * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
1083 # You can add a bool to put a frame around the predicted parts
1084 predicted_prompts[:nb],
1085 predicted_answers[:nb],