Update.
[culture.git] / reasoning.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, tqdm, os, warnings
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 ######################################################################
16
17 import problem
18
19
20 class Reasoning(problem.Problem):
21     named_colors = [
22         ("white", [255, 255, 255]),
23         ("red", [255, 0, 0]),
24         ("green", [0, 192, 0]),
25         ("blue", [0, 0, 255]),
26         ("orange", [255, 192, 0]),
27         ("cyan", [0, 255, 255]),
28         ("violet", [255, 0, 255]),
29         ("lightgreen", [192, 255, 192]),
30         ("brown", [165, 42, 42]),
31         ("lightblue", [192, 192, 255]),
32         ("gray", [128, 128, 128]),
33     ]
34
35     def __init__(self, device=torch.device("cpu")):
36         self.colors = torch.tensor([c for _, c in self.named_colors])
37         self.name2color = dict([(p[0], i) for i, p in enumerate(self.named_colors)])
38         self.height = 10
39         self.width = 10
40         self.device = device
41
42     ######################################################################
43
44     def frame2img(self, x, scale=15):
45         x = x.reshape(x.size(0), self.height, -1)
46         m = torch.logical_and(x >= 0, x < self.nb_token_values()).long()
47         x = self.colors[x * m].permute(0, 3, 1, 2)
48         s = x.shape
49         x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
50         x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
51
52         x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
53         x[:, :, torch.arange(0, x.size(2), scale), :] = 0
54         x = x[:, :, 1:, 1:]
55
56         for n in range(m.size(0)):
57             for i in range(m.size(1)):
58                 for j in range(m.size(2)):
59                     if m[n, i, j] == 0:
60                         for k in range(2, scale - 2):
61                             for l in [0, 1]:
62                                 x[n, :, i * scale + k, j * scale + k - l] = 0
63                                 x[
64                                     n, :, i * scale + scale - 1 - k, j * scale + k - l
65                                 ] = 0
66
67         return x
68
69     def frame2img_(self, x, scale=15):
70         x = x.reshape(x.size(0), self.height, -1)
71         x = self.colors[x].permute(0, 3, 1, 2)
72         s = x.shape
73         x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
74         x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
75
76         x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
77         x[:, :, torch.arange(0, x.size(2), scale), :] = 0
78         x = x[:, :, 1:, 1:]
79
80         return x
81
82     def save_image(
83         self,
84         result_dir,
85         filename,
86         prompts,
87         answers,
88         predicted_prompts=None,
89         predicted_answers=None,
90         nrow=4,
91     ):
92         prompts = prompts.reshape(prompts.size(0), self.height, -1)
93         answers = answers.reshape(answers.size(0), self.height, -1)
94
95         if predicted_prompts is None:
96             predicted_prompts = 255
97
98         if predicted_answers is None:
99             predicted_answers = 255
100
101         def add_frame(x, c, margin, bottom=False):
102             if bottom:
103                 h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0
104             else:
105                 h, w, di, dj = (
106                     x.size(2) + 2 * margin,
107                     x.size(3) + 2 * margin,
108                     margin,
109                     margin,
110                 )
111
112             y = x.new_full((x.size(0), x.size(1), h, w), 0)
113
114             if type(c) is int:
115                 y[...] = c
116             else:
117                 c = c.long()[:, None]
118                 c = (
119                     (1 - ((c == 1).long() + (c == 0).long() + (c == -1).long()))
120                     * torch.tensor([192, 192, 192], device=c.device)
121                     + (c == 1).long() * torch.tensor([0, 255, 0], device=c.device)
122                     + (c == 0).long() * torch.tensor([255, 255, 255], device=c.device)
123                     + (c == -1).long() * torch.tensor([255, 0, 0], device=c.device)
124                 )
125                 y[...] = c[:, :, None, None]
126
127             y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x
128
129             return y
130
131         margin = 8
132
133         img_prompts = torch.cat(
134             [
135                 add_frame(
136                     add_frame(self.frame2img(x), c=0, margin=1),
137                     c=predicted_prompts,
138                     margin=margin,
139                 )
140                 for x in prompts.to("cpu").split(split_size=self.width, dim=2)
141             ],
142             dim=3,
143         )
144
145         h = img_prompts.size(2)
146         img_answers = add_frame(
147             add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1),
148             c=predicted_answers,
149             margin=margin,
150         )
151
152         separator_size = 2 * margin
153
154         separator = img_prompts.new_full(
155             (
156                 img_prompts.size(0),
157                 img_prompts.size(1),
158                 img_prompts.size(2),
159                 separator_size,
160             ),
161             255,
162         )
163
164         marker = img_prompts.new_full(
165             (
166                 img_prompts.size(0),
167                 img_prompts.size(1),
168                 img_prompts.size(2),
169                 separator_size,
170             ),
171             255,
172         )
173
174         # marker[:, :, 0] = 0
175         # marker[:, :, h - 1] = 0
176
177         for k in range(1, 2 * separator_size - 8):
178             i = k - (separator_size - 4)
179             j = separator_size - 5 - abs(i)
180             marker[:, :, h // 2 - 1 + i, 2 + j] = 0
181             marker[:, :, h // 2 - 1 + i + 1, 2 + j] = 0
182
183         img = torch.cat(
184             [
185                 img_prompts,
186                 marker,
187                 img_answers,
188             ],
189             dim=3,
190         )
191
192         image_name = os.path.join(result_dir, filename)
193         torchvision.utils.save_image(
194             img.float() / 255.0,
195             image_name,
196             nrow=nrow,
197             padding=margin * 4,
198             pad_value=1.0,
199         )
200
201     ######################################################################
202
203     def nb_token_values(self):
204         return len(self.colors)
205
206     # That's quite a tensorial spaghetti mess to sample
207     # non-overlapping rectangles quickly, but made the generation of
208     # 100k samples go from 1h50 with a lame pure python code to 3min30s
209     # with this one.
210     def rec_coo(self, nb_rec, min_height=3, min_width=3):
211         nb_trials = 200
212
213         while True:
214             v = (
215                 (
216                     torch.rand(nb_trials * nb_rec, self.height + 1, device=self.device)
217                     .sort(dim=-1)
218                     .indices
219                     < 2
220                 )
221                 .long()
222                 .cumsum(dim=1)
223                 == 1
224             ).long()
225
226             h = (
227                 (
228                     torch.rand(nb_trials * nb_rec, self.width + 1, device=self.device)
229                     .sort(dim=-1)
230                     .indices
231                     < 2
232                 )
233                 .long()
234                 .cumsum(dim=1)
235                 == 1
236             ).long()
237
238             i = torch.logical_and(
239                 v.sum(dim=-1) >= min_height, h.sum(dim=-1) >= min_width
240             )
241
242             v, h = v[i], h[i]
243             v = v[: v.size(0) - v.size(0) % nb_rec]
244             h = h[: h.size(0) - h.size(0) % nb_rec]
245             v = v.reshape(v.size(0) // nb_rec, nb_rec, -1)
246             h = h.reshape(h.size(0) // nb_rec, nb_rec, -1)
247
248             r = v[:, :, :, None] * h[:, :, None, :]
249
250             valid = r.sum(dim=1).flatten(1).max(dim=-1).values == 1
251
252             v = v[valid]
253             h = h[valid]
254
255             if v.size(0) > 0:
256                 break
257
258         av = torch.arange(v.size(2), device=self.device)[None, :]
259         ah = torch.arange(h.size(2), device=self.device)[None, :]
260
261         return [
262             (i1.item(), j1.item(), i2.item() + 1, j2.item() + 1)
263             for i1, j1, i2, j2 in zip(
264                 v.size(2) - (v[0] * (v.size(2) - av)).max(dim=-1).values,
265                 h.size(2) - (h[0] * (h.size(2) - ah)).max(dim=-1).values,
266                 (v[0] * av).max(dim=-1).values,
267                 (h[0] * ah).max(dim=-1).values,
268             )
269         ]
270
271     def rec_coo_(self, x, n, min_height=3, min_width=3):
272         collision = x.new(x.size())
273         while True:
274             collision[...] = 0
275             result = []
276             for _ in range(n):
277                 while True:
278                     i1, i2 = torch.randint(x.size(0), (2,))
279                     if i1 + min_height <= i2:
280                         break
281                 while True:
282                     j1, j2 = torch.randint(x.size(1), (2,))
283                     if j1 + min_width <= j2:
284                         break
285                 collision[i1:i2, j1:j2] += 1
286                 if collision.max() > 1:
287                     break
288                 result.append((i1, j1, i2, j2))
289             if collision.max() == 1:
290                 break
291         return result
292
293     ######################################################################
294
295     def task_replace_color(self, A, f_A, B, f_B):
296         nb_rec = 3
297         c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
298         for X, f_X in [(A, f_A), (B, f_B)]:
299             r = self.rec_coo(nb_rec)
300             for n in range(nb_rec):
301                 i1, j1, i2, j2 = r[n]
302                 X[i1:i2, j1:j2] = c[n]
303                 f_X[i1:i2, j1:j2] = c[n if n > 0 else -1]
304
305     def task_translate(self, A, f_A, B, f_B):
306         di, dj = torch.randint(3, (2,)) - 1
307         nb_rec = 3
308         c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
309         for X, f_X in [(A, f_A), (B, f_B)]:
310             while True:
311                 r = self.rec_coo(nb_rec)
312                 i1, j1, i2, j2 = r[nb_rec - 1]
313                 if (
314                     i1 + di >= 0
315                     and i2 + di < X.size(0)
316                     and j1 + dj >= 0
317                     and j2 + dj < X.size(1)
318                 ):
319                     break
320
321             for n in range(nb_rec):
322                 i1, j1, i2, j2 = r[n]
323                 X[i1:i2, j1:j2] = c[n]
324                 if n == nb_rec - 1:
325                     f_X[i1 + di : i2 + di, j1 + dj : j2 + dj] = c[n]
326                 else:
327                     f_X[i1:i2, j1:j2] = c[n]
328
329     def task_grow(self, A, f_A, B, f_B):
330         di, dj = torch.randint(2, (2,)) * 2 - 1
331         nb_rec = 3
332         c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
333         direction = torch.randint(2, (1,))
334         for X, f_X in [(A, f_A), (B, f_B)]:
335             while True:
336                 r = self.rec_coo(nb_rec)
337                 i1, j1, i2, j2 = r[nb_rec - 1]
338                 if i1 + 3 < i2 and j1 + 3 < j2:
339                     break
340
341             for n in range(nb_rec):
342                 i1, j1, i2, j2 = r[n]
343                 if n == nb_rec - 1:
344                     if direction == 0:
345                         X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
346                         f_X[i1:i2, j1:j2] = c[n]
347                     else:
348                         X[i1:i2, j1:j2] = c[n]
349                         f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
350                 else:
351                     X[i1:i2, j1:j2] = c[n]
352                     f_X[i1:i2, j1:j2] = c[n]
353
354     def task_color_grow(self, A, f_A, B, f_B):
355         di, dj = torch.randint(2, (2,)) * 2 - 1
356         nb_rec = 3
357         c = torch.randperm(len(self.colors) - 1)[: 2 * nb_rec] + 1
358         direction = torch.randint(4, (1,))
359         for X, f_X in [(A, f_A), (B, f_B)]:
360             r = self.rec_coo(nb_rec)
361             for n in range(nb_rec):
362                 i1, j1, i2, j2 = r[n]
363                 X[i1:i2, j1:j2] = c[2 * n]
364                 f_X[i1:i2, j1:j2] = c[2 * n]
365                 # Not my proudest moment
366                 if direction == 0:
367                     i = (i1 + i2) // 2
368                     X[i : i + 1, j1:j2] = c[2 * n + 1]
369                     if n == nb_rec - 1:
370                         f_X[i:i2, j1:j2] = c[2 * n + 1]
371                     else:
372                         f_X[i : i + 1, j1:j2] = c[2 * n + 1]
373                 elif direction == 1:
374                     i = (i1 + i2 - 1) // 2
375                     X[i : i + 1, j1:j2] = c[2 * n + 1]
376                     if n == nb_rec - 1:
377                         f_X[i1 : i + 1, j1:j2] = c[2 * n + 1]
378                     else:
379                         f_X[i : i + 1, j1:j2] = c[2 * n + 1]
380                 elif direction == 2:
381                     j = (j1 + j2) // 2
382                     X[i1:i2, j : j + 1] = c[2 * n + 1]
383                     if n == nb_rec - 1:
384                         f_X[i1:i2, j:j2] = c[2 * n + 1]
385                     else:
386                         f_X[i1:i2, j : j + 1] = c[2 * n + 1]
387                 elif direction == 3:
388                     j = (j1 + j2 - 1) // 2
389                     X[i1:i2, j : j + 1] = c[2 * n + 1]
390                     if n == nb_rec - 1:
391                         f_X[i1:i2, j1 : j + 1] = c[2 * n + 1]
392                     else:
393                         f_X[i1:i2, j : j + 1] = c[2 * n + 1]
394
395     def task_frame(self, A, f_A, B, f_B):
396         nb_rec = 3
397         c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
398         for X, f_X in [(A, f_A), (B, f_B)]:
399             r = self.rec_coo(nb_rec)
400             for n in range(nb_rec):
401                 i1, j1, i2, j2 = r[n]
402                 X[i1:i2, j1:j2] = c[n]
403                 f_X[i1:i2, j1:j2] = c[n]
404                 if n == nb_rec - 1:
405                     f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = 0
406
407     def task_detect(self, A, f_A, B, f_B):
408         nb_rec = 3
409         c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
410         for X, f_X in [(A, f_A), (B, f_B)]:
411             r = self.rec_coo(nb_rec)
412             for n in range(nb_rec):
413                 i1, j1, i2, j2 = r[n]
414                 X[i1:i2, j1:j2] = c[n]
415                 if n < nb_rec - 1:
416                     f_X[i1, j1] = c[-1]
417
418     def task_count(self, A, f_A, B, f_B):
419         N = torch.randint(4, (1,)) + 2
420         c = torch.randperm(len(self.colors) - 1)[:N] + 1
421
422         for X, f_X in [(A, f_A), (B, f_B)]:
423
424             def contact(i, j, q):
425                 nq, nq_diag = 0, 0
426                 no = 0
427
428                 for ii, jj in [
429                     (i - 1, j - 1),
430                     (i - 1, j),
431                     (i - 1, j + 1),
432                     (i, j - 1),
433                     (i, j + 1),
434                     (i + 1, j - 1),
435                     (i + 1, j),
436                     (i + 1, j + 1),
437                 ]:
438                     if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width:
439                         if X[ii, jj] != 0 and X[ii, jj] != q:
440                             no += 1
441
442                 for ii, jj in [
443                     (i - 1, j - 1),
444                     (i - 1, j + 1),
445                     (i + 1, j - 1),
446                     (i + 1, j + 1),
447                 ]:
448                     if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width:
449                         if X[ii, jj] == q and X[i, jj] != q and X[ii, j] != q:
450                             nq_diag += 1
451
452                 for ii, jj in [(i - 1, j), (i, j - 1), (i, j + 1), (i + 1, j)]:
453                     if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width:
454                         if X[ii, jj] == q:
455                             nq += 1
456
457                 return no, nq, nq_diag
458
459             nb = torch.zeros(N, dtype=torch.int64)
460             q = torch.randint(N, (self.height * self.width,))
461             k = torch.randperm(self.height * self.width)
462             for p in range(self.height * self.width):
463                 i, j = k[p] % self.height, k[p] // self.height
464                 no, nq, nq_diag = contact(i, j, c[q[p]])
465                 if no == 0 and nq_diag == 0:
466                     if nq == 0:
467                         if nb[q[p]] < self.width:
468                             X[i, j] = c[q[p]]
469                             nb[q[p]] += 1
470                     if nq == 1:
471                         X[i, j] = c[q[p]]
472
473             for n in range(N):
474                 for j in range(nb[n]):
475                     f_X[n, j] = c[n]
476
477     def task_trajectory(self, A, f_A, B, f_B):
478         c = torch.randperm(len(self.colors) - 1)[:2] + 1
479         for X, f_X in [(A, f_A), (B, f_B)]:
480             while True:
481                 di, dj = torch.randint(7, (2,)) - 3
482                 i, j = torch.randint(self.height, (1,)), torch.randint(self.width, (1,))
483                 if (
484                     abs(di) + abs(dj) > 0
485                     and i + 2 * di >= 0
486                     and i + 2 * di < self.height
487                     and j + 2 * dj >= 0
488                     and j + 2 * dj < self.width
489                 ):
490                     break
491
492             k = 0
493             while (
494                 i + k * di >= 0
495                 and i + k * di < self.height
496                 and j + k * dj >= 0
497                 and j + k * dj < self.width
498             ):
499                 if k < 2:
500                     X[i + k * di, j + k * dj] = c[k]
501                 f_X[i + k * di, j + k * dj] = c[min(k, 1)]
502                 k += 1
503
504     def task_bounce(self, A, f_A, B, f_B):
505         c = torch.randperm(len(self.colors) - 1)[:3] + 1
506         for X, f_X in [(A, f_A), (B, f_B)]:
507
508             def free(i, j):
509                 return (
510                     i >= 0
511                     and i < self.height
512                     and j >= 0
513                     and j < self.width
514                     and f_X[i, j] == 0
515                 )
516
517             while True:
518                 f_X[...] = 0
519                 X[...] = 0
520
521                 for _ in range((self.height * self.width) // 10):
522                     i, j = torch.randint(self.height, (1,)), torch.randint(
523                         self.width, (1,)
524                     )
525                     X[i, j] = c[0]
526                     f_X[i, j] = c[0]
527
528                 while True:
529                     di, dj = torch.randint(7, (2,)) - 3
530                     if abs(di) + abs(dj) == 1:
531                         break
532
533                 i, j = torch.randint(self.height, (1,)), torch.randint(self.width, (1,))
534
535                 X[i, j] = c[1]
536                 f_X[i, j] = c[1]
537                 l = 0
538
539                 while True:
540                     l += 1
541                     if free(i + di, j + dj):
542                         pass
543                     elif free(i - dj, j + di):
544                         di, dj = -dj, di
545                         if free(i + dj, j - di):
546                             if torch.rand(1) < 0.5:
547                                 di, dj = -di, -dj
548                     elif free(i + dj, j - di):
549                         di, dj = dj, -di
550                     else:
551                         break
552
553                     i, j = i + di, j + dj
554                     f_X[i, j] = c[2]
555                     if l <= 1:
556                         X[i, j] = c[2]
557
558                     if l >= self.width:
559                         break
560
561                 f_X[i, j] = c[1]
562                 X[i, j] = c[1]
563
564                 if l > 3:
565                     break
566
567     def task_scale(self, A, f_A, B, f_B):
568         c = torch.randperm(len(self.colors) - 1)[:2] + 1
569
570         i, j = torch.randint(self.height // 2, (1,)), torch.randint(
571             self.width // 2, (1,)
572         )
573
574         for X, f_X in [(A, f_A), (B, f_B)]:
575             for _ in range(3):
576                 while True:
577                     i1, j1 = torch.randint(self.height // 2 + 1, (1,)), torch.randint(
578                         self.width // 2 + 1, (1,)
579                     )
580                     i2, j2 = torch.randint(self.height // 2 + 1, (1,)), torch.randint(
581                         self.width // 2 + 1, (1,)
582                     )
583                     if i1 < i2 and j1 < j2 and min(i2 - i1, j2 - j1) <= 3:
584                         break
585                 X[i + i1 : i + i2, j + j1 : j + j2] = c[0]
586                 f_X[2 * i1 : 2 * i2, 2 * j1 : 2 * j2] = c[0]
587
588             X[i, j] = c[1]
589             f_X[0:2, 0:2] = c[1]
590
591     def task_islands(self, A, f_A, B, f_B):
592         for X, f_X in [(A, f_A), (B, f_B)]:
593             while True:
594                 i, j = torch.randint(self.height, (1,)), torch.randint(self.width, (1,))
595                 if (
596                     i == 0
597                     or i == self.height - 1
598                     or j == 0
599                     or j == self.width - 1
600                     or X[i, j] == 1
601                 ):
602                     break
603             while True:
604                 di, dj = torch.randint(3, (2,)) - 1
605                 if abs(di) + abs(dj) > 0:
606                     break
607             X[i, j] = 1
608             while True:
609                 i, j = i + di, j + dj
610                 if i < 0 or i >= self.height or j < 0 or j >= self.width:
611                     break
612                 b = (
613                     i == 0
614                     or i == self.height - 1
615                     or j == 0
616                     or j == self.width - 1
617                     or X[i, j] == 1
618                 )
619                 X[i, j] = 1
620                 if b:
621                     break
622
623     ######################################################################
624
625     def all_tasks(self):
626         return [
627             self.task_replace_color,
628             self.task_translate,
629             self.task_grow,
630             self.task_color_grow,
631             self.task_frame,
632             self.task_detect,
633             self.task_count,
634             self.task_trajectory,
635             self.task_bounce,
636             self.task_scale,
637             self.task_islands,
638         ]
639
640     def generate_prompts_and_answers(self, nb, tasks=None, device="cpu"):
641         if tasks is None:
642             tasks = self.all_tasks()
643
644         prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64)
645         answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64)
646         w = self.width
647
648         for prompt, answer in tqdm.tqdm(
649             zip(prompts, answers),
650             dynamic_ncols=True,
651             desc="world generation",
652             total=prompts.size(0),
653         ):
654             A = prompt[:, 0 * w : 1 * w]
655             f_A = prompt[:, 1 * w : 2 * w]
656             B = prompt[:, 2 * w : 3 * w]
657             f_B = answer
658             task = tasks[torch.randint(len(tasks), (1,))]
659             task(A, f_A, B, f_B)
660
661         return prompts.flatten(1), answers.flatten(1)
662
663     def save_quizzes(
664         self,
665         result_dir,
666         filename_prefix,
667         prompts,
668         answers,
669         predicted_prompts=None,
670         predicted_answers=None,
671         nrow=4,
672     ):
673         self.save_image(
674             result_dir,
675             filename_prefix + ".png",
676             prompts,
677             answers,
678             predicted_prompts,
679             predicted_answers,
680             nrow,
681         )
682
683
684 ######################################################################
685
686 if __name__ == "__main__":
687     import time
688
689     nb = 4
690
691     reasoning = Reasoning()
692
693     for t in [reasoning.task_islands]:  # reasoning.all_tasks():
694         print(t.__name__)
695         prompts, answers = reasoning.generate_prompts_and_answers(nb, tasks=[t])
696         reasoning.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=1)
697
698     exit(0)
699
700     nb = 72
701
702     start_time = time.perf_counter()
703     prompts, answers = reasoning.generate_prompts_and_answers(nb)
704     delay = time.perf_counter() - start_time
705     print(f"{prompts.size(0)/delay:02f} seq/s")
706
707     m = torch.randint(2, (prompts.size(0),))
708     predicted_prompts = m * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
709     predicted_answers = (1 - m) * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
710
711     reasoning.save_quizzes(
712         "/tmp",
713         "test",
714         prompts[:nb],
715         answers[:nb],
716         # You can add a bool to put a frame around the predicted parts
717         predicted_prompts[:nb],
718         predicted_answers[:nb],
719     )