b4e6f67c1ee8d22f0f1e202ad252a7613e6a8b94
[culture.git] / tasks.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, os, tqdm, warnings
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 from mygpt import BracketedSequence
16
17 # from graph import save_attention_image
18 save_attention_image = None
19
20 ######################################################################
21
22
23 def masked_inplace_autoregression(
24     model,
25     batch_size,
26     input,
27     ar_mask,
28     deterministic_synthesis,
29     forbidden_tokens=None,
30     logit_biases=None,
31     progress_bar_desc="autoregression",
32     device=torch.device("cpu"),
33 ):
34     assert input.size() == ar_mask.size()
35
36     batches = zip(input.split(batch_size), ar_mask.split(batch_size))
37
38     if progress_bar_desc is not None:
39         batches = tqdm.tqdm(
40             batches,
41             dynamic_ncols=True,
42             desc=progress_bar_desc,
43             total=(input.size(0) + batch_size - 1) // batch_size,
44         )
45
46     with torch.autograd.no_grad():
47         t = model.training
48         model.eval()
49
50         for input, ar_mask in batches:
51             model.masked_inplace_autoregression(
52                 input,
53                 ar_mask,
54                 deterministic_synthesis,
55                 forbidden_tokens,
56                 logit_biases,
57             )
58
59         model.train(t)
60
61
62 ######################################################################
63
64
65 class Task:
66     def batches(self, split="train", nb_to_use=-1, desc=None):
67         pass
68
69     def vocabulary_size(self):
70         pass
71
72     def produce_results(
73         self, n_epoch, model, result_dir, logger, deterministic_synthesis
74     ):
75         pass
76
77
78 class TaskFromFile(Task):
79     def tensorize(self, pairs, shuffle):
80         len_max = max([len(x[0]) for x in pairs])
81
82         input = torch.cat(
83             [
84                 torch.tensor(
85                     [
86                         [self.char2id[c] for c in s[0] + "#" * (len_max - len(s[0]))]
87                         for s in pairs
88                     ]
89                 )
90             ],
91             0,
92         ).to("cpu")
93
94         pred_mask = torch.cat(
95             [
96                 torch.tensor(
97                     [
98                         [int(c) for c in s[1] + "0" * (len_max - len(s[1]))]
99                         for s in pairs
100                     ]
101                 )
102             ],
103             0,
104         ).to("cpu")
105
106         if shuffle:
107             i = torch.randperm(input.size(0))
108             input = input[i].contiguous()
109             pred_mask = pred_mask[i].contiguous()
110
111         return input, pred_mask
112
113     # trim all the tensors in the tuple z to remove as much token from
114     # left and right in the first tensor. If z is a tuple, all its
115     # elements are trimed according to the triming for the first
116     def trim(self, z, token="#"):
117         n = self.char2id[token]
118         if type(z) == tuple:
119             x = z[0]
120             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
121             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
122             return tuple([t[:, a:b] for t in z])
123         else:
124             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
125             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
126             return z[:, a:b]
127
128     def __init__(
129         self,
130         train_filename,
131         test_filename,
132         nb_train_samples,
133         nb_test_samples,
134         batch_size,
135         shuffle=False,
136         device=torch.device("cpu"),
137     ):
138         self.batch_size = batch_size
139         self.device = device
140
141         def read_file(filename, nb=-1):
142             pairs = []
143             with open(filename, "r") as f:
144                 while True:
145                     sequence = f.readline().strip()
146                     if not sequence:
147                         break
148                     pred_mask = f.readline().strip()
149                     assert len(sequence) == len(pred_mask)
150                     assert set(pred_mask).issubset({"0", "1", "2"}), f"{set(pred_mask)}"
151                     pairs.append((sequence, pred_mask))
152                     if len(pairs) == nb:
153                         break
154
155             if nb > 0:
156                 pairs = pairs[:nb]
157                 assert len(pairs) == nb
158
159             return pairs
160
161         train_pairs = read_file(train_filename, nb_train_samples)
162         test_pairs = read_file(test_filename, nb_test_samples)
163
164         symbols = ["#"] + list(
165             set("".join([x[0] for x in train_pairs + test_pairs])) - set(["#"])
166         )
167         self.char2id = dict([(c, n) for n, c in enumerate(symbols)])
168         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
169
170         self.train_input, self.train_pred_masks = self.tensorize(
171             train_pairs, shuffle=shuffle
172         )
173         self.test_input, self.test_pred_masks = self.tensorize(
174             test_pairs, shuffle=shuffle
175         )
176
177     def batches(self, split="train", nb_to_use=-1, desc=None):
178         assert split in {"train", "test"}
179         input = self.train_input if split == "train" else self.test_input
180         if nb_to_use > 0:
181             input = input[:nb_to_use]
182         if desc is None:
183             desc = f"epoch-{split}"
184         for batch in tqdm.tqdm(
185             input.split(self.batch_size), dynamic_ncols=True, desc=desc
186         ):
187             yield self.trim(batch).to(self.device)
188
189     def vocabulary_size(self):
190         return len(self.char2id)
191
192     def tensor2str(self, t):
193         return ["".join([self.id2char[x.item()] for x in s]) for s in t]
194
195     def produce_results(
196         self, n_epoch, model, result_dir, logger, deterministic_synthesis
197     ):
198         correct = self.trim(self.test_input[:1000]).to(self.device)
199         result = correct.clone()
200         pred_mask = self.test_pred_masks[:1000, : result.size(1)].to(self.device)
201         ar_mask = (pred_mask > 0).long()
202         result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
203
204         logger(f"----------------------------------------------------------")
205
206         for e in self.tensor2str(result[:50]):
207             logger(f"test_before {e}")
208
209         masked_inplace_autoregression(
210             model,
211             self.batch_size,
212             result,
213             ar_mask,
214             deterministic_synthesis,
215             device=self.device,
216         )
217
218         logger(f"----------------------------------------------------------")
219
220         for e, c in zip(self.tensor2str(result[:50]), self.tensor2str(correct[:50])):
221             logger(f"test_after  {e}")
222             logger(f"correct     {c}")
223
224         logger(f"----------------------------------------------------------")
225
226         err_mask = (pred_mask == 2).long()
227         nb_total = err_mask.sum().item()
228         nb_correct = ((correct == result).long() * err_mask).sum().item()
229
230         logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}")
231         logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}")
232
233
234 ####################
235
236 import problems
237
238
239 class SandBox(Task):
240     def __init__(
241         self,
242         problem,
243         nb_train_samples,
244         nb_test_samples,
245         batch_size,
246         logger=None,
247         device=torch.device("cpu"),
248         max_nb_codes=1024,
249     ):
250         super().__init__()
251
252         self.batch_size = batch_size
253         self.device = device
254         self.problem = problem
255
256         self.train_input, self.train_ar_mask = self.problem.generate_sequences(
257             nb_train_samples
258         )
259         self.test_input, self.test_ar_mask = self.problem.generate_sequences(
260             nb_test_samples
261         )
262
263         self.train_input, self.train_ar_mask = self.train_input.to(
264             device
265         ), self.train_ar_mask.to(device)
266         self.test_input, self.test_ar_mask = self.test_input.to(
267             device
268         ), self.test_ar_mask.to(device)
269
270         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
271
272         # A bit of paranoia never hurts
273         assert self.nb_codes <= max_nb_codes
274         assert self.train_input.min() >= 0
275         assert self.test_input.min() >= 0
276         assert tuple(x.item() for x in self.train_ar_mask.unique()) in {
277             (0,),
278             (1,),
279             (0, 1),
280         }
281         assert tuple(x.item() for x in self.test_ar_mask.unique()) in {
282             (0,),
283             (1,),
284             (0, 1),
285         }
286
287         if logger is not None:
288             for s, a in zip(self.train_input[:100], self.train_ar_mask[:100]):
289                 logger(f"train_sequences {self.problem.seq2str(s)}")
290                 a = "".join(["01"[x.item()] for x in a])
291                 logger(f"                {a}")
292
293     def batches(self, split="train", nb_to_use=-1, desc=None):
294         assert split in {"train", "test"}
295         input = self.train_input if split == "train" else self.test_input
296         if nb_to_use > 0:
297             input = input[:nb_to_use]
298         if desc is None:
299             desc = f"epoch-{split}"
300         for batch in tqdm.tqdm(
301             input.split(self.batch_size), dynamic_ncols=True, desc=desc
302         ):
303             yield batch
304
305     def vocabulary_size(self):
306         return self.nb_codes
307
308     def produce_results(
309         self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
310     ):
311         def compute_accuracy(input, ar_mask, logger=None):
312             input, ar_mask = input[:nmax], ar_mask[:nmax]
313             result = input.clone() * (1 - ar_mask)
314
315             masked_inplace_autoregression(
316                 model,
317                 self.batch_size,
318                 result,
319                 ar_mask,
320                 deterministic_synthesis,
321                 progress_bar_desc=None,
322                 device=self.device,
323             )
324
325             log_ground_truth = ar_mask.min() == 0
326
327             if logger is not None:
328                 for sp, st in zip(result[:10], input[:10]):
329                     logger(
330                         f"test_sequences {n_epoch} prediction   {self.problem.seq2str(sp)}"
331                     )
332                     if log_ground_truth:
333                         logger(
334                             f"               {n_epoch} ground truth {self.problem.seq2str(st)}"
335                         )
336
337             nb_total, nb_correct = self.problem.compute_nb_correct(
338                 input, ar_mask, result
339             )
340
341             # nb_total = ar_mask.sum().item()
342             # nb_correct = ((result == input).long() * ar_mask).sum().item()
343
344             return nb_total, nb_correct
345
346         train_nb_total, train_nb_correct = compute_accuracy(
347             self.train_input, self.train_ar_mask
348         )
349
350         logger(
351             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
352         )
353
354         test_nb_total, test_nb_correct = compute_accuracy(
355             self.test_input, self.test_ar_mask, logger
356         )
357
358         logger(
359             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
360         )
361
362         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
363
364         if save_attention_image is not None:
365             for k in range(10):
366                 ns = torch.randint(self.test_input.size(0), (1,)).item()
367                 input = self.test_input[ns : ns + 1].clone()
368
369                 with torch.autograd.no_grad():
370                     t = model.training
371                     model.eval()
372                     # model.record_attention(True)
373                     model(BracketedSequence(input))
374                     model.train(t)
375                     # ram = model.retrieve_attention()
376                     # model.record_attention(False)
377
378                 # tokens_output = [c for c in self.problem.seq2str(input[0])]
379                 # tokens_input = ["n/a"] + tokens_output[:-1]
380                 # for n_head in range(ram[0].size(1)):
381                 # filename = os.path.join(
382                 # result_dir, f"sandbox_attention_{k}_h{n_head}.pdf"
383                 # )
384                 # attention_matrices = [m[0, n_head] for m in ram]
385                 # save_attention_image(
386                 # filename,
387                 # tokens_input,
388                 # tokens_output,
389                 # attention_matrices,
390                 # k_top=10,
391                 ##min_total_attention=0.9,
392                 # token_gap=12,
393                 # layer_gap=50,
394                 # )
395                 # logger(f"wrote {filename}")
396
397
398 ######################################################################
399
400 import world
401
402
403 class World(Task):
404     def __init__(
405         self,
406         nb_train_samples,
407         nb_test_samples,
408         batch_size,
409         logger=None,
410         device=torch.device("cpu"),
411     ):
412         super().__init__()
413
414         self.batch_size = batch_size
415         self.device = device
416         self.height = 6
417         self.width = 8
418
419         self.train_input = world.generate(
420             nb_train_samples, height=self.height, width=self.width
421         )
422         self.train_ar_mask = (
423             (torch.arange(self.train_input.size(1)) > self.train_input.size(1) // 2)
424             .long()[None, :]
425             .expand_as(self.train_input)
426         )
427
428         self.test_input = world.generate(
429             nb_test_samples, height=self.height, width=self.width
430         )
431         self.test_ar_mask = (
432             (torch.arange(self.test_input.size(1)) > self.test_input.size(1) // 2)
433             .long()[None, :]
434             .expand_as(self.test_input)
435         )
436
437         self.train_input, self.train_ar_mask = self.train_input.to(
438             device
439         ), self.train_ar_mask.to(device)
440         self.test_input, self.test_ar_mask = self.test_input.to(
441             device
442         ), self.test_ar_mask.to(device)
443
444         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
445
446     def batches(self, split="train", nb_to_use=-1, desc=None):
447         assert split in {"train", "test"}
448         input = self.train_input if split == "train" else self.test_input
449         if nb_to_use > 0:
450             input = input[:nb_to_use]
451         if desc is None:
452             desc = f"epoch-{split}"
453         for batch in tqdm.tqdm(
454             input.split(self.batch_size), dynamic_ncols=True, desc=desc
455         ):
456             yield batch
457
458     def vocabulary_size(self):
459         return self.nb_codes
460
461     def produce_results(
462         self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
463     ):
464         def compute_accuracy(input, ar_mask, logger=None):
465             input, ar_mask = input[:nmax], ar_mask[:nmax]
466             result = input.clone() * (1 - ar_mask)
467
468             masked_inplace_autoregression(
469                 model,
470                 self.batch_size,
471                 result,
472                 ar_mask,
473                 deterministic_synthesis,
474                 progress_bar_desc=None,
475                 device=self.device,
476             )
477
478             nb_total, nb_correct = (
479                 input.size(0),
480                 (input == result).long().min(dim=1).values.sum(),
481             )
482
483             return nb_total, nb_correct
484
485         train_nb_total, train_nb_correct = compute_accuracy(
486             self.train_input, self.train_ar_mask
487         )
488
489         logger(
490             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
491         )
492
493         test_nb_total, test_nb_correct = compute_accuracy(
494             self.test_input, self.test_ar_mask, logger
495         )
496
497         logger(
498             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
499         )
500
501         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
502
503         ##############################
504
505         input, ar_mask = self.test_input[:64], self.test_ar_mask[:64]
506         result = input.clone() * (1 - ar_mask)
507
508         masked_inplace_autoregression(
509             model,
510             self.batch_size,
511             result,
512             ar_mask,
513             deterministic_synthesis,
514             progress_bar_desc=None,
515             device=self.device,
516         )
517
518         img = world.sample2img(result.to("cpu"), self.height, self.width)
519
520         image_name = os.path.join(result_dir, f"world_result_{n_epoch:04d}.png")
521         torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=8, padding=2)
522         logger(f"wrote {image_name}")
523
524
525 ######################################################################
526
527 import picoclvr
528
529
530 class PicoCLVR(Task):
531     # Make a tensor from a list of strings
532     def tensorize(self, descr):
533         token_descr = [s.strip().split(" ") for s in descr]
534         l = max([len(s) for s in token_descr])
535         token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
536         id_descr = [[self.token2id[u] for u in s] for s in token_descr]
537         return torch.tensor(id_descr, device=self.device)
538
539     # Make a list of strings from a tensor
540     def detensorize(self, x):
541         return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
542
543     # trim all the tensors in the tuple z to remove as much token from
544     # left and right in the first tensor. If z is a tuple, all its
545     # elements are trimed according to the triming for the first
546     def trim(self, z, token="<nul>"):
547         n = self.token2id[token]
548         if type(z) == tuple:
549             x = z[0]
550             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
551             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
552             return tuple([t[:, a:b] for t in z])
553         else:
554             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
555             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
556             return z[:, a:b]
557
558     ######################
559
560     def __init__(
561         self,
562         nb_train_samples,
563         nb_test_samples,
564         batch_size,
565         height,
566         width,
567         nb_colors=5,
568         logger=None,
569         device=torch.device("cpu"),
570         pruner_train=None,
571         pruner_eval=None,
572     ):
573         super().__init__()
574
575         def generate_descr(nb, cache_suffix, pruner):
576             return picoclvr.generate(
577                 nb,
578                 height=self.height,
579                 width=self.width,
580                 nb_colors=nb_colors,
581                 pruner=pruner,
582             )
583
584         self.height = height
585         self.width = width
586         self.batch_size = batch_size
587         self.device = device
588         self.pruner_train = pruner_train
589         self.pruner_eval = pruner_eval
590
591         if logger is not None:
592             logger(
593                 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
594             )
595
596         self.train_descr = generate_descr(
597             nb_train_samples, "train", pruner=self.pruner_train
598         )
599         self.test_descr = generate_descr(nb_test_samples, "test", pruner=None)
600
601         # Build the tokenizer
602         tokens = {"<nul>", "<img>"}
603         for d in [self.train_descr, self.test_descr]:
604             for s in d:
605                 for t in s.strip().split(" "):
606                     tokens.add(t)
607         # make this set a sorted list to get the same tensors given
608         # the same descr
609         tokens = list(tokens)
610         tokens.sort()
611         self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
612         self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
613         self.t_img, self.t_nul = self.token2id["<img>"], self.token2id["<nul>"]
614
615         # Tokenize the train and test sets
616         self.train_input = self.tensorize(self.train_descr)
617         self.test_input = self.tensorize(self.test_descr)
618
619     def batches(self, split="train", nb_to_use=-1, desc=None):
620         assert split in {"train", "test"}
621         input = self.train_input if split == "train" else self.test_input
622         for batch in tqdm.tqdm(
623             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
624         ):
625             yield self.trim(batch)
626
627     def vocabulary_size(self):
628         return len(self.token2id)
629
630     def compute_missing_properties(
631         self, n_epoch, model, logger, deterministic_synthesis, pruner=None
632     ):
633         acc_nb_requested_properties = []
634         acc_nb_missing_properties = []
635         acc_nb_results = 0
636
637         for input in tqdm.tqdm(
638             self.test_input.split(self.batch_size),
639             dynamic_ncols=True,
640             desc=f"test-properties",
641         ):
642             result = input.clone()
643             ar_mask = (result == self.t_img).long().cumsum(dim=1).clamp(max=1)
644             result = (1 - ar_mask) * result + ar_mask * self.t_nul
645             masked_inplace_autoregression(
646                 model,
647                 self.batch_size,
648                 result,
649                 ar_mask,
650                 deterministic_synthesis,
651                 progress_bar_desc=None,
652                 device=self.device,
653             )
654
655             result_descr = self.detensorize(result)
656             np = picoclvr.nb_properties(
657                 result_descr,
658                 height=self.height,
659                 width=self.width,
660                 pruner=pruner,
661             )
662             nb_requested_properties, _, nb_missing_properties = zip(*np)
663             acc_nb_requested_properties += nb_requested_properties
664             acc_nb_missing_properties += nb_missing_properties
665             acc_nb_results += len(result_descr)
666
667         nb_requested_properties = sum(acc_nb_requested_properties)
668         nb_missing_properties = sum(acc_nb_missing_properties)
669
670         prefix = "" if pruner is None else "pruned_"
671         logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
672         logger(
673             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
674         )
675         logger(
676             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
677         )
678
679         logger(
680             f"main_test_accuracy {n_epoch} {1-nb_missing_properties/nb_requested_properties}"
681         )
682
683     ######################################################################
684
685     def produce_results(
686         self, n_epoch, model, result_dir, logger, deterministic_synthesis
687     ):
688         self.compute_missing_properties(n_epoch, model, logger, deterministic_synthesis)
689
690         if self.pruner_eval is not None:
691             self.compute_missing_properties(n_epoch, model, self.pruner_eval)
692
693         nb_tokens_to_generate = self.height * self.width + 3
694         result_descr = []
695         nb_per_primer = 8
696         primer = []
697
698         for primer_descr in [
699             "red above green <sep> green top <sep> blue right of red",
700             "there is red <sep> there is yellow <sep> there is blue",
701             "red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left",
702             "green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top",
703         ]:
704             primer += [primer_descr + " <img>"] * nb_per_primer
705
706         result = self.tensorize(primer)
707         fill = result.new_full(
708             result.size()[:-1] + (self.height * self.width + 1,), self.t_nul
709         )
710         result = torch.cat((result, fill), 1)
711         ar_mask = (result == self.t_nul).long()
712         masked_inplace_autoregression(
713             model,
714             self.batch_size,
715             result,
716             ar_mask,
717             deterministic_synthesis,
718             device=self.device,
719         )
720         result_descr = self.detensorize(result)
721
722         np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
723
724         acc_nb_requested_properties, _, acc_nb_missing_properties = zip(*np)
725         acc_nb_results = len(result_descr)
726
727         nb_requested_properties = sum(acc_nb_requested_properties)
728         nb_missing_properties = sum(acc_nb_missing_properties)
729
730         prefix = "demo_"
731         logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
732         logger(
733             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
734         )
735         logger(
736             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
737         )
738
739         img = picoclvr.descr2img(result_descr, height=self.height, width=self.width)
740
741         if img.dim() == 5:
742             if img.size(1) == 1:
743                 img = F.pad(img.squeeze(1), pad=(1, 1, 1, 1), value=64)
744             else:
745                 img = torch.cat(
746                     [
747                         torchvision.utils.make_grid(x, padding=1, pad_value=64)[None]
748                         for x in img
749                     ],
750                     0,
751                 )
752
753         image_name = os.path.join(result_dir, f"picoclvr_result_{n_epoch:04d}.png")
754         torchvision.utils.save_image(
755             img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=0.0
756         )
757         logger(f"wrote {image_name}")
758
759
760 ######################################################################
761
762
763 class MNIST(Task):
764     def __init__(
765         self, nb_train_samples, nb_test_samples, batch_size, device=torch.device("cpu")
766     ):
767         super().__init__()
768
769         self.nb_train_samples = (nb_train_samples,)
770         self.nb_test_samples = (nb_test_samples,)
771         self.batch_size = batch_size
772         self.device = device
773         data_set = torchvision.datasets.MNIST(root="./data", train=True, download=True)
774         self.train_input = data_set.data[:nb_train_samples].view(-1, 28 * 28).long()
775         data_set = torchvision.datasets.MNIST(root="./data", train=False, download=True)
776         self.test_input = data_set.data[:nb_test_samples].view(-1, 28 * 28).long()
777
778     def batches(self, split="train", nb_to_use=-1, desc=None):
779         assert split in {"train", "test"}
780         input = self.train_input if split == "train" else self.test_input
781         if nb_to_use > 0:
782             input = input[:nb_to_use]
783         if desc is None:
784             desc = f"epoch-{split}"
785         for batch in tqdm.tqdm(
786             input.split(self.batch_size), dynamic_ncols=True, desc=desc
787         ):
788             yield batch
789
790     def vocabulary_size(self):
791         return 256
792
793     def produce_results(
794         self, n_epoch, model, result_dir, logger, deterministic_synthesis
795     ):
796         results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64)
797         ar_mask = torch.full_like(results, 1)
798         masked_inplace_autoregression(
799             model,
800             self.batch_size,
801             results,
802             ar_mask,
803             deterministic_synthesis,
804             device=self.device,
805         )
806         image_name = os.path.join(result_dir, f"mnist_result_{n_epoch:04d}.png")
807         torchvision.utils.save_image(
808             1 - results.reshape(-1, 1, 28, 28) / 255.0,
809             image_name,
810             nrow=16,
811             pad_value=0.8,
812         )
813         logger(f"wrote {image_name}")
814
815
816 ######################################################################
817
818 import maze
819
820
821 class Maze(Task):
822     def map2seq(self, *m):
823         return torch.cat([x.flatten(1) for x in m], 1)
824
825     def seq2map(self, s):
826         s = s.reshape(s.size(0), -1, self.height, self.width)
827         return (s[:, k] for k in range(s.size(1)))
828
829     def __init__(
830         self,
831         nb_train_samples,
832         nb_test_samples,
833         batch_size,
834         height,
835         width,
836         nb_walls,
837         device=torch.device("cpu"),
838     ):
839         super().__init__()
840
841         self.batch_size = batch_size
842         self.height = height
843         self.width = width
844         self.device = device
845
846         train_mazes, train_paths, _ = maze.create_maze_data(
847             nb_train_samples,
848             height=height,
849             width=width,
850             nb_walls=nb_walls,
851             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"),
852         )
853         self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device))
854
855         test_mazes, test_paths, _ = maze.create_maze_data(
856             nb_test_samples,
857             height=height,
858             width=width,
859             nb_walls=nb_walls,
860             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"),
861         )
862         self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
863
864         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
865
866     def batches(self, split="train", nb_to_use=-1, desc=None):
867         assert split in {"train", "test"}
868         input = self.train_input if split == "train" else self.test_input
869         if nb_to_use > 0:
870             input = input[:nb_to_use]
871         if desc is None:
872             desc = f"epoch-{split}"
873         for batch in tqdm.tqdm(
874             input.split(self.batch_size), dynamic_ncols=True, desc=desc
875         ):
876             yield batch
877
878     def vocabulary_size(self):
879         return self.nb_codes
880
881     def compute_error(
882         self, model, split="train", nb_to_use=-1, deterministic_synthesis=False
883     ):
884         model_device = next(model.parameters()).device
885         nb_total, nb_correct = 0, 0
886         count = torch.zeros(
887             self.width * self.height,
888             self.width * self.height,
889             device=model_device,
890             dtype=torch.int64,
891         )
892
893         for input in self.batches(split, nb_to_use):
894             input = input.to(model_device)
895             result = input.clone()
896             ar_mask = result.new_zeros(result.size())
897             ar_mask[:, self.height * self.width :] = 1
898             result *= 1 - ar_mask
899             masked_inplace_autoregression(
900                 model,
901                 self.batch_size,
902                 result,
903                 ar_mask,
904                 deterministic_synthesis,
905                 progress_bar_desc=None,
906                 device=self.device,
907             )
908             mazes, paths = self.seq2map(result)
909             path_correctness = maze.path_correctness(mazes, paths)
910             nb_correct += path_correctness.long().sum()
911             nb_total += mazes.size(0)
912
913             optimal_path_lengths = (
914                 (input[:, self.height * self.width :] == maze.v_path).long().sum(1)
915             )
916             predicted_path_lengths = (
917                 (result[:, self.height * self.width :] == maze.v_path).long().sum(1)
918             )
919             optimal_path_lengths = optimal_path_lengths[path_correctness]
920             predicted_path_lengths = predicted_path_lengths[path_correctness]
921             count[optimal_path_lengths, predicted_path_lengths] += 1
922
923         if count.max() == 0:
924             count = None
925         else:
926             count = count[
927                 : count.sum(1).nonzero().max() + 1, : count.sum(0).nonzero().max() + 1
928             ]
929
930         return nb_total, nb_correct, count
931
932     def produce_results(
933         self, n_epoch, model, result_dir, logger, deterministic_synthesis
934     ):
935         train_nb_total, train_nb_correct, count = self.compute_error(
936             model,
937             "train",
938             nb_to_use=1000,
939             deterministic_synthesis=deterministic_synthesis,
940         )
941         logger(
942             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
943         )
944
945         test_nb_total, test_nb_correct, count = self.compute_error(
946             model,
947             "test",
948             nb_to_use=1000,
949             deterministic_synthesis=deterministic_synthesis,
950         )
951         logger(
952             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
953         )
954
955         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
956
957         if count is not None:
958             proportion_optimal = count.diagonal().sum().float() / count.sum()
959             logger(f"proportion_optimal_test {proportion_optimal*100:.02f}%")
960             with open(
961                 os.path.join(result_dir, f"maze_result_{n_epoch:04d}.txt"), "w"
962             ) as f:
963                 for i in range(count.size(0)):
964                     for j in range(count.size(1)):
965                         eol = " " if j < count.size(1) - 1 else "\n"
966                         f.write(f"{count[i,j]}{eol}")
967
968         input = self.test_input[:48].to(next(model.parameters()).device)
969         result = input.clone()
970         ar_mask = result.new_zeros(result.size())
971         ar_mask[:, self.height * self.width :] = 1
972         result *= 1 - ar_mask
973         masked_inplace_autoregression(
974             model,
975             self.batch_size,
976             result,
977             ar_mask,
978             deterministic_synthesis,
979             device=self.device,
980         )
981
982         mazes, paths = self.seq2map(input)
983         _, predicted_paths = self.seq2map(result)
984
985         filename = os.path.join(result_dir, f"maze_result_{n_epoch:04d}.png")
986         maze.save_image(
987             filename,
988             mazes=mazes,
989             target_paths=paths,
990             predicted_paths=predicted_paths,
991             path_correct=maze.path_correctness(mazes, predicted_paths),
992             path_optimal=maze.path_optimality(paths, predicted_paths),
993         )
994         logger(f"wrote {filename}")
995
996
997 ######################################################################
998
999
1000 import snake
1001
1002
1003 class Snake(Task):
1004     def __init__(
1005         self,
1006         nb_train_samples,
1007         nb_test_samples,
1008         batch_size,
1009         height,
1010         width,
1011         nb_colors,
1012         length,
1013         prompt_length,
1014         device=torch.device("cpu"),
1015     ):
1016         super().__init__()
1017
1018         self.batch_size = batch_size
1019         self.height = height
1020         self.width = width
1021         self.device = device
1022         self.prompt_length = prompt_length
1023
1024         self.train_input, self.train_prior_visits, _, _ = snake.generate_sequences(
1025             nb_train_samples,
1026             height,
1027             width,
1028             nb_colors,
1029             length,
1030             prompt_length,
1031             self.device,
1032         )
1033         self.test_input, self.test_prior_visits, _, _ = snake.generate_sequences(
1034             nb_test_samples,
1035             height,
1036             width,
1037             nb_colors,
1038             length,
1039             prompt_length,
1040             self.device,
1041         )
1042
1043         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1044
1045     def batches(self, split="train", nb_to_use=-1, desc=None):
1046         assert split in {"train", "test"}
1047         input = self.train_input if split == "train" else self.test_input
1048         if nb_to_use > 0:
1049             input = input[:nb_to_use]
1050         if desc is None:
1051             desc = f"epoch-{split}"
1052         for batch in tqdm.tqdm(
1053             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1054         ):
1055             yield batch
1056
1057     def vocabulary_size(self):
1058         return self.nb_codes
1059
1060     def produce_results(
1061         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1062     ):
1063         def compute_nb_correct(input, prior_visits):
1064             result = input.clone()
1065             i = torch.arange(result.size(1), device=result.device)[None, :]
1066             ar_mask = (
1067                 torch.logical_and(i >= self.prompt_length * 2, i % 2 == 0)
1068                 .long()
1069                 .expand_as(result)
1070             )
1071             result *= 1 - ar_mask
1072
1073             masked_inplace_autoregression(
1074                 model,
1075                 self.batch_size,
1076                 result,
1077                 ar_mask,
1078                 deterministic_synthesis,
1079                 device=self.device,
1080             )
1081
1082             nb_total = ((prior_visits > 0) * ar_mask).sum()
1083
1084             nb_correct = ((result == input).long() * (prior_visits > 0) * ar_mask).sum()
1085
1086             return nb_total, nb_correct
1087
1088         test_nb_total, test_nb_correct = compute_nb_correct(
1089             self.test_input[:1000], self.test_prior_visits[:1000]
1090         )
1091
1092         logger(
1093             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
1094         )
1095
1096         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
1097
1098
1099 ######################################################################
1100
1101
1102 import stack
1103
1104
1105 class Stack(Task):
1106     def __init__(
1107         self,
1108         nb_train_samples,
1109         nb_test_samples,
1110         batch_size,
1111         logger,
1112         nb_steps,
1113         nb_stacks,
1114         nb_digits,
1115         fraction_values_for_train=None,
1116         device=torch.device("cpu"),
1117     ):
1118         super().__init__()
1119
1120         self.batch_size = batch_size
1121         self.nb_steps = nb_steps
1122         self.nb_stacks = nb_stacks
1123         self.nb_digits = nb_digits
1124         self.device = device
1125
1126         if fraction_values_for_train is None:
1127             values_for_train = None
1128             values_for_test = None
1129         else:
1130             all = torch.randperm(10**nb_digits)
1131             nb_for_train = int(all.size(0) * fraction_values_for_train)
1132             values_for_train = all[:nb_for_train]
1133             values_for_test = all[nb_for_train:]
1134
1135         self.train_input, self.train_stack_counts = stack.generate_sequences(
1136             nb_train_samples,
1137             nb_steps,
1138             nb_stacks,
1139             nb_digits,
1140             values_for_train,
1141             self.device,
1142         )
1143
1144         self.test_input, self.test_stack_counts = stack.generate_sequences(
1145             nb_test_samples,
1146             nb_steps,
1147             nb_stacks,
1148             nb_digits,
1149             values_for_test,
1150             self.device,
1151         )
1152
1153         i = torch.logical_and(self.test_input % 2 == 1, self.test_input < 2 * nb_stacks)
1154         counts = self.test_stack_counts.flatten()[i.flatten()]
1155         counts = F.one_hot(counts).sum(0)
1156         logger(f"test_pop_stack_counts {counts}")
1157
1158         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1159
1160     def batches(self, split="train", nb_to_use=-1, desc=None):
1161         assert split in {"train", "test"}
1162         input = self.train_input if split == "train" else self.test_input
1163         if nb_to_use > 0:
1164             input = input[:nb_to_use]
1165         if desc is None:
1166             desc = f"epoch-{split}"
1167         for batch in tqdm.tqdm(
1168             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1169         ):
1170             yield batch
1171
1172     def vocabulary_size(self):
1173         return self.nb_codes
1174
1175     def produce_results(
1176         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1177     ):
1178         def compute_nb_correct(input):
1179             result = input.clone()
1180             stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
1181             ar_mask = (result != input).long()
1182             masked_inplace_autoregression(
1183                 model,
1184                 self.batch_size,
1185                 result,
1186                 ar_mask,
1187                 deterministic_synthesis,
1188                 device=self.device,
1189             )
1190
1191             errors = ((result != input).long() * ar_mask).reshape(
1192                 -1, 1 + self.nb_digits
1193             )
1194             ar_mask = ar_mask.reshape(-1, 1 + self.nb_digits)
1195
1196             nb_total = ar_mask.max(1).values.sum()
1197             nb_correct = nb_total - errors.max(1).values.sum()
1198
1199             return nb_total, nb_correct
1200
1201         test_nb_total, test_nb_correct = compute_nb_correct(self.test_input[:1000])
1202
1203         logger(
1204             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
1205         )
1206
1207         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
1208
1209         ##############################################################
1210         # Log a few generated sequences
1211         input = self.test_input[:10, : 12 * (1 + self.nb_digits)]
1212         result = input.clone()
1213         stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
1214         ar_mask = (result != input).long()
1215
1216         # for n in range(result.size(0)):
1217         # logger(
1218         # f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
1219         # )
1220
1221         masked_inplace_autoregression(
1222             model,
1223             self.batch_size,
1224             result,
1225             ar_mask,
1226             deterministic_synthesis,
1227             device=self.device,
1228         )
1229
1230         #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
1231         for label, input in [
1232             ("train", self.train_input[:32]),
1233             ("test", self.test_input[:32]),
1234         ]:
1235             output = model(BracketedSequence(input)).x
1236             output = output.log_softmax(dim=-1)
1237             filename = os.path.join(
1238                 result_dir, f"stack_with_crossentropy_{n_epoch:04d}_{label}.txt"
1239             )
1240             with open(filename, "w") as f:
1241                 for n in range(input.size(0)):
1242                     s = stack.seq_to_str(
1243                         input[n], nb_stacks=self.nb_stacks, nb_digits=self.nb_digits
1244                     )
1245                     for t, k, w in zip(range(input[n].size(0)), input[n], s.split(" ")):
1246                         u = (
1247                             " " * (10 - len(w))
1248                             + w
1249                             + " "
1250                             + str(output[n][t][k].exp().item())
1251                             + "\n"
1252                         )
1253                         f.write(u)
1254                     f.write("\n")
1255             logger(f"wrote {filename}")
1256         #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
1257
1258         for n in range(result.size(0)):
1259             logger(
1260                 f"test_after  {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
1261             )
1262         ##############################################################
1263
1264
1265 ######################################################################
1266
1267 import rpl
1268
1269
1270 class RPL(Task):
1271     def tensorize(self, sequences):
1272         len_max = max([len(x) for x in sequences])
1273         return torch.cat(
1274             [
1275                 torch.tensor(
1276                     [
1277                         [
1278                             self.token2id[str(c)]
1279                             for c in s + ["<nul>"] * (len_max - len(s))
1280                         ]
1281                         for s in sequences
1282                     ]
1283                 )
1284             ],
1285             0,
1286         )
1287
1288     def seq2str(self, seq):
1289         return " ".join([self.id2token[i] for i in seq])
1290
1291     def __init__(
1292         self,
1293         nb_train_samples,
1294         nb_test_samples,
1295         batch_size,
1296         nb_starting_values=3,
1297         max_input=9,
1298         prog_len=6,
1299         nb_runs=5,
1300         no_prog=False,
1301         logger=None,
1302         device=torch.device("cpu"),
1303     ):
1304         super().__init__()
1305
1306         self.batch_size = batch_size
1307         self.device = device
1308         self.no_prog = no_prog
1309
1310         train_sequences = [
1311             rpl.generate(
1312                 nb_starting_values=nb_starting_values,
1313                 nb_result_values_max=4 * nb_starting_values,
1314                 max_input=max_input,
1315                 prog_len=prog_len,
1316                 nb_runs=nb_runs,
1317             )
1318             for _ in tqdm.tqdm(range(nb_train_samples), desc="train-data")
1319         ]
1320
1321         test_sequences = [
1322             rpl.generate(
1323                 nb_starting_values=nb_starting_values,
1324                 nb_result_values_max=4 * nb_starting_values,
1325                 max_input=max_input,
1326                 prog_len=prog_len,
1327                 nb_runs=nb_runs,
1328             )
1329             for _ in tqdm.tqdm(range(nb_test_samples), desc="test-data")
1330         ]
1331
1332         symbols = list(
1333             set(["<nul>"] + [x for l in train_sequences + test_sequences for x in l])
1334         )
1335         val_max = max([x if type(x) is int else 0 for x in symbols])
1336         symbols = list(filter(lambda x: type(x) is str, symbols))
1337         symbols.sort()
1338         symbols += [str(n) for n in range(val_max + 1)]
1339         self.token2id = dict([(c, n) for n, c in enumerate(symbols)])
1340         self.id2token = dict([(n, c) for c, n in self.token2id.items()])
1341
1342         self.t_nul = self.token2id["<nul>"]
1343         self.t_input = self.token2id["<in>"]
1344         self.t_output = self.token2id["<out>"]
1345         self.t_prog = self.token2id["<prg>"]
1346         self.t_end = self.token2id["<end>"]
1347
1348         self.train_input = self.tensorize(train_sequences)
1349         self.test_input = self.tensorize(test_sequences)
1350
1351         if no_prog:
1352             # Excise the program from every train and test example
1353             k = torch.arange(self.train_input.size(1), device=self.train_input.device)[
1354                 None, :
1355             ]
1356             p = (
1357                 ((self.train_input == self.t_prog).long() * k)
1358                 .max(1, keepdim=True)
1359                 .values
1360             )
1361             self.train_input = (
1362                 self.train_input * (k <= p).long()
1363                 + self.t_end * (k == p + 1).long()
1364                 + self.t_nul * (k > p + 1).long()
1365             )
1366             k = torch.arange(self.test_input.size(1), device=self.test_input.device)[
1367                 None, :
1368             ]
1369             p = (
1370                 ((self.test_input == self.t_prog).long() * k)
1371                 .max(1, keepdim=True)
1372                 .values
1373             )
1374             self.test_input = (
1375                 self.test_input * (k <= p).long()
1376                 + self.t_end * (k == p + 1).long()
1377                 + self.t_nul * (k > p + 1).long()
1378             )
1379
1380         if logger is not None:
1381             logger(f"value_max {val_max}")
1382             for x in self.train_input[:25]:
1383                 end = (x != self.t_nul).nonzero().max().item() + 1
1384                 seq = [self.id2token[i.item()] for i in x[:end]]
1385                 s = " ".join(seq)
1386                 logger(f"example_seq {s}")
1387
1388         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1389
1390     def batches(self, split="train", nb_to_use=-1, desc=None):
1391         assert split in {"train", "test"}
1392         input = self.train_input if split == "train" else self.test_input
1393         if nb_to_use > 0:
1394             input = input[:nb_to_use]
1395         if desc is None:
1396             desc = f"epoch-{split}"
1397         for batch in tqdm.tqdm(
1398             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1399         ):
1400             last = (batch != self.t_nul).max(0).values.nonzero().max() + 3
1401             batch = batch[:, :last].to(self.device)
1402             yield batch
1403
1404     def vocabulary_size(self):
1405         return self.nb_codes
1406
1407     def produce_results(
1408         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1409     ):
1410         # --------------------------------------------------------------------
1411         def compute_nb_errors_prog(input, nb_to_log=0):
1412             result = input.clone()
1413             s = (result == self.t_prog).long()
1414             ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
1415             result = (1 - ar_mask) * result + ar_mask * self.t_nul
1416
1417             masked_inplace_autoregression(
1418                 model,
1419                 self.batch_size,
1420                 result,
1421                 ar_mask,
1422                 deterministic_synthesis,
1423                 device=self.device,
1424             )
1425
1426             sum_nb_total, sum_nb_errors = 0, 0
1427             for one_input, one_result in zip(input, result):
1428                 seq = [self.id2token[i.item()] for i in one_result]
1429                 nb_total, nb_errors, prog, stacks = rpl.compute_nb_errors(seq)
1430                 sum_nb_total += 1
1431                 sum_nb_errors += 0 if nb_errors == 0 else 1
1432                 if nb_to_log > 0:
1433                     gt_seq = [self.id2token[i.item()] for i in one_input]
1434                     _, _, gt_prog, _ = rpl.compute_nb_errors(gt_seq)
1435                     gt_prog = " ".join([str(x) for x in gt_prog])
1436                     prog = " ".join([str(x) for x in prog])
1437                     comment = "*" if nb_errors == 0 else "-"
1438                     logger(f"{comment} PROG [{gt_prog}] PREDICTED [{prog}]")
1439                     for start_stack, target_stack, result_stack, correct in stacks:
1440                         comment = "*" if correct else "-"
1441                         start_stack = " ".join([str(x) for x in start_stack])
1442                         target_stack = " ".join([str(x) for x in target_stack])
1443                         result_stack = " ".join([str(x) for x in result_stack])
1444                         logger(
1445                             f"  {comment} [{start_stack}] -> [{target_stack}] PREDICTED [{result_stack}]"
1446                         )
1447                     nb_to_log -= 1
1448
1449             return sum_nb_total, sum_nb_errors
1450
1451         # --------------------------------------------------------------------
1452         def compute_nb_errors_output(input, nb_to_log=0):
1453             result = input.clone()
1454             k = torch.arange(result.size(1), device=result.device)[None, :]
1455             last_output_idx = (
1456                 ((result == self.t_output) * k).max(dim=1, keepdim=True).values
1457             )
1458             first_prog_idx = (
1459                 ((result == self.t_prog) * k).max(dim=1, keepdim=True).values
1460             )
1461             ar_mask = (k > last_output_idx).long() * (k < first_prog_idx).long()
1462             result = (1 - ar_mask) * result + ar_mask * self.t_nul
1463
1464             masked_inplace_autoregression(
1465                 model,
1466                 self.batch_size,
1467                 result,
1468                 ar_mask,
1469                 deterministic_synthesis,
1470                 device=self.device,
1471             )
1472
1473             sum_nb_total, sum_nb_errors = 0, 0
1474             for one_input, one_result, i, j in zip(
1475                 input, result, last_output_idx, first_prog_idx
1476             ):
1477                 seq = [self.id2token[i.item()] for i in one_result]
1478                 sum_nb_total += 1
1479                 correct = (one_input - one_result).abs().max() == 0
1480                 sum_nb_errors += 0 if correct else 1
1481                 if nb_to_log > 0:
1482                     result_stack = [
1483                         self.id2token[i.item()] for i in one_result[i : j + 1]
1484                     ]
1485                     target_stack = [
1486                         self.id2token[i.item()] for i in one_input[i : j + 1]
1487                     ]
1488                     comment = "*" if correct else "-"
1489                     result_stack = " ".join([str(x) for x in result_stack])
1490                     target_stack = " ".join([str(x) for x in target_stack])
1491                     logger(
1492                         f"output_test {comment} [{target_stack}] PREDICTED [{result_stack}]"
1493                     )
1494                     nb_to_log -= 1
1495
1496             return sum_nb_total, sum_nb_errors
1497
1498         # --------------------------------------------------------------------
1499
1500         if not self.no_prog:
1501             test_nb_total, test_nb_errors = compute_nb_errors_prog(
1502                 self.test_input[:1000].to(self.device), nb_to_log=10
1503             )
1504
1505             logger(
1506                 f"accuracy_prog_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
1507             )
1508
1509             logger(f"main_test_accuracy {n_epoch} {1-test_nb_errors/test_nb_total}")
1510
1511         test_nb_total, test_nb_errors = compute_nb_errors_output(
1512             self.test_input[:1000].to(self.device), nb_to_log=10
1513         )
1514
1515         logger(
1516             f"accuracy_output_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
1517         )
1518
1519         if save_attention_image is None:
1520             logger("no save_attention_image (is pycairo installed?)")
1521         else:
1522             ns = torch.randint(self.test_input.size(0), (1,)).item()
1523             input = self.test_input[ns : ns + 1].clone()
1524             last = (input != self.t_nul).max(0).values.nonzero().max() + 3
1525             input = input[:, :last].to(self.device)
1526
1527             with torch.autograd.no_grad():
1528                 t = model.training
1529                 model.eval()
1530                 model.record_attention(True)
1531                 model(BracketedSequence(input))
1532                 model.train(t)
1533                 ram = model.retrieve_attention()
1534                 model.record_attention(False)
1535
1536             tokens_output = [self.id2token[i.item()] for i in input[0]]
1537             tokens_input = ["n/a"] + tokens_output[:-1]
1538             for n_head in range(ram[0].size(1)):
1539                 filename = os.path.join(
1540                     result_dir, f"rpl_attention_{n_epoch}_h{n_head}.pdf"
1541                 )
1542                 attention_matrices = [m[0, n_head] for m in ram]
1543                 save_attention_image(
1544                     filename,
1545                     tokens_input,
1546                     tokens_output,
1547                     attention_matrices,
1548                     k_top=10,
1549                     # min_total_attention=0.9,
1550                     token_gap=12,
1551                     layer_gap=50,
1552                 )
1553                 logger(f"wrote {filename}")
1554
1555
1556 ######################################################################
1557
1558
1559 import expr
1560
1561
1562 class Expr(Task):
1563     def tensorize(self, sequences):
1564         len_max = max([len(x) for x in sequences])
1565         return torch.cat(
1566             [
1567                 torch.tensor(
1568                     [
1569                         [self.char2id[c] for c in s + "#" * (len_max - len(s))]
1570                         for s in sequences
1571                     ]
1572                 )
1573             ],
1574             0,
1575         ).to(self.device)
1576
1577     def __init__(
1578         self,
1579         nb_train_samples,
1580         nb_test_samples,
1581         nb_variables,
1582         sequence_length,
1583         operand_max,
1584         result_max,
1585         batch_size,
1586         device=torch.device("cpu"),
1587     ):
1588         super().__init__()
1589
1590         self.batch_size = batch_size
1591         self.device = device
1592
1593         train_sequences = expr.generate_sequences(
1594             nb_train_samples,
1595             nb_variables=nb_variables,
1596             length=sequence_length,
1597             operand_max=operand_max,
1598             result_max=result_max,
1599         )
1600
1601         test_sequences = expr.generate_sequences(
1602             nb_test_samples,
1603             nb_variables=nb_variables,
1604             length=sequence_length,
1605             operand_max=operand_max,
1606             result_max=result_max,
1607         )
1608
1609         symbols = list(set("#" + "".join(train_sequences + test_sequences)))
1610         symbols.sort()
1611
1612         self.char2id = dict([(c, n) for n, c in enumerate(symbols)])
1613         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
1614
1615         self.filler, self.space = self.char2id["#"], self.char2id[" "]
1616
1617         self.train_input = self.tensorize(train_sequences)
1618         self.test_input = self.tensorize(test_sequences)
1619
1620         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1621
1622     def batches(self, split="train", nb_to_use=-1, desc=None):
1623         assert split in {"train", "test"}
1624         input = self.train_input if split == "train" else self.test_input
1625         if nb_to_use > 0:
1626             input = input[:nb_to_use]
1627         if desc is None:
1628             desc = f"epoch-{split}"
1629         for batch in tqdm.tqdm(
1630             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1631         ):
1632             last = (batch != self.filler).max(0).values.nonzero().max() + 3
1633             batch = batch[:, :last]
1634             yield batch
1635
1636     def vocabulary_size(self):
1637         return self.nb_codes
1638
1639     def seq2str(self, s):
1640         return "".join([self.id2char[k.item()] for k in s])
1641
1642     def produce_results(
1643         self,
1644         n_epoch,
1645         model,
1646         result_dir,
1647         logger,
1648         deterministic_synthesis,
1649         input_file=None,
1650     ):
1651         def compute_nb_correct(input):
1652             result = input.clone()
1653             s = (result == self.space).long()
1654             ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
1655             result = (1 - ar_mask) * result + ar_mask * self.filler
1656             masked_inplace_autoregression(
1657                 model,
1658                 self.batch_size,
1659                 result,
1660                 ar_mask,
1661                 deterministic_synthesis,
1662                 device=self.device,
1663             )
1664
1665             nb_total = input.size(0)
1666             nb_correct = (input == result).long().min(1).values.sum()
1667
1668             #######################################################################
1669             # Comput predicted vs. true variable values
1670
1671             nb_delta = torch.zeros(5, dtype=torch.int64)
1672             nb_missed = 0
1673
1674             values_input = expr.extract_results([self.seq2str(s) for s in input])
1675             values_result = expr.extract_results([self.seq2str(s) for s in result])
1676
1677             filename = os.path.join(result_dir, f"expr_result_{n_epoch:04d}.txt")
1678
1679             with open(filename, "w") as f:
1680                 for i, r in zip(values_input, values_result):
1681                     for n, vi in i.items():
1682                         vr = r.get(n)
1683                         f.write(f"{vi} {-1 if vr is None else vr}\n")
1684
1685                         if vr is None or vr < 0:
1686                             nb_missed += 1
1687                         else:
1688                             d = abs(vr - vi)
1689                             if d >= nb_delta.size(0):
1690                                 nb_missed += 1
1691                             else:
1692                                 nb_delta[d] += 1
1693
1694             ######################################################################
1695
1696             return nb_total, nb_correct, nb_delta, nb_missed
1697
1698         (
1699             test_nb_total,
1700             test_nb_correct,
1701             test_nb_delta,
1702             test_nb_missed,
1703         ) = compute_nb_correct(self.test_input[:10000])
1704
1705         logger(
1706             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
1707         )
1708
1709         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
1710
1711         nb_total = test_nb_delta.sum() + test_nb_missed
1712         for d in range(test_nb_delta.size(0)):
1713             logger(
1714                 f"error_value {n_epoch} delta {d} {test_nb_delta[d]} {test_nb_delta[d]*100/nb_total:.02f}%"
1715             )
1716         logger(
1717             f"error_value {n_epoch} missed {test_nb_missed} {test_nb_missed*100/nb_total:.02f}%"
1718         )
1719
1720         ##############################################################
1721         # Log a few generated sequences
1722         if input_file is None:
1723             input = self.test_input[:10]
1724         else:
1725             with open(input_file, "r") as f:
1726                 sequences = [e.strip() for e in f.readlines()]
1727                 sequences = [s + " " + "#" * 50 for s in sequences]
1728                 input = self.tensorize(sequences)
1729
1730         result = input.clone()
1731         s = (result == self.space).long()
1732         ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
1733         result = (1 - ar_mask) * result + ar_mask * self.filler
1734
1735         for n in range(result.size(0)):
1736             logger(f"test_before {self.seq2str(result[n])}")
1737
1738         masked_inplace_autoregression(
1739             model,
1740             self.batch_size,
1741             result,
1742             ar_mask,
1743             deterministic_synthesis,
1744             device=self.device,
1745         )
1746
1747         correct = (1 - ar_mask) * self.space + ar_mask * input
1748         for n in range(result.size(0)):
1749             comment = "GOOD" if (result[n] - input[n]).abs().max() == 0 else ""
1750             logger(f"test_after  {self.seq2str(result[n])} {comment}")
1751             logger(f"truth       {self.seq2str(correct[n])}")
1752         ##############################################################
1753
1754
1755 ######################################################################
1756
1757 import grid
1758
1759
1760 class Grid(Task):
1761     # Make a tensor from a list of strings
1762     def str2tensor(self, descr):
1763         token_descr = [s.strip().split(" ") for s in descr]
1764         l = max([len(s) for s in token_descr])
1765         token_descr = [s + ["#"] * (l - len(s)) for s in token_descr]
1766         id_descr = [[self.token2id[u] for u in s] for s in token_descr]
1767         return torch.tensor(id_descr, device=self.device)
1768
1769     # Make a list of strings from a tensor
1770     def tensor2str(self, x):
1771         return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
1772
1773     # trim all the tensors in the tuple z to remove as much token from
1774     # left and right in the first tensor. If z is a tuple, all its
1775     # elements are trimed according to the triming for the first
1776     def trim(self, z, token="#"):
1777         n = self.token2id[token]
1778         if type(z) == tuple:
1779             x = z[0]
1780             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
1781             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
1782             return tuple([t[:, a:b] for t in z])
1783         else:
1784             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
1785             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
1786             return z[:, a:b]
1787
1788     ######################
1789
1790     def __init__(
1791         self,
1792         nb_train_samples,
1793         nb_test_samples,
1794         batch_size,
1795         size,
1796         fraction_play=0.0,
1797         logger=None,
1798         device=torch.device("cpu"),
1799     ):
1800         super().__init__()
1801
1802         self.device = device
1803         self.batch_size = batch_size
1804         self.grid_factory = grid.GridFactory(size=size)
1805         self.fraction_play = fraction_play
1806
1807         if logger is not None:
1808             logger(
1809                 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
1810             )
1811
1812         self.train_descr = self.grid_factory.generate_samples(
1813             nb=nb_train_samples,
1814             fraction_play=fraction_play,
1815             progress_bar=lambda r: tqdm.tqdm(r),
1816         )
1817
1818         self.test_descr = self.grid_factory.generate_samples(
1819             nb=nb_test_samples, fraction_play=0.0, progress_bar=lambda r: tqdm.tqdm(r)
1820         )
1821
1822         if fraction_play > 0:
1823             self.play_descr = self.grid_factory.generate_samples(
1824                 nb=25, fraction_play=1.0, progress_bar=lambda r: tqdm.tqdm(r)
1825             )
1826         else:
1827             self.play_descr = []
1828
1829         # Build the tokenizer
1830         tokens = set()
1831         for d in [self.train_descr, self.test_descr, self.play_descr]:
1832             for s in d:
1833                 for t in s.strip().split(" "):
1834                     tokens.add(t)
1835         # make this set a sorted list to get the same tensors given
1836         # the same descr
1837         tokens = list(tokens)
1838         tokens.sort()
1839         tokens = ["#"] + tokens
1840         self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
1841         self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
1842         self.t_nul = self.token2id["#"]
1843         self.t_true = self.token2id["true"]
1844         self.t_false = self.token2id["false"]
1845         # self.t_pipe = self.token2id["|"]
1846
1847         # Tokenize the train and test sets
1848         self.train_input = self.str2tensor(self.train_descr)
1849         self.test_input = self.str2tensor(self.test_descr)
1850         self.play_input = (
1851             None if len(self.play_descr) == 0 else self.str2tensor(self.play_descr)
1852         )
1853
1854     def batches(self, split="train", nb_to_use=-1, desc=None):
1855         assert split in {"train", "test"}
1856         input = self.train_input if split == "train" else self.test_input
1857         for batch in tqdm.tqdm(
1858             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
1859         ):
1860             yield self.trim(batch)
1861
1862     def vocabulary_size(self):
1863         return len(self.token2id)
1864
1865     def produce_results(
1866         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1867     ):
1868         correct = self.test_input[:1000]
1869         result = correct.clone()
1870         ar_mask = torch.logical_or(result == self.t_true, result == self.t_false).long()
1871         result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
1872
1873         logger(f"----------------------------------------------------------")
1874
1875         for e in self.tensor2str(result[:10]):
1876             logger(f"test_before {e}")
1877
1878         masked_inplace_autoregression(
1879             model,
1880             self.batch_size,
1881             result,
1882             ar_mask,
1883             deterministic_synthesis,
1884             device=self.device,
1885         )
1886
1887         logger(f"----------------------------------------------------------")
1888
1889         for e in self.tensor2str(result[:10]):
1890             logger(f"test_after  {e}")
1891
1892         logger(f"----------------------------------------------------------")
1893
1894         nb_total = ar_mask.sum().item()
1895         nb_correct = ((correct == result).long() * ar_mask).sum().item()
1896
1897         logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}")
1898         logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}")
1899
1900         if self.play_input is not None:
1901             result = self.play_input.clone()
1902             ar_mask = (result == self.t_pipe).long().cumsum(dim=1).clamp(max=1)
1903             result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
1904
1905             logger(f"----------------------------------------------------------")
1906
1907             for e in self.tensor2str(result[:10]):
1908                 logger(f"play_before {e}")
1909
1910             masked_inplace_autoregression(
1911                 model,
1912                 self.batch_size,
1913                 result,
1914                 ar_mask,
1915                 deterministic_synthesis,
1916                 device=self.device,
1917             )
1918
1919             logger(f"----------------------------------------------------------")
1920
1921             for e in self.tensor2str(result[:10]):
1922                 logger(f"play_after  {e}")
1923
1924             logger(f"----------------------------------------------------------")
1925
1926
1927 ######################################################################
1928
1929 import qmlp
1930
1931
1932 class QMLP(Task):
1933     ######################
1934
1935     def __init__(
1936         self,
1937         nb_train_samples,
1938         nb_test_samples,
1939         batch_size,
1940         result_dir,
1941         logger=None,
1942         device=torch.device("cpu"),
1943     ):
1944         super().__init__()
1945
1946         self.device = device
1947         self.batch_size = batch_size
1948         self.nb_samples_per_mlp = 256
1949
1950         if logger is not None:
1951             logger(
1952                 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
1953             )
1954
1955         seq, q_test_set, test_error = qmlp.generate_sequence_and_test_set(
1956             nb_mlps=nb_train_samples + nb_test_samples,
1957             nb_samples=self.nb_samples_per_mlp,
1958             device=self.device,
1959             batch_size=64,
1960             nb_epochs=250,
1961             nb_mlps_per_batch=1024,
1962         )
1963
1964         self.train_input = seq[:nb_train_samples]
1965         self.train_q_test_set = q_test_set[:nb_train_samples]
1966         self.train_ref_test_errors = test_error[:nb_train_samples]
1967         self.test_input = seq[nb_train_samples:]
1968         self.test_q_test_set = q_test_set[nb_train_samples:]
1969         self.test_ref_test_errors = test_error[nb_train_samples:]
1970
1971         filename = os.path.join(result_dir, f"train_errors_ref.dat")
1972         with open(filename, "w") as f:
1973             for e in self.train_ref_test_errors:
1974                 f.write(f"{e}\n")
1975
1976         filename = os.path.join(result_dir, f"test_errors_ref.dat")
1977         with open(filename, "w") as f:
1978             for e in self.test_ref_test_errors:
1979                 f.write(f"{e}\n")
1980
1981         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1982
1983     def batches(self, split="train", nb_to_use=-1, desc=None):
1984         assert split in {"train", "test"}
1985         input = self.train_input if split == "train" else self.test_input
1986         for batch in tqdm.tqdm(
1987             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
1988         ):
1989             yield batch
1990
1991     def vocabulary_size(self):
1992         return self.nb_codes
1993
1994     def produce_results(
1995         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1996     ):
1997         correct = self.test_input[:1000]
1998         result = correct.clone()
1999         ar_mask = (
2000             torch.arange(result.size(1), device=result.device)
2001             > self.nb_samples_per_mlp * 3 + 1
2002         ).long()[None, :]
2003         ar_mask = ar_mask.expand_as(result)
2004         result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
2005
2006         masked_inplace_autoregression(
2007             model,
2008             self.batch_size,
2009             result,
2010             ar_mask,
2011             deterministic_synthesis,
2012             device=self.device,
2013         )
2014
2015         q_train_set = result[:, : self.nb_samples_per_mlp * 3]
2016         q_params = result[:, self.nb_samples_per_mlp * 3 + 1 :]
2017         error_test = qmlp.evaluate_q_params(q_params, self.test_q_test_set)
2018
2019         filename = os.path.join(result_dir, f"test_errors_{n_epoch:04d}.dat")
2020         with open(filename, "w") as f:
2021             for e in error_test:
2022                 f.write(f"{e}\n")
2023
2024
2025 ######################################################################
2026
2027 import greed
2028
2029
2030 class Greed(Task):
2031     def __init__(
2032         self,
2033         nb_train_samples,
2034         nb_test_samples,
2035         batch_size,
2036         height,
2037         width,
2038         T,
2039         nb_walls,
2040         nb_coins,
2041         logger=None,
2042         device=torch.device("cpu"),
2043     ):
2044         super().__init__()
2045
2046         self.batch_size = batch_size
2047         self.device = device
2048
2049         self.world = greed.GreedWorld(height, width, T, nb_walls, nb_coins)
2050
2051         states, actions, rewards = self.world.generate_episodes(
2052             nb_train_samples + nb_test_samples
2053         )
2054         seq = self.world.episodes2seq(states, actions, rewards)
2055         self.train_input = seq[:nb_train_samples].to(self.device)
2056         self.test_input = seq[nb_train_samples:].to(self.device)
2057
2058     def wipe_lookahead_rewards(self, batch):
2059         t = torch.arange(batch.size(1), device=batch.device)[None, :]
2060         u = torch.randint(batch.size(1), (batch.size(0), 1), device=batch.device)
2061         lr_mask = (t <= u).long() * (
2062             t % self.world.it_len == self.world.index_lookahead_reward
2063         ).long()
2064
2065         return (
2066             lr_mask * self.world.lookahead_reward2code(greed.REWARD_UNKNOWN)
2067             + (1 - lr_mask) * batch
2068         )
2069
2070     def batches(self, split="train", nb_to_use=-1, desc=None):
2071         assert split in {"train", "test"}
2072         input = self.train_input if split == "train" else self.test_input
2073         if nb_to_use > 0:
2074             input = input[:nb_to_use]
2075         if desc is None:
2076             desc = f"epoch-{split}"
2077         for batch in tqdm.tqdm(
2078             input.split(self.batch_size), dynamic_ncols=True, desc=desc
2079         ):
2080             yield self.wipe_lookahead_rewards(batch)
2081
2082     def vocabulary_size(self):
2083         return self.world.nb_codes
2084
2085     def thinking_autoregression(
2086         self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
2087     ):
2088         snapshots = []
2089
2090         def ar(result, ar_mask, logit_biases=None):
2091             ar_mask = ar_mask.expand_as(result)
2092             result *= 1 - ar_mask
2093             masked_inplace_autoregression(
2094                 model,
2095                 self.batch_size,
2096                 result,
2097                 ar_mask,
2098                 deterministic_synthesis=deterministic_synthesis,
2099                 logit_biases=logit_biases,
2100                 device=self.device,
2101                 progress_bar_desc=None,
2102             )
2103             warnings.warn("keeping thinking snapshots", RuntimeWarning)
2104             snapshots.append(result[:100].detach().clone())
2105
2106         # Generate iteration after iteration
2107
2108         result = self.test_input[:250].clone()
2109         # Erase all the content but that of the first iteration
2110         result[:, self.world.it_len :] = -1
2111         # Set the lookahead_reward of the firs to UNKNOWN
2112         result[:, self.world.index_lookahead_reward] = self.world.lookahead_reward2code(
2113             greed.REWARD_UNKNOWN
2114         )
2115
2116         t = torch.arange(result.size(1), device=result.device)[None, :]
2117
2118         for u in tqdm.tqdm(
2119             range(0, result.size(1), self.world.it_len),
2120             desc="thinking",
2121         ):
2122             # Generate the next state but keep the initial one, the
2123             # lookahead_reward of previous iterations are set to
2124             # UNKNOWN
2125             if u > 0:
2126                 result[
2127                     :, u + self.world.index_lookahead_reward
2128                 ] = self.world.lookahead_reward2code(greed.REWARD_UNKNOWN)
2129                 ar_mask = (t >= u + self.world.index_states).long() * (
2130                     t < u + self.world.index_states + self.world.state_len
2131                 ).long()
2132                 ar(result, ar_mask)
2133
2134             # Generate the action and reward with lookahead_reward to +1
2135             result[
2136                 :, u + self.world.index_lookahead_reward
2137             ] = self.world.lookahead_reward2code(greed.REWARD_PLUS)
2138             ar_mask = (t >= u + self.world.index_reward).long() * (
2139                 t <= u + self.world.index_action
2140             ).long()
2141             ar(result, ar_mask)
2142
2143             # Set the lookahead_reward to UNKNOWN for the next iterations
2144             result[
2145                 :, u + self.world.index_lookahead_reward
2146             ] = self.world.lookahead_reward2code(greed.REWARD_UNKNOWN)
2147
2148         filename = os.path.join(result_dir, f"test_thinking_compute_{n_epoch:04d}.txt")
2149         with open(filename, "w") as f:
2150             for n in range(snapshots[0].size(0)):
2151                 for s in snapshots:
2152                     lr, s, a, r = self.world.seq2episodes(
2153                         s[n : n + 1],
2154                     )
2155                     str = self.world.episodes2str(
2156                         lr, s, a, r, unicode=True, ansi_colors=True
2157                     )
2158                     f.write(str)
2159                 f.write("\n\n")
2160
2161         # Saving the generated sequences
2162
2163         lr, s, a, r = self.world.seq2episodes(result)
2164         str = self.world.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True)
2165
2166         filename = os.path.join(result_dir, f"test_thinking_seq_{n_epoch:04d}.txt")
2167         with open(filename, "w") as f:
2168             f.write(str)
2169             logger(f"wrote {filename}")
2170
2171     def produce_results(
2172         self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
2173     ):
2174         result = self.wipe_lookahead_rewards(self.test_input[:250].clone())
2175
2176         # Saving the ground truth
2177
2178         lr, s, a, r = self.world.seq2episodes(
2179             result,
2180         )
2181         str = self.world.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True)
2182
2183         filename = os.path.join(result_dir, f"test_true_seq_{n_epoch:04d}.txt")
2184         with open(filename, "w") as f:
2185             f.write(str)
2186             logger(f"wrote {filename}")
2187
2188         # Re-generating from the first frame
2189
2190         ar_mask = (
2191             torch.arange(result.size(1), device=result.device) >= self.world.it_len
2192         ).long()[None, :]
2193         ar_mask = ar_mask.expand_as(result)
2194         result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
2195
2196         masked_inplace_autoregression(
2197             model,
2198             self.batch_size,
2199             result,
2200             ar_mask,
2201             deterministic_synthesis,
2202             device=self.device,
2203         )
2204
2205         # Saving the generated sequences
2206
2207         lr, s, a, r = self.world.seq2episodes(
2208             result,
2209         )
2210         str = self.world.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True)
2211
2212         filename = os.path.join(result_dir, f"test_seq_{n_epoch:04d}.txt")
2213         with open(filename, "w") as f:
2214             f.write(str)
2215             logger(f"wrote {filename}")
2216
2217         self.thinking_autoregression(
2218             n_epoch, model, result_dir, logger, deterministic_synthesis, nmax
2219         )
2220
2221
2222 ######################################################################