7894fcd62a90cd2d629137d36e2fb99c28e0a883
[culture.git] / tasks.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, os, tqdm, warnings
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 from mygpt import BracketedSequence
16
17 # from graph import save_attention_image
18 save_attention_image = None
19
20 ######################################################################
21
22
23 def masked_inplace_autoregression(
24     model,
25     batch_size,
26     input,
27     ar_mask,
28     deterministic_synthesis,
29     forbidden_tokens=None,
30     logit_biases=None,
31     progress_bar_desc="autoregression",
32     device=torch.device("cpu"),
33 ):
34     assert input.size() == ar_mask.size()
35
36     batches = zip(input.split(batch_size), ar_mask.split(batch_size))
37
38     if progress_bar_desc is not None:
39         batches = tqdm.tqdm(
40             batches,
41             dynamic_ncols=True,
42             desc=progress_bar_desc,
43             total=(input.size(0) + batch_size - 1) // batch_size,
44         )
45
46     with torch.autograd.no_grad():
47         t = model.training
48         model.eval()
49
50         for input, ar_mask in batches:
51             model.masked_inplace_autoregression(
52                 input,
53                 ar_mask,
54                 deterministic_synthesis,
55                 forbidden_tokens,
56                 logit_biases,
57             )
58
59         model.train(t)
60
61
62 ######################################################################
63
64
65 class Task:
66     def batches(self, split="train", nb_to_use=-1, desc=None):
67         pass
68
69     def vocabulary_size(self):
70         pass
71
72     def produce_results(
73         self, n_epoch, model, result_dir, logger, deterministic_synthesis
74     ):
75         pass
76
77
78 class TaskFromFile(Task):
79     def tensorize(self, pairs, shuffle):
80         len_max = max([len(x[0]) for x in pairs])
81
82         input = torch.cat(
83             [
84                 torch.tensor(
85                     [
86                         [self.char2id[c] for c in s[0] + "#" * (len_max - len(s[0]))]
87                         for s in pairs
88                     ]
89                 )
90             ],
91             0,
92         ).to("cpu")
93
94         pred_mask = torch.cat(
95             [
96                 torch.tensor(
97                     [
98                         [int(c) for c in s[1] + "0" * (len_max - len(s[1]))]
99                         for s in pairs
100                     ]
101                 )
102             ],
103             0,
104         ).to("cpu")
105
106         if shuffle:
107             i = torch.randperm(input.size(0))
108             input = input[i].contiguous()
109             pred_mask = pred_mask[i].contiguous()
110
111         return input, pred_mask
112
113     # trim all the tensors in the tuple z to remove as much token from
114     # left and right in the first tensor. If z is a tuple, all its
115     # elements are trimed according to the triming for the first
116     def trim(self, z, token="#"):
117         n = self.char2id[token]
118         if type(z) == tuple:
119             x = z[0]
120             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
121             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
122             return tuple([t[:, a:b] for t in z])
123         else:
124             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
125             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
126             return z[:, a:b]
127
128     def __init__(
129         self,
130         train_filename,
131         test_filename,
132         nb_train_samples,
133         nb_test_samples,
134         batch_size,
135         shuffle=False,
136         device=torch.device("cpu"),
137     ):
138         self.batch_size = batch_size
139         self.device = device
140
141         def read_file(filename, nb=-1):
142             pairs = []
143             with open(filename, "r") as f:
144                 while True:
145                     sequence = f.readline().strip()
146                     if not sequence:
147                         break
148                     pred_mask = f.readline().strip()
149                     assert len(sequence) == len(pred_mask)
150                     assert set(pred_mask).issubset({"0", "1", "2"}), f"{set(pred_mask)}"
151                     pairs.append((sequence, pred_mask))
152                     if len(pairs) == nb:
153                         break
154
155             if nb > 0:
156                 pairs = pairs[:nb]
157                 assert len(pairs) == nb
158
159             return pairs
160
161         train_pairs = read_file(train_filename, nb_train_samples)
162         test_pairs = read_file(test_filename, nb_test_samples)
163
164         symbols = ["#"] + list(
165             set("".join([x[0] for x in train_pairs + test_pairs])) - set(["#"])
166         )
167         self.char2id = dict([(c, n) for n, c in enumerate(symbols)])
168         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
169
170         self.train_input, self.train_pred_masks = self.tensorize(
171             train_pairs, shuffle=shuffle
172         )
173         self.test_input, self.test_pred_masks = self.tensorize(
174             test_pairs, shuffle=shuffle
175         )
176
177     def batches(self, split="train", nb_to_use=-1, desc=None):
178         assert split in {"train", "test"}
179         input = self.train_input if split == "train" else self.test_input
180         if nb_to_use > 0:
181             input = input[:nb_to_use]
182         if desc is None:
183             desc = f"epoch-{split}"
184         for batch in tqdm.tqdm(
185             input.split(self.batch_size), dynamic_ncols=True, desc=desc
186         ):
187             yield self.trim(batch).to(self.device)
188
189     def vocabulary_size(self):
190         return len(self.char2id)
191
192     def tensor2str(self, t):
193         return ["".join([self.id2char[x.item()] for x in s]) for s in t]
194
195     def produce_results(
196         self, n_epoch, model, result_dir, logger, deterministic_synthesis
197     ):
198         correct = self.trim(self.test_input[:1000]).to(self.device)
199         result = correct.clone()
200         pred_mask = self.test_pred_masks[:1000, : result.size(1)].to(self.device)
201         ar_mask = (pred_mask > 0).long()
202         result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
203
204         logger(f"----------------------------------------------------------")
205
206         for e in self.tensor2str(result[:50]):
207             logger(f"test_before {e}")
208
209         masked_inplace_autoregression(
210             model,
211             self.batch_size,
212             result,
213             ar_mask,
214             deterministic_synthesis,
215             device=self.device,
216         )
217
218         logger(f"----------------------------------------------------------")
219
220         for e, c in zip(self.tensor2str(result[:50]), self.tensor2str(correct[:50])):
221             logger(f"test_after  {e}")
222             logger(f"correct     {c}")
223
224         logger(f"----------------------------------------------------------")
225
226         err_mask = (pred_mask == 2).long()
227         nb_total = err_mask.sum().item()
228         nb_correct = ((correct == result).long() * err_mask).sum().item()
229
230         logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}")
231         logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}")
232
233
234 ####################
235
236 import problems
237
238
239 class SandBox(Task):
240     def __init__(
241         self,
242         problem,
243         nb_train_samples,
244         nb_test_samples,
245         batch_size,
246         logger=None,
247         device=torch.device("cpu"),
248         max_nb_codes=1024,
249     ):
250         super().__init__()
251
252         self.batch_size = batch_size
253         self.device = device
254         self.problem = problem
255
256         self.train_input, self.train_ar_mask = self.problem.generate_sequences(
257             nb_train_samples
258         )
259         self.test_input, self.test_ar_mask = self.problem.generate_sequences(
260             nb_test_samples
261         )
262
263         self.train_input, self.train_ar_mask = self.train_input.to(
264             device
265         ), self.train_ar_mask.to(device)
266         self.test_input, self.test_ar_mask = self.test_input.to(
267             device
268         ), self.test_ar_mask.to(device)
269
270         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
271
272         # A bit of paranoia never hurts
273         assert self.nb_codes <= max_nb_codes
274         assert self.train_input.min() >= 0
275         assert self.test_input.min() >= 0
276         assert tuple(x.item() for x in self.train_ar_mask.unique()) in {
277             (0,),
278             (1,),
279             (0, 1),
280         }
281         assert tuple(x.item() for x in self.test_ar_mask.unique()) in {
282             (0,),
283             (1,),
284             (0, 1),
285         }
286
287         if logger is not None:
288             for s, a in zip(self.train_input[:100], self.train_ar_mask[:100]):
289                 logger(f"train_sequences {self.problem.seq2str(s)}")
290                 a = "".join(["01"[x.item()] for x in a])
291                 logger(f"                {a}")
292
293     def batches(self, split="train", nb_to_use=-1, desc=None):
294         assert split in {"train", "test"}
295         input = self.train_input if split == "train" else self.test_input
296         if nb_to_use > 0:
297             input = input[:nb_to_use]
298         if desc is None:
299             desc = f"epoch-{split}"
300         for batch in tqdm.tqdm(
301             input.split(self.batch_size), dynamic_ncols=True, desc=desc
302         ):
303             yield batch
304
305     def vocabulary_size(self):
306         return self.nb_codes
307
308     def produce_results(
309         self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
310     ):
311         def compute_accuracy(input, ar_mask, logger=None):
312             input, ar_mask = input[:nmax], ar_mask[:nmax]
313             result = input.clone() * (1 - ar_mask)
314
315             masked_inplace_autoregression(
316                 model,
317                 self.batch_size,
318                 result,
319                 ar_mask,
320                 deterministic_synthesis,
321                 progress_bar_desc=None,
322                 device=self.device,
323             )
324
325             log_ground_truth = ar_mask.min() == 0
326
327             if logger is not None:
328                 for sp, st in zip(result[:10], input[:10]):
329                     logger(
330                         f"test_sequences {n_epoch} prediction   {self.problem.seq2str(sp)}"
331                     )
332                     if log_ground_truth:
333                         logger(
334                             f"               {n_epoch} ground truth {self.problem.seq2str(st)}"
335                         )
336
337             nb_total, nb_correct = self.problem.compute_nb_correct(
338                 input, ar_mask, result
339             )
340
341             # nb_total = ar_mask.sum().item()
342             # nb_correct = ((result == input).long() * ar_mask).sum().item()
343
344             return nb_total, nb_correct
345
346         train_nb_total, train_nb_correct = compute_accuracy(
347             self.train_input, self.train_ar_mask
348         )
349
350         logger(
351             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
352         )
353
354         test_nb_total, test_nb_correct = compute_accuracy(
355             self.test_input, self.test_ar_mask, logger
356         )
357
358         logger(
359             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
360         )
361
362         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
363
364         if save_attention_image is not None:
365             for k in range(10):
366                 ns = torch.randint(self.test_input.size(0), (1,)).item()
367                 input = self.test_input[ns : ns + 1].clone()
368
369                 with torch.autograd.no_grad():
370                     t = model.training
371                     model.eval()
372                     # model.record_attention(True)
373                     model(BracketedSequence(input))
374                     model.train(t)
375                     # ram = model.retrieve_attention()
376                     # model.record_attention(False)
377
378                 # tokens_output = [c for c in self.problem.seq2str(input[0])]
379                 # tokens_input = ["n/a"] + tokens_output[:-1]
380                 # for n_head in range(ram[0].size(1)):
381                 # filename = os.path.join(
382                 # result_dir, f"sandbox_attention_{k}_h{n_head}.pdf"
383                 # )
384                 # attention_matrices = [m[0, n_head] for m in ram]
385                 # save_attention_image(
386                 # filename,
387                 # tokens_input,
388                 # tokens_output,
389                 # attention_matrices,
390                 # k_top=10,
391                 ##min_total_attention=0.9,
392                 # token_gap=12,
393                 # layer_gap=50,
394                 # )
395                 # logger(f"wrote {filename}")
396
397
398 ######################################################################
399
400 import world
401
402
403 class World(Task):
404     def __init__(
405         self,
406         nb_train_samples,
407         nb_test_samples,
408         batch_size,
409         logger=None,
410         device=torch.device("cpu"),
411     ):
412         super().__init__()
413
414         self.batch_size = batch_size
415         self.device = device
416         self.height = 6
417         self.width = 8
418
419         self.train_input = world.generate(
420             nb_train_samples, height=self.height, width=self.width
421         )
422         self.train_ar_mask = (
423             (torch.arange(self.train_input.size(1)) > self.train_input.size(1) // 2)
424             .long()[None, :]
425             .expand_as(self.train_input)
426         )
427
428         self.test_input = world.generate(
429             nb_test_samples, height=self.height, width=self.width
430         )
431         self.test_ar_mask = (
432             (torch.arange(self.test_input.size(1)) > self.test_input.size(1) // 2)
433             .long()[None, :]
434             .expand_as(self.test_input)
435         )
436
437         self.train_input, self.train_ar_mask = self.train_input.to(
438             device
439         ), self.train_ar_mask.to(device)
440         self.test_input, self.test_ar_mask = self.test_input.to(
441             device
442         ), self.test_ar_mask.to(device)
443
444         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
445
446     def batches(self, split="train", nb_to_use=-1, desc=None):
447         assert split in {"train", "test"}
448         input = self.train_input if split == "train" else self.test_input
449         if nb_to_use > 0:
450             input = input[:nb_to_use]
451         if desc is None:
452             desc = f"epoch-{split}"
453         for batch in tqdm.tqdm(
454             input.split(self.batch_size), dynamic_ncols=True, desc=desc
455         ):
456             yield batch
457
458     def vocabulary_size(self):
459         return self.nb_codes
460
461     def produce_results(
462         self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
463     ):
464         def compute_accuracy(input, ar_mask, logger=None):
465             input, ar_mask = input[:nmax], ar_mask[:nmax]
466             result = input.clone() * (1 - ar_mask)
467
468             masked_inplace_autoregression(
469                 model,
470                 self.batch_size,
471                 result,
472                 ar_mask,
473                 deterministic_synthesis,
474                 progress_bar_desc=None,
475                 device=self.device,
476             )
477
478             nb_total, nb_correct = (
479                 input.size(0),
480                 (input == result).long().min(dim=1).values.sum(),
481             )
482
483             return nb_total, nb_correct
484
485         train_nb_total, train_nb_correct = compute_accuracy(
486             self.train_input, self.train_ar_mask
487         )
488
489         logger(
490             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
491         )
492
493         test_nb_total, test_nb_correct = compute_accuracy(
494             self.test_input, self.test_ar_mask, logger
495         )
496
497         logger(
498             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
499         )
500
501         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
502
503         if save_attention_image is not None:
504             for k in range(10):
505                 ns = torch.randint(self.test_input.size(0), (1,)).item()
506                 input = self.test_input[ns : ns + 1].clone()
507
508                 with torch.autograd.no_grad():
509                     t = model.training
510                     model.eval()
511                     # model.record_attention(True)
512                     model(BracketedSequence(input))
513                     model.train(t)
514                     # ram = model.retrieve_attention()
515                     # model.record_attention(False)
516
517                 # tokens_output = [c for c in self.problem.seq2str(input[0])]
518                 # tokens_input = ["n/a"] + tokens_output[:-1]
519                 # for n_head in range(ram[0].size(1)):
520                 # filename = os.path.join(
521                 # result_dir, f"sandbox_attention_{k}_h{n_head}.pdf"
522                 # )
523                 # attention_matrices = [m[0, n_head] for m in ram]
524                 # save_attention_image(
525                 # filename,
526                 # tokens_input,
527                 # tokens_output,
528                 # attention_matrices,
529                 # k_top=10,
530                 ##min_total_attention=0.9,
531                 # token_gap=12,
532                 # layer_gap=50,
533                 # )
534                 # logger(f"wrote {filename}")
535
536
537 ######################################################################
538
539 import picoclvr
540
541
542 class PicoCLVR(Task):
543     # Make a tensor from a list of strings
544     def tensorize(self, descr):
545         token_descr = [s.strip().split(" ") for s in descr]
546         l = max([len(s) for s in token_descr])
547         token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
548         id_descr = [[self.token2id[u] for u in s] for s in token_descr]
549         return torch.tensor(id_descr, device=self.device)
550
551     # Make a list of strings from a tensor
552     def detensorize(self, x):
553         return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
554
555     # trim all the tensors in the tuple z to remove as much token from
556     # left and right in the first tensor. If z is a tuple, all its
557     # elements are trimed according to the triming for the first
558     def trim(self, z, token="<nul>"):
559         n = self.token2id[token]
560         if type(z) == tuple:
561             x = z[0]
562             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
563             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
564             return tuple([t[:, a:b] for t in z])
565         else:
566             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
567             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
568             return z[:, a:b]
569
570     ######################
571
572     def __init__(
573         self,
574         nb_train_samples,
575         nb_test_samples,
576         batch_size,
577         height,
578         width,
579         nb_colors=5,
580         logger=None,
581         device=torch.device("cpu"),
582         pruner_train=None,
583         pruner_eval=None,
584     ):
585         super().__init__()
586
587         def generate_descr(nb, cache_suffix, pruner):
588             return picoclvr.generate(
589                 nb,
590                 height=self.height,
591                 width=self.width,
592                 nb_colors=nb_colors,
593                 pruner=pruner,
594             )
595
596         self.height = height
597         self.width = width
598         self.batch_size = batch_size
599         self.device = device
600         self.pruner_train = pruner_train
601         self.pruner_eval = pruner_eval
602
603         if logger is not None:
604             logger(
605                 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
606             )
607
608         self.train_descr = generate_descr(
609             nb_train_samples, "train", pruner=self.pruner_train
610         )
611         self.test_descr = generate_descr(nb_test_samples, "test", pruner=None)
612
613         # Build the tokenizer
614         tokens = {"<nul>", "<img>"}
615         for d in [self.train_descr, self.test_descr]:
616             for s in d:
617                 for t in s.strip().split(" "):
618                     tokens.add(t)
619         # make this set a sorted list to get the same tensors given
620         # the same descr
621         tokens = list(tokens)
622         tokens.sort()
623         self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
624         self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
625         self.t_img, self.t_nul = self.token2id["<img>"], self.token2id["<nul>"]
626
627         # Tokenize the train and test sets
628         self.train_input = self.tensorize(self.train_descr)
629         self.test_input = self.tensorize(self.test_descr)
630
631     def batches(self, split="train", nb_to_use=-1, desc=None):
632         assert split in {"train", "test"}
633         input = self.train_input if split == "train" else self.test_input
634         for batch in tqdm.tqdm(
635             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
636         ):
637             yield self.trim(batch)
638
639     def vocabulary_size(self):
640         return len(self.token2id)
641
642     def compute_missing_properties(
643         self, n_epoch, model, logger, deterministic_synthesis, pruner=None
644     ):
645         acc_nb_requested_properties = []
646         acc_nb_missing_properties = []
647         acc_nb_results = 0
648
649         for input in tqdm.tqdm(
650             self.test_input.split(self.batch_size),
651             dynamic_ncols=True,
652             desc=f"test-properties",
653         ):
654             result = input.clone()
655             ar_mask = (result == self.t_img).long().cumsum(dim=1).clamp(max=1)
656             result = (1 - ar_mask) * result + ar_mask * self.t_nul
657             masked_inplace_autoregression(
658                 model,
659                 self.batch_size,
660                 result,
661                 ar_mask,
662                 deterministic_synthesis,
663                 progress_bar_desc=None,
664                 device=self.device,
665             )
666
667             result_descr = self.detensorize(result)
668             np = picoclvr.nb_properties(
669                 result_descr,
670                 height=self.height,
671                 width=self.width,
672                 pruner=pruner,
673             )
674             nb_requested_properties, _, nb_missing_properties = zip(*np)
675             acc_nb_requested_properties += nb_requested_properties
676             acc_nb_missing_properties += nb_missing_properties
677             acc_nb_results += len(result_descr)
678
679         nb_requested_properties = sum(acc_nb_requested_properties)
680         nb_missing_properties = sum(acc_nb_missing_properties)
681
682         prefix = "" if pruner is None else "pruned_"
683         logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
684         logger(
685             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
686         )
687         logger(
688             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
689         )
690
691         logger(
692             f"main_test_accuracy {n_epoch} {1-nb_missing_properties/nb_requested_properties}"
693         )
694
695     ######################################################################
696
697     def produce_results(
698         self, n_epoch, model, result_dir, logger, deterministic_synthesis
699     ):
700         self.compute_missing_properties(n_epoch, model, logger, deterministic_synthesis)
701
702         if self.pruner_eval is not None:
703             self.compute_missing_properties(n_epoch, model, self.pruner_eval)
704
705         nb_tokens_to_generate = self.height * self.width + 3
706         result_descr = []
707         nb_per_primer = 8
708         primer = []
709
710         for primer_descr in [
711             "red above green <sep> green top <sep> blue right of red",
712             "there is red <sep> there is yellow <sep> there is blue",
713             "red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left",
714             "green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top",
715         ]:
716             primer += [primer_descr + " <img>"] * nb_per_primer
717
718         result = self.tensorize(primer)
719         fill = result.new_full(
720             result.size()[:-1] + (self.height * self.width + 1,), self.t_nul
721         )
722         result = torch.cat((result, fill), 1)
723         ar_mask = (result == self.t_nul).long()
724         masked_inplace_autoregression(
725             model,
726             self.batch_size,
727             result,
728             ar_mask,
729             deterministic_synthesis,
730             device=self.device,
731         )
732         result_descr = self.detensorize(result)
733
734         np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
735
736         acc_nb_requested_properties, _, acc_nb_missing_properties = zip(*np)
737         acc_nb_results = len(result_descr)
738
739         nb_requested_properties = sum(acc_nb_requested_properties)
740         nb_missing_properties = sum(acc_nb_missing_properties)
741
742         prefix = "demo_"
743         logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
744         logger(
745             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
746         )
747         logger(
748             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
749         )
750
751         img = picoclvr.descr2img(result_descr, height=self.height, width=self.width)
752
753         if img.dim() == 5:
754             if img.size(1) == 1:
755                 img = F.pad(img.squeeze(1), pad=(1, 1, 1, 1), value=64)
756             else:
757                 img = torch.cat(
758                     [
759                         torchvision.utils.make_grid(x, padding=1, pad_value=64)[None]
760                         for x in img
761                     ],
762                     0,
763                 )
764
765         image_name = os.path.join(result_dir, f"picoclvr_result_{n_epoch:04d}.png")
766         torchvision.utils.save_image(
767             img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=0.0
768         )
769         logger(f"wrote {image_name}")
770
771
772 ######################################################################
773
774
775 class MNIST(Task):
776     def __init__(
777         self, nb_train_samples, nb_test_samples, batch_size, device=torch.device("cpu")
778     ):
779         super().__init__()
780
781         self.nb_train_samples = (nb_train_samples,)
782         self.nb_test_samples = (nb_test_samples,)
783         self.batch_size = batch_size
784         self.device = device
785         data_set = torchvision.datasets.MNIST(root="./data", train=True, download=True)
786         self.train_input = data_set.data[:nb_train_samples].view(-1, 28 * 28).long()
787         data_set = torchvision.datasets.MNIST(root="./data", train=False, download=True)
788         self.test_input = data_set.data[:nb_test_samples].view(-1, 28 * 28).long()
789
790     def batches(self, split="train", nb_to_use=-1, desc=None):
791         assert split in {"train", "test"}
792         input = self.train_input if split == "train" else self.test_input
793         if nb_to_use > 0:
794             input = input[:nb_to_use]
795         if desc is None:
796             desc = f"epoch-{split}"
797         for batch in tqdm.tqdm(
798             input.split(self.batch_size), dynamic_ncols=True, desc=desc
799         ):
800             yield batch
801
802     def vocabulary_size(self):
803         return 256
804
805     def produce_results(
806         self, n_epoch, model, result_dir, logger, deterministic_synthesis
807     ):
808         results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64)
809         ar_mask = torch.full_like(results, 1)
810         masked_inplace_autoregression(
811             model,
812             self.batch_size,
813             results,
814             ar_mask,
815             deterministic_synthesis,
816             device=self.device,
817         )
818         image_name = os.path.join(result_dir, f"mnist_result_{n_epoch:04d}.png")
819         torchvision.utils.save_image(
820             1 - results.reshape(-1, 1, 28, 28) / 255.0,
821             image_name,
822             nrow=16,
823             pad_value=0.8,
824         )
825         logger(f"wrote {image_name}")
826
827
828 ######################################################################
829
830 import maze
831
832
833 class Maze(Task):
834     def map2seq(self, *m):
835         return torch.cat([x.flatten(1) for x in m], 1)
836
837     def seq2map(self, s):
838         s = s.reshape(s.size(0), -1, self.height, self.width)
839         return (s[:, k] for k in range(s.size(1)))
840
841     def __init__(
842         self,
843         nb_train_samples,
844         nb_test_samples,
845         batch_size,
846         height,
847         width,
848         nb_walls,
849         device=torch.device("cpu"),
850     ):
851         super().__init__()
852
853         self.batch_size = batch_size
854         self.height = height
855         self.width = width
856         self.device = device
857
858         train_mazes, train_paths, _ = maze.create_maze_data(
859             nb_train_samples,
860             height=height,
861             width=width,
862             nb_walls=nb_walls,
863             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"),
864         )
865         self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device))
866
867         test_mazes, test_paths, _ = maze.create_maze_data(
868             nb_test_samples,
869             height=height,
870             width=width,
871             nb_walls=nb_walls,
872             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"),
873         )
874         self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
875
876         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
877
878     def batches(self, split="train", nb_to_use=-1, desc=None):
879         assert split in {"train", "test"}
880         input = self.train_input if split == "train" else self.test_input
881         if nb_to_use > 0:
882             input = input[:nb_to_use]
883         if desc is None:
884             desc = f"epoch-{split}"
885         for batch in tqdm.tqdm(
886             input.split(self.batch_size), dynamic_ncols=True, desc=desc
887         ):
888             yield batch
889
890     def vocabulary_size(self):
891         return self.nb_codes
892
893     def compute_error(
894         self, model, split="train", nb_to_use=-1, deterministic_synthesis=False
895     ):
896         model_device = next(model.parameters()).device
897         nb_total, nb_correct = 0, 0
898         count = torch.zeros(
899             self.width * self.height,
900             self.width * self.height,
901             device=model_device,
902             dtype=torch.int64,
903         )
904
905         for input in self.batches(split, nb_to_use):
906             input = input.to(model_device)
907             result = input.clone()
908             ar_mask = result.new_zeros(result.size())
909             ar_mask[:, self.height * self.width :] = 1
910             result *= 1 - ar_mask
911             masked_inplace_autoregression(
912                 model,
913                 self.batch_size,
914                 result,
915                 ar_mask,
916                 deterministic_synthesis,
917                 progress_bar_desc=None,
918                 device=self.device,
919             )
920             mazes, paths = self.seq2map(result)
921             path_correctness = maze.path_correctness(mazes, paths)
922             nb_correct += path_correctness.long().sum()
923             nb_total += mazes.size(0)
924
925             optimal_path_lengths = (
926                 (input[:, self.height * self.width :] == maze.v_path).long().sum(1)
927             )
928             predicted_path_lengths = (
929                 (result[:, self.height * self.width :] == maze.v_path).long().sum(1)
930             )
931             optimal_path_lengths = optimal_path_lengths[path_correctness]
932             predicted_path_lengths = predicted_path_lengths[path_correctness]
933             count[optimal_path_lengths, predicted_path_lengths] += 1
934
935         if count.max() == 0:
936             count = None
937         else:
938             count = count[
939                 : count.sum(1).nonzero().max() + 1, : count.sum(0).nonzero().max() + 1
940             ]
941
942         return nb_total, nb_correct, count
943
944     def produce_results(
945         self, n_epoch, model, result_dir, logger, deterministic_synthesis
946     ):
947         train_nb_total, train_nb_correct, count = self.compute_error(
948             model,
949             "train",
950             nb_to_use=1000,
951             deterministic_synthesis=deterministic_synthesis,
952         )
953         logger(
954             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
955         )
956
957         test_nb_total, test_nb_correct, count = self.compute_error(
958             model,
959             "test",
960             nb_to_use=1000,
961             deterministic_synthesis=deterministic_synthesis,
962         )
963         logger(
964             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
965         )
966
967         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
968
969         if count is not None:
970             proportion_optimal = count.diagonal().sum().float() / count.sum()
971             logger(f"proportion_optimal_test {proportion_optimal*100:.02f}%")
972             with open(
973                 os.path.join(result_dir, f"maze_result_{n_epoch:04d}.txt"), "w"
974             ) as f:
975                 for i in range(count.size(0)):
976                     for j in range(count.size(1)):
977                         eol = " " if j < count.size(1) - 1 else "\n"
978                         f.write(f"{count[i,j]}{eol}")
979
980         input = self.test_input[:48].to(next(model.parameters()).device)
981         result = input.clone()
982         ar_mask = result.new_zeros(result.size())
983         ar_mask[:, self.height * self.width :] = 1
984         result *= 1 - ar_mask
985         masked_inplace_autoregression(
986             model,
987             self.batch_size,
988             result,
989             ar_mask,
990             deterministic_synthesis,
991             device=self.device,
992         )
993
994         mazes, paths = self.seq2map(input)
995         _, predicted_paths = self.seq2map(result)
996
997         filename = os.path.join(result_dir, f"maze_result_{n_epoch:04d}.png")
998         maze.save_image(
999             filename,
1000             mazes=mazes,
1001             target_paths=paths,
1002             predicted_paths=predicted_paths,
1003             path_correct=maze.path_correctness(mazes, predicted_paths),
1004             path_optimal=maze.path_optimality(paths, predicted_paths),
1005         )
1006         logger(f"wrote {filename}")
1007
1008
1009 ######################################################################
1010
1011
1012 import snake
1013
1014
1015 class Snake(Task):
1016     def __init__(
1017         self,
1018         nb_train_samples,
1019         nb_test_samples,
1020         batch_size,
1021         height,
1022         width,
1023         nb_colors,
1024         length,
1025         prompt_length,
1026         device=torch.device("cpu"),
1027     ):
1028         super().__init__()
1029
1030         self.batch_size = batch_size
1031         self.height = height
1032         self.width = width
1033         self.device = device
1034         self.prompt_length = prompt_length
1035
1036         self.train_input, self.train_prior_visits, _, _ = snake.generate_sequences(
1037             nb_train_samples,
1038             height,
1039             width,
1040             nb_colors,
1041             length,
1042             prompt_length,
1043             self.device,
1044         )
1045         self.test_input, self.test_prior_visits, _, _ = snake.generate_sequences(
1046             nb_test_samples,
1047             height,
1048             width,
1049             nb_colors,
1050             length,
1051             prompt_length,
1052             self.device,
1053         )
1054
1055         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1056
1057     def batches(self, split="train", nb_to_use=-1, desc=None):
1058         assert split in {"train", "test"}
1059         input = self.train_input if split == "train" else self.test_input
1060         if nb_to_use > 0:
1061             input = input[:nb_to_use]
1062         if desc is None:
1063             desc = f"epoch-{split}"
1064         for batch in tqdm.tqdm(
1065             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1066         ):
1067             yield batch
1068
1069     def vocabulary_size(self):
1070         return self.nb_codes
1071
1072     def produce_results(
1073         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1074     ):
1075         def compute_nb_correct(input, prior_visits):
1076             result = input.clone()
1077             i = torch.arange(result.size(1), device=result.device)[None, :]
1078             ar_mask = (
1079                 torch.logical_and(i >= self.prompt_length * 2, i % 2 == 0)
1080                 .long()
1081                 .expand_as(result)
1082             )
1083             result *= 1 - ar_mask
1084
1085             masked_inplace_autoregression(
1086                 model,
1087                 self.batch_size,
1088                 result,
1089                 ar_mask,
1090                 deterministic_synthesis,
1091                 device=self.device,
1092             )
1093
1094             nb_total = ((prior_visits > 0) * ar_mask).sum()
1095
1096             nb_correct = ((result == input).long() * (prior_visits > 0) * ar_mask).sum()
1097
1098             return nb_total, nb_correct
1099
1100         test_nb_total, test_nb_correct = compute_nb_correct(
1101             self.test_input[:1000], self.test_prior_visits[:1000]
1102         )
1103
1104         logger(
1105             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
1106         )
1107
1108         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
1109
1110
1111 ######################################################################
1112
1113
1114 import stack
1115
1116
1117 class Stack(Task):
1118     def __init__(
1119         self,
1120         nb_train_samples,
1121         nb_test_samples,
1122         batch_size,
1123         logger,
1124         nb_steps,
1125         nb_stacks,
1126         nb_digits,
1127         fraction_values_for_train=None,
1128         device=torch.device("cpu"),
1129     ):
1130         super().__init__()
1131
1132         self.batch_size = batch_size
1133         self.nb_steps = nb_steps
1134         self.nb_stacks = nb_stacks
1135         self.nb_digits = nb_digits
1136         self.device = device
1137
1138         if fraction_values_for_train is None:
1139             values_for_train = None
1140             values_for_test = None
1141         else:
1142             all = torch.randperm(10**nb_digits)
1143             nb_for_train = int(all.size(0) * fraction_values_for_train)
1144             values_for_train = all[:nb_for_train]
1145             values_for_test = all[nb_for_train:]
1146
1147         self.train_input, self.train_stack_counts = stack.generate_sequences(
1148             nb_train_samples,
1149             nb_steps,
1150             nb_stacks,
1151             nb_digits,
1152             values_for_train,
1153             self.device,
1154         )
1155
1156         self.test_input, self.test_stack_counts = stack.generate_sequences(
1157             nb_test_samples,
1158             nb_steps,
1159             nb_stacks,
1160             nb_digits,
1161             values_for_test,
1162             self.device,
1163         )
1164
1165         i = torch.logical_and(self.test_input % 2 == 1, self.test_input < 2 * nb_stacks)
1166         counts = self.test_stack_counts.flatten()[i.flatten()]
1167         counts = F.one_hot(counts).sum(0)
1168         logger(f"test_pop_stack_counts {counts}")
1169
1170         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1171
1172     def batches(self, split="train", nb_to_use=-1, desc=None):
1173         assert split in {"train", "test"}
1174         input = self.train_input if split == "train" else self.test_input
1175         if nb_to_use > 0:
1176             input = input[:nb_to_use]
1177         if desc is None:
1178             desc = f"epoch-{split}"
1179         for batch in tqdm.tqdm(
1180             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1181         ):
1182             yield batch
1183
1184     def vocabulary_size(self):
1185         return self.nb_codes
1186
1187     def produce_results(
1188         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1189     ):
1190         def compute_nb_correct(input):
1191             result = input.clone()
1192             stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
1193             ar_mask = (result != input).long()
1194             masked_inplace_autoregression(
1195                 model,
1196                 self.batch_size,
1197                 result,
1198                 ar_mask,
1199                 deterministic_synthesis,
1200                 device=self.device,
1201             )
1202
1203             errors = ((result != input).long() * ar_mask).reshape(
1204                 -1, 1 + self.nb_digits
1205             )
1206             ar_mask = ar_mask.reshape(-1, 1 + self.nb_digits)
1207
1208             nb_total = ar_mask.max(1).values.sum()
1209             nb_correct = nb_total - errors.max(1).values.sum()
1210
1211             return nb_total, nb_correct
1212
1213         test_nb_total, test_nb_correct = compute_nb_correct(self.test_input[:1000])
1214
1215         logger(
1216             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
1217         )
1218
1219         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
1220
1221         ##############################################################
1222         # Log a few generated sequences
1223         input = self.test_input[:10, : 12 * (1 + self.nb_digits)]
1224         result = input.clone()
1225         stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
1226         ar_mask = (result != input).long()
1227
1228         # for n in range(result.size(0)):
1229         # logger(
1230         # f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
1231         # )
1232
1233         masked_inplace_autoregression(
1234             model,
1235             self.batch_size,
1236             result,
1237             ar_mask,
1238             deterministic_synthesis,
1239             device=self.device,
1240         )
1241
1242         #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
1243         for label, input in [
1244             ("train", self.train_input[:32]),
1245             ("test", self.test_input[:32]),
1246         ]:
1247             output = model(BracketedSequence(input)).x
1248             output = output.log_softmax(dim=-1)
1249             filename = os.path.join(
1250                 result_dir, f"stack_with_crossentropy_{n_epoch:04d}_{label}.txt"
1251             )
1252             with open(filename, "w") as f:
1253                 for n in range(input.size(0)):
1254                     s = stack.seq_to_str(
1255                         input[n], nb_stacks=self.nb_stacks, nb_digits=self.nb_digits
1256                     )
1257                     for t, k, w in zip(range(input[n].size(0)), input[n], s.split(" ")):
1258                         u = (
1259                             " " * (10 - len(w))
1260                             + w
1261                             + " "
1262                             + str(output[n][t][k].exp().item())
1263                             + "\n"
1264                         )
1265                         f.write(u)
1266                     f.write("\n")
1267             logger(f"wrote {filename}")
1268         #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
1269
1270         for n in range(result.size(0)):
1271             logger(
1272                 f"test_after  {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
1273             )
1274         ##############################################################
1275
1276
1277 ######################################################################
1278
1279 import rpl
1280
1281
1282 class RPL(Task):
1283     def tensorize(self, sequences):
1284         len_max = max([len(x) for x in sequences])
1285         return torch.cat(
1286             [
1287                 torch.tensor(
1288                     [
1289                         [
1290                             self.token2id[str(c)]
1291                             for c in s + ["<nul>"] * (len_max - len(s))
1292                         ]
1293                         for s in sequences
1294                     ]
1295                 )
1296             ],
1297             0,
1298         )
1299
1300     def seq2str(self, seq):
1301         return " ".join([self.id2token[i] for i in seq])
1302
1303     def __init__(
1304         self,
1305         nb_train_samples,
1306         nb_test_samples,
1307         batch_size,
1308         nb_starting_values=3,
1309         max_input=9,
1310         prog_len=6,
1311         nb_runs=5,
1312         no_prog=False,
1313         logger=None,
1314         device=torch.device("cpu"),
1315     ):
1316         super().__init__()
1317
1318         self.batch_size = batch_size
1319         self.device = device
1320         self.no_prog = no_prog
1321
1322         train_sequences = [
1323             rpl.generate(
1324                 nb_starting_values=nb_starting_values,
1325                 nb_result_values_max=4 * nb_starting_values,
1326                 max_input=max_input,
1327                 prog_len=prog_len,
1328                 nb_runs=nb_runs,
1329             )
1330             for _ in tqdm.tqdm(range(nb_train_samples), desc="train-data")
1331         ]
1332
1333         test_sequences = [
1334             rpl.generate(
1335                 nb_starting_values=nb_starting_values,
1336                 nb_result_values_max=4 * nb_starting_values,
1337                 max_input=max_input,
1338                 prog_len=prog_len,
1339                 nb_runs=nb_runs,
1340             )
1341             for _ in tqdm.tqdm(range(nb_test_samples), desc="test-data")
1342         ]
1343
1344         symbols = list(
1345             set(["<nul>"] + [x for l in train_sequences + test_sequences for x in l])
1346         )
1347         val_max = max([x if type(x) is int else 0 for x in symbols])
1348         symbols = list(filter(lambda x: type(x) is str, symbols))
1349         symbols.sort()
1350         symbols += [str(n) for n in range(val_max + 1)]
1351         self.token2id = dict([(c, n) for n, c in enumerate(symbols)])
1352         self.id2token = dict([(n, c) for c, n in self.token2id.items()])
1353
1354         self.t_nul = self.token2id["<nul>"]
1355         self.t_input = self.token2id["<in>"]
1356         self.t_output = self.token2id["<out>"]
1357         self.t_prog = self.token2id["<prg>"]
1358         self.t_end = self.token2id["<end>"]
1359
1360         self.train_input = self.tensorize(train_sequences)
1361         self.test_input = self.tensorize(test_sequences)
1362
1363         if no_prog:
1364             # Excise the program from every train and test example
1365             k = torch.arange(self.train_input.size(1), device=self.train_input.device)[
1366                 None, :
1367             ]
1368             p = (
1369                 ((self.train_input == self.t_prog).long() * k)
1370                 .max(1, keepdim=True)
1371                 .values
1372             )
1373             self.train_input = (
1374                 self.train_input * (k <= p).long()
1375                 + self.t_end * (k == p + 1).long()
1376                 + self.t_nul * (k > p + 1).long()
1377             )
1378             k = torch.arange(self.test_input.size(1), device=self.test_input.device)[
1379                 None, :
1380             ]
1381             p = (
1382                 ((self.test_input == self.t_prog).long() * k)
1383                 .max(1, keepdim=True)
1384                 .values
1385             )
1386             self.test_input = (
1387                 self.test_input * (k <= p).long()
1388                 + self.t_end * (k == p + 1).long()
1389                 + self.t_nul * (k > p + 1).long()
1390             )
1391
1392         if logger is not None:
1393             logger(f"value_max {val_max}")
1394             for x in self.train_input[:25]:
1395                 end = (x != self.t_nul).nonzero().max().item() + 1
1396                 seq = [self.id2token[i.item()] for i in x[:end]]
1397                 s = " ".join(seq)
1398                 logger(f"example_seq {s}")
1399
1400         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1401
1402     def batches(self, split="train", nb_to_use=-1, desc=None):
1403         assert split in {"train", "test"}
1404         input = self.train_input if split == "train" else self.test_input
1405         if nb_to_use > 0:
1406             input = input[:nb_to_use]
1407         if desc is None:
1408             desc = f"epoch-{split}"
1409         for batch in tqdm.tqdm(
1410             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1411         ):
1412             last = (batch != self.t_nul).max(0).values.nonzero().max() + 3
1413             batch = batch[:, :last].to(self.device)
1414             yield batch
1415
1416     def vocabulary_size(self):
1417         return self.nb_codes
1418
1419     def produce_results(
1420         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1421     ):
1422         # --------------------------------------------------------------------
1423         def compute_nb_errors_prog(input, nb_to_log=0):
1424             result = input.clone()
1425             s = (result == self.t_prog).long()
1426             ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
1427             result = (1 - ar_mask) * result + ar_mask * self.t_nul
1428
1429             masked_inplace_autoregression(
1430                 model,
1431                 self.batch_size,
1432                 result,
1433                 ar_mask,
1434                 deterministic_synthesis,
1435                 device=self.device,
1436             )
1437
1438             sum_nb_total, sum_nb_errors = 0, 0
1439             for one_input, one_result in zip(input, result):
1440                 seq = [self.id2token[i.item()] for i in one_result]
1441                 nb_total, nb_errors, prog, stacks = rpl.compute_nb_errors(seq)
1442                 sum_nb_total += 1
1443                 sum_nb_errors += 0 if nb_errors == 0 else 1
1444                 if nb_to_log > 0:
1445                     gt_seq = [self.id2token[i.item()] for i in one_input]
1446                     _, _, gt_prog, _ = rpl.compute_nb_errors(gt_seq)
1447                     gt_prog = " ".join([str(x) for x in gt_prog])
1448                     prog = " ".join([str(x) for x in prog])
1449                     comment = "*" if nb_errors == 0 else "-"
1450                     logger(f"{comment} PROG [{gt_prog}] PREDICTED [{prog}]")
1451                     for start_stack, target_stack, result_stack, correct in stacks:
1452                         comment = "*" if correct else "-"
1453                         start_stack = " ".join([str(x) for x in start_stack])
1454                         target_stack = " ".join([str(x) for x in target_stack])
1455                         result_stack = " ".join([str(x) for x in result_stack])
1456                         logger(
1457                             f"  {comment} [{start_stack}] -> [{target_stack}] PREDICTED [{result_stack}]"
1458                         )
1459                     nb_to_log -= 1
1460
1461             return sum_nb_total, sum_nb_errors
1462
1463         # --------------------------------------------------------------------
1464         def compute_nb_errors_output(input, nb_to_log=0):
1465             result = input.clone()
1466             k = torch.arange(result.size(1), device=result.device)[None, :]
1467             last_output_idx = (
1468                 ((result == self.t_output) * k).max(dim=1, keepdim=True).values
1469             )
1470             first_prog_idx = (
1471                 ((result == self.t_prog) * k).max(dim=1, keepdim=True).values
1472             )
1473             ar_mask = (k > last_output_idx).long() * (k < first_prog_idx).long()
1474             result = (1 - ar_mask) * result + ar_mask * self.t_nul
1475
1476             masked_inplace_autoregression(
1477                 model,
1478                 self.batch_size,
1479                 result,
1480                 ar_mask,
1481                 deterministic_synthesis,
1482                 device=self.device,
1483             )
1484
1485             sum_nb_total, sum_nb_errors = 0, 0
1486             for one_input, one_result, i, j in zip(
1487                 input, result, last_output_idx, first_prog_idx
1488             ):
1489                 seq = [self.id2token[i.item()] for i in one_result]
1490                 sum_nb_total += 1
1491                 correct = (one_input - one_result).abs().max() == 0
1492                 sum_nb_errors += 0 if correct else 1
1493                 if nb_to_log > 0:
1494                     result_stack = [
1495                         self.id2token[i.item()] for i in one_result[i : j + 1]
1496                     ]
1497                     target_stack = [
1498                         self.id2token[i.item()] for i in one_input[i : j + 1]
1499                     ]
1500                     comment = "*" if correct else "-"
1501                     result_stack = " ".join([str(x) for x in result_stack])
1502                     target_stack = " ".join([str(x) for x in target_stack])
1503                     logger(
1504                         f"output_test {comment} [{target_stack}] PREDICTED [{result_stack}]"
1505                     )
1506                     nb_to_log -= 1
1507
1508             return sum_nb_total, sum_nb_errors
1509
1510         # --------------------------------------------------------------------
1511
1512         if not self.no_prog:
1513             test_nb_total, test_nb_errors = compute_nb_errors_prog(
1514                 self.test_input[:1000].to(self.device), nb_to_log=10
1515             )
1516
1517             logger(
1518                 f"accuracy_prog_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
1519             )
1520
1521             logger(f"main_test_accuracy {n_epoch} {1-test_nb_errors/test_nb_total}")
1522
1523         test_nb_total, test_nb_errors = compute_nb_errors_output(
1524             self.test_input[:1000].to(self.device), nb_to_log=10
1525         )
1526
1527         logger(
1528             f"accuracy_output_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
1529         )
1530
1531         if save_attention_image is None:
1532             logger("no save_attention_image (is pycairo installed?)")
1533         else:
1534             ns = torch.randint(self.test_input.size(0), (1,)).item()
1535             input = self.test_input[ns : ns + 1].clone()
1536             last = (input != self.t_nul).max(0).values.nonzero().max() + 3
1537             input = input[:, :last].to(self.device)
1538
1539             with torch.autograd.no_grad():
1540                 t = model.training
1541                 model.eval()
1542                 model.record_attention(True)
1543                 model(BracketedSequence(input))
1544                 model.train(t)
1545                 ram = model.retrieve_attention()
1546                 model.record_attention(False)
1547
1548             tokens_output = [self.id2token[i.item()] for i in input[0]]
1549             tokens_input = ["n/a"] + tokens_output[:-1]
1550             for n_head in range(ram[0].size(1)):
1551                 filename = os.path.join(
1552                     result_dir, f"rpl_attention_{n_epoch}_h{n_head}.pdf"
1553                 )
1554                 attention_matrices = [m[0, n_head] for m in ram]
1555                 save_attention_image(
1556                     filename,
1557                     tokens_input,
1558                     tokens_output,
1559                     attention_matrices,
1560                     k_top=10,
1561                     # min_total_attention=0.9,
1562                     token_gap=12,
1563                     layer_gap=50,
1564                 )
1565                 logger(f"wrote {filename}")
1566
1567
1568 ######################################################################
1569
1570
1571 import expr
1572
1573
1574 class Expr(Task):
1575     def tensorize(self, sequences):
1576         len_max = max([len(x) for x in sequences])
1577         return torch.cat(
1578             [
1579                 torch.tensor(
1580                     [
1581                         [self.char2id[c] for c in s + "#" * (len_max - len(s))]
1582                         for s in sequences
1583                     ]
1584                 )
1585             ],
1586             0,
1587         ).to(self.device)
1588
1589     def __init__(
1590         self,
1591         nb_train_samples,
1592         nb_test_samples,
1593         nb_variables,
1594         sequence_length,
1595         operand_max,
1596         result_max,
1597         batch_size,
1598         device=torch.device("cpu"),
1599     ):
1600         super().__init__()
1601
1602         self.batch_size = batch_size
1603         self.device = device
1604
1605         train_sequences = expr.generate_sequences(
1606             nb_train_samples,
1607             nb_variables=nb_variables,
1608             length=sequence_length,
1609             operand_max=operand_max,
1610             result_max=result_max,
1611         )
1612
1613         test_sequences = expr.generate_sequences(
1614             nb_test_samples,
1615             nb_variables=nb_variables,
1616             length=sequence_length,
1617             operand_max=operand_max,
1618             result_max=result_max,
1619         )
1620
1621         symbols = list(set("#" + "".join(train_sequences + test_sequences)))
1622         symbols.sort()
1623
1624         self.char2id = dict([(c, n) for n, c in enumerate(symbols)])
1625         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
1626
1627         self.filler, self.space = self.char2id["#"], self.char2id[" "]
1628
1629         self.train_input = self.tensorize(train_sequences)
1630         self.test_input = self.tensorize(test_sequences)
1631
1632         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1633
1634     def batches(self, split="train", nb_to_use=-1, desc=None):
1635         assert split in {"train", "test"}
1636         input = self.train_input if split == "train" else self.test_input
1637         if nb_to_use > 0:
1638             input = input[:nb_to_use]
1639         if desc is None:
1640             desc = f"epoch-{split}"
1641         for batch in tqdm.tqdm(
1642             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1643         ):
1644             last = (batch != self.filler).max(0).values.nonzero().max() + 3
1645             batch = batch[:, :last]
1646             yield batch
1647
1648     def vocabulary_size(self):
1649         return self.nb_codes
1650
1651     def seq2str(self, s):
1652         return "".join([self.id2char[k.item()] for k in s])
1653
1654     def produce_results(
1655         self,
1656         n_epoch,
1657         model,
1658         result_dir,
1659         logger,
1660         deterministic_synthesis,
1661         input_file=None,
1662     ):
1663         def compute_nb_correct(input):
1664             result = input.clone()
1665             s = (result == self.space).long()
1666             ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
1667             result = (1 - ar_mask) * result + ar_mask * self.filler
1668             masked_inplace_autoregression(
1669                 model,
1670                 self.batch_size,
1671                 result,
1672                 ar_mask,
1673                 deterministic_synthesis,
1674                 device=self.device,
1675             )
1676
1677             nb_total = input.size(0)
1678             nb_correct = (input == result).long().min(1).values.sum()
1679
1680             #######################################################################
1681             # Comput predicted vs. true variable values
1682
1683             nb_delta = torch.zeros(5, dtype=torch.int64)
1684             nb_missed = 0
1685
1686             values_input = expr.extract_results([self.seq2str(s) for s in input])
1687             values_result = expr.extract_results([self.seq2str(s) for s in result])
1688
1689             filename = os.path.join(result_dir, f"expr_result_{n_epoch:04d}.txt")
1690
1691             with open(filename, "w") as f:
1692                 for i, r in zip(values_input, values_result):
1693                     for n, vi in i.items():
1694                         vr = r.get(n)
1695                         f.write(f"{vi} {-1 if vr is None else vr}\n")
1696
1697                         if vr is None or vr < 0:
1698                             nb_missed += 1
1699                         else:
1700                             d = abs(vr - vi)
1701                             if d >= nb_delta.size(0):
1702                                 nb_missed += 1
1703                             else:
1704                                 nb_delta[d] += 1
1705
1706             ######################################################################
1707
1708             return nb_total, nb_correct, nb_delta, nb_missed
1709
1710         (
1711             test_nb_total,
1712             test_nb_correct,
1713             test_nb_delta,
1714             test_nb_missed,
1715         ) = compute_nb_correct(self.test_input[:10000])
1716
1717         logger(
1718             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
1719         )
1720
1721         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
1722
1723         nb_total = test_nb_delta.sum() + test_nb_missed
1724         for d in range(test_nb_delta.size(0)):
1725             logger(
1726                 f"error_value {n_epoch} delta {d} {test_nb_delta[d]} {test_nb_delta[d]*100/nb_total:.02f}%"
1727             )
1728         logger(
1729             f"error_value {n_epoch} missed {test_nb_missed} {test_nb_missed*100/nb_total:.02f}%"
1730         )
1731
1732         ##############################################################
1733         # Log a few generated sequences
1734         if input_file is None:
1735             input = self.test_input[:10]
1736         else:
1737             with open(input_file, "r") as f:
1738                 sequences = [e.strip() for e in f.readlines()]
1739                 sequences = [s + " " + "#" * 50 for s in sequences]
1740                 input = self.tensorize(sequences)
1741
1742         result = input.clone()
1743         s = (result == self.space).long()
1744         ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
1745         result = (1 - ar_mask) * result + ar_mask * self.filler
1746
1747         for n in range(result.size(0)):
1748             logger(f"test_before {self.seq2str(result[n])}")
1749
1750         masked_inplace_autoregression(
1751             model,
1752             self.batch_size,
1753             result,
1754             ar_mask,
1755             deterministic_synthesis,
1756             device=self.device,
1757         )
1758
1759         correct = (1 - ar_mask) * self.space + ar_mask * input
1760         for n in range(result.size(0)):
1761             comment = "GOOD" if (result[n] - input[n]).abs().max() == 0 else ""
1762             logger(f"test_after  {self.seq2str(result[n])} {comment}")
1763             logger(f"truth       {self.seq2str(correct[n])}")
1764         ##############################################################
1765
1766
1767 ######################################################################
1768
1769 import grid
1770
1771
1772 class Grid(Task):
1773     # Make a tensor from a list of strings
1774     def str2tensor(self, descr):
1775         token_descr = [s.strip().split(" ") for s in descr]
1776         l = max([len(s) for s in token_descr])
1777         token_descr = [s + ["#"] * (l - len(s)) for s in token_descr]
1778         id_descr = [[self.token2id[u] for u in s] for s in token_descr]
1779         return torch.tensor(id_descr, device=self.device)
1780
1781     # Make a list of strings from a tensor
1782     def tensor2str(self, x):
1783         return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
1784
1785     # trim all the tensors in the tuple z to remove as much token from
1786     # left and right in the first tensor. If z is a tuple, all its
1787     # elements are trimed according to the triming for the first
1788     def trim(self, z, token="#"):
1789         n = self.token2id[token]
1790         if type(z) == tuple:
1791             x = z[0]
1792             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
1793             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
1794             return tuple([t[:, a:b] for t in z])
1795         else:
1796             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
1797             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
1798             return z[:, a:b]
1799
1800     ######################
1801
1802     def __init__(
1803         self,
1804         nb_train_samples,
1805         nb_test_samples,
1806         batch_size,
1807         size,
1808         fraction_play=0.0,
1809         logger=None,
1810         device=torch.device("cpu"),
1811     ):
1812         super().__init__()
1813
1814         self.device = device
1815         self.batch_size = batch_size
1816         self.grid_factory = grid.GridFactory(size=size)
1817         self.fraction_play = fraction_play
1818
1819         if logger is not None:
1820             logger(
1821                 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
1822             )
1823
1824         self.train_descr = self.grid_factory.generate_samples(
1825             nb=nb_train_samples,
1826             fraction_play=fraction_play,
1827             progress_bar=lambda r: tqdm.tqdm(r),
1828         )
1829
1830         self.test_descr = self.grid_factory.generate_samples(
1831             nb=nb_test_samples, fraction_play=0.0, progress_bar=lambda r: tqdm.tqdm(r)
1832         )
1833
1834         if fraction_play > 0:
1835             self.play_descr = self.grid_factory.generate_samples(
1836                 nb=25, fraction_play=1.0, progress_bar=lambda r: tqdm.tqdm(r)
1837             )
1838         else:
1839             self.play_descr = []
1840
1841         # Build the tokenizer
1842         tokens = set()
1843         for d in [self.train_descr, self.test_descr, self.play_descr]:
1844             for s in d:
1845                 for t in s.strip().split(" "):
1846                     tokens.add(t)
1847         # make this set a sorted list to get the same tensors given
1848         # the same descr
1849         tokens = list(tokens)
1850         tokens.sort()
1851         tokens = ["#"] + tokens
1852         self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
1853         self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
1854         self.t_nul = self.token2id["#"]
1855         self.t_true = self.token2id["true"]
1856         self.t_false = self.token2id["false"]
1857         # self.t_pipe = self.token2id["|"]
1858
1859         # Tokenize the train and test sets
1860         self.train_input = self.str2tensor(self.train_descr)
1861         self.test_input = self.str2tensor(self.test_descr)
1862         self.play_input = (
1863             None if len(self.play_descr) == 0 else self.str2tensor(self.play_descr)
1864         )
1865
1866     def batches(self, split="train", nb_to_use=-1, desc=None):
1867         assert split in {"train", "test"}
1868         input = self.train_input if split == "train" else self.test_input
1869         for batch in tqdm.tqdm(
1870             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
1871         ):
1872             yield self.trim(batch)
1873
1874     def vocabulary_size(self):
1875         return len(self.token2id)
1876
1877     def produce_results(
1878         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1879     ):
1880         correct = self.test_input[:1000]
1881         result = correct.clone()
1882         ar_mask = torch.logical_or(result == self.t_true, result == self.t_false).long()
1883         result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
1884
1885         logger(f"----------------------------------------------------------")
1886
1887         for e in self.tensor2str(result[:10]):
1888             logger(f"test_before {e}")
1889
1890         masked_inplace_autoregression(
1891             model,
1892             self.batch_size,
1893             result,
1894             ar_mask,
1895             deterministic_synthesis,
1896             device=self.device,
1897         )
1898
1899         logger(f"----------------------------------------------------------")
1900
1901         for e in self.tensor2str(result[:10]):
1902             logger(f"test_after  {e}")
1903
1904         logger(f"----------------------------------------------------------")
1905
1906         nb_total = ar_mask.sum().item()
1907         nb_correct = ((correct == result).long() * ar_mask).sum().item()
1908
1909         logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}")
1910         logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}")
1911
1912         if self.play_input is not None:
1913             result = self.play_input.clone()
1914             ar_mask = (result == self.t_pipe).long().cumsum(dim=1).clamp(max=1)
1915             result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
1916
1917             logger(f"----------------------------------------------------------")
1918
1919             for e in self.tensor2str(result[:10]):
1920                 logger(f"play_before {e}")
1921
1922             masked_inplace_autoregression(
1923                 model,
1924                 self.batch_size,
1925                 result,
1926                 ar_mask,
1927                 deterministic_synthesis,
1928                 device=self.device,
1929             )
1930
1931             logger(f"----------------------------------------------------------")
1932
1933             for e in self.tensor2str(result[:10]):
1934                 logger(f"play_after  {e}")
1935
1936             logger(f"----------------------------------------------------------")
1937
1938
1939 ######################################################################
1940
1941 import qmlp
1942
1943
1944 class QMLP(Task):
1945     ######################
1946
1947     def __init__(
1948         self,
1949         nb_train_samples,
1950         nb_test_samples,
1951         batch_size,
1952         result_dir,
1953         logger=None,
1954         device=torch.device("cpu"),
1955     ):
1956         super().__init__()
1957
1958         self.device = device
1959         self.batch_size = batch_size
1960         self.nb_samples_per_mlp = 256
1961
1962         if logger is not None:
1963             logger(
1964                 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
1965             )
1966
1967         seq, q_test_set, test_error = qmlp.generate_sequence_and_test_set(
1968             nb_mlps=nb_train_samples + nb_test_samples,
1969             nb_samples=self.nb_samples_per_mlp,
1970             device=self.device,
1971             batch_size=64,
1972             nb_epochs=250,
1973             nb_mlps_per_batch=1024,
1974         )
1975
1976         self.train_input = seq[:nb_train_samples]
1977         self.train_q_test_set = q_test_set[:nb_train_samples]
1978         self.train_ref_test_errors = test_error[:nb_train_samples]
1979         self.test_input = seq[nb_train_samples:]
1980         self.test_q_test_set = q_test_set[nb_train_samples:]
1981         self.test_ref_test_errors = test_error[nb_train_samples:]
1982
1983         filename = os.path.join(result_dir, f"train_errors_ref.dat")
1984         with open(filename, "w") as f:
1985             for e in self.train_ref_test_errors:
1986                 f.write(f"{e}\n")
1987
1988         filename = os.path.join(result_dir, f"test_errors_ref.dat")
1989         with open(filename, "w") as f:
1990             for e in self.test_ref_test_errors:
1991                 f.write(f"{e}\n")
1992
1993         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1994
1995     def batches(self, split="train", nb_to_use=-1, desc=None):
1996         assert split in {"train", "test"}
1997         input = self.train_input if split == "train" else self.test_input
1998         for batch in tqdm.tqdm(
1999             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
2000         ):
2001             yield batch
2002
2003     def vocabulary_size(self):
2004         return self.nb_codes
2005
2006     def produce_results(
2007         self, n_epoch, model, result_dir, logger, deterministic_synthesis
2008     ):
2009         correct = self.test_input[:1000]
2010         result = correct.clone()
2011         ar_mask = (
2012             torch.arange(result.size(1), device=result.device)
2013             > self.nb_samples_per_mlp * 3 + 1
2014         ).long()[None, :]
2015         ar_mask = ar_mask.expand_as(result)
2016         result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
2017
2018         masked_inplace_autoregression(
2019             model,
2020             self.batch_size,
2021             result,
2022             ar_mask,
2023             deterministic_synthesis,
2024             device=self.device,
2025         )
2026
2027         q_train_set = result[:, : self.nb_samples_per_mlp * 3]
2028         q_params = result[:, self.nb_samples_per_mlp * 3 + 1 :]
2029         error_test = qmlp.evaluate_q_params(q_params, self.test_q_test_set)
2030
2031         filename = os.path.join(result_dir, f"test_errors_{n_epoch:04d}.dat")
2032         with open(filename, "w") as f:
2033             for e in error_test:
2034                 f.write(f"{e}\n")
2035
2036
2037 ######################################################################
2038
2039 import greed
2040
2041
2042 class Greed(Task):
2043     def __init__(
2044         self,
2045         nb_train_samples,
2046         nb_test_samples,
2047         batch_size,
2048         height,
2049         width,
2050         T,
2051         nb_walls,
2052         nb_coins,
2053         logger=None,
2054         device=torch.device("cpu"),
2055     ):
2056         super().__init__()
2057
2058         self.batch_size = batch_size
2059         self.device = device
2060
2061         self.world = greed.GreedWorld(height, width, T, nb_walls, nb_coins)
2062
2063         states, actions, rewards = self.world.generate_episodes(
2064             nb_train_samples + nb_test_samples
2065         )
2066         seq = self.world.episodes2seq(states, actions, rewards)
2067         self.train_input = seq[:nb_train_samples].to(self.device)
2068         self.test_input = seq[nb_train_samples:].to(self.device)
2069
2070     def wipe_lookahead_rewards(self, batch):
2071         t = torch.arange(batch.size(1), device=batch.device)[None, :]
2072         u = torch.randint(batch.size(1), (batch.size(0), 1), device=batch.device)
2073         lr_mask = (t <= u).long() * (
2074             t % self.world.it_len == self.world.index_lookahead_reward
2075         ).long()
2076
2077         return (
2078             lr_mask * self.world.lookahead_reward2code(greed.REWARD_UNKNOWN)
2079             + (1 - lr_mask) * batch
2080         )
2081
2082     def batches(self, split="train", nb_to_use=-1, desc=None):
2083         assert split in {"train", "test"}
2084         input = self.train_input if split == "train" else self.test_input
2085         if nb_to_use > 0:
2086             input = input[:nb_to_use]
2087         if desc is None:
2088             desc = f"epoch-{split}"
2089         for batch in tqdm.tqdm(
2090             input.split(self.batch_size), dynamic_ncols=True, desc=desc
2091         ):
2092             yield self.wipe_lookahead_rewards(batch)
2093
2094     def vocabulary_size(self):
2095         return self.world.nb_codes
2096
2097     def thinking_autoregression(
2098         self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
2099     ):
2100         snapshots = []
2101
2102         def ar(result, ar_mask, logit_biases=None):
2103             ar_mask = ar_mask.expand_as(result)
2104             result *= 1 - ar_mask
2105             masked_inplace_autoregression(
2106                 model,
2107                 self.batch_size,
2108                 result,
2109                 ar_mask,
2110                 deterministic_synthesis=deterministic_synthesis,
2111                 logit_biases=logit_biases,
2112                 device=self.device,
2113                 progress_bar_desc=None,
2114             )
2115             warnings.warn("keeping thinking snapshots", RuntimeWarning)
2116             snapshots.append(result[:100].detach().clone())
2117
2118         # Generate iteration after iteration
2119
2120         result = self.test_input[:250].clone()
2121         # Erase all the content but that of the first iteration
2122         result[:, self.world.it_len :] = -1
2123         # Set the lookahead_reward of the firs to UNKNOWN
2124         result[:, self.world.index_lookahead_reward] = self.world.lookahead_reward2code(
2125             greed.REWARD_UNKNOWN
2126         )
2127
2128         t = torch.arange(result.size(1), device=result.device)[None, :]
2129
2130         for u in tqdm.tqdm(
2131             range(0, result.size(1), self.world.it_len),
2132             desc="thinking",
2133         ):
2134             # Generate the next state but keep the initial one, the
2135             # lookahead_reward of previous iterations are set to
2136             # UNKNOWN
2137             if u > 0:
2138                 result[
2139                     :, u + self.world.index_lookahead_reward
2140                 ] = self.world.lookahead_reward2code(greed.REWARD_UNKNOWN)
2141                 ar_mask = (t >= u + self.world.index_states).long() * (
2142                     t < u + self.world.index_states + self.world.state_len
2143                 ).long()
2144                 ar(result, ar_mask)
2145
2146             # Generate the action and reward with lookahead_reward to +1
2147             result[
2148                 :, u + self.world.index_lookahead_reward
2149             ] = self.world.lookahead_reward2code(greed.REWARD_PLUS)
2150             ar_mask = (t >= u + self.world.index_reward).long() * (
2151                 t <= u + self.world.index_action
2152             ).long()
2153             ar(result, ar_mask)
2154
2155             # Set the lookahead_reward to UNKNOWN for the next iterations
2156             result[
2157                 :, u + self.world.index_lookahead_reward
2158             ] = self.world.lookahead_reward2code(greed.REWARD_UNKNOWN)
2159
2160         filename = os.path.join(result_dir, f"test_thinking_compute_{n_epoch:04d}.txt")
2161         with open(filename, "w") as f:
2162             for n in range(snapshots[0].size(0)):
2163                 for s in snapshots:
2164                     lr, s, a, r = self.world.seq2episodes(
2165                         s[n : n + 1],
2166                     )
2167                     str = self.world.episodes2str(
2168                         lr, s, a, r, unicode=True, ansi_colors=True
2169                     )
2170                     f.write(str)
2171                 f.write("\n\n")
2172
2173         # Saving the generated sequences
2174
2175         lr, s, a, r = self.world.seq2episodes(result)
2176         str = self.world.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True)
2177
2178         filename = os.path.join(result_dir, f"test_thinking_seq_{n_epoch:04d}.txt")
2179         with open(filename, "w") as f:
2180             f.write(str)
2181             logger(f"wrote {filename}")
2182
2183     def produce_results(
2184         self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
2185     ):
2186         result = self.wipe_lookahead_rewards(self.test_input[:250].clone())
2187
2188         # Saving the ground truth
2189
2190         lr, s, a, r = self.world.seq2episodes(
2191             result,
2192         )
2193         str = self.world.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True)
2194
2195         filename = os.path.join(result_dir, f"test_true_seq_{n_epoch:04d}.txt")
2196         with open(filename, "w") as f:
2197             f.write(str)
2198             logger(f"wrote {filename}")
2199
2200         # Re-generating from the first frame
2201
2202         ar_mask = (
2203             torch.arange(result.size(1), device=result.device) >= self.world.it_len
2204         ).long()[None, :]
2205         ar_mask = ar_mask.expand_as(result)
2206         result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
2207
2208         masked_inplace_autoregression(
2209             model,
2210             self.batch_size,
2211             result,
2212             ar_mask,
2213             deterministic_synthesis,
2214             device=self.device,
2215         )
2216
2217         # Saving the generated sequences
2218
2219         lr, s, a, r = self.world.seq2episodes(
2220             result,
2221         )
2222         str = self.world.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True)
2223
2224         filename = os.path.join(result_dir, f"test_seq_{n_epoch:04d}.txt")
2225         with open(filename, "w") as f:
2226             f.write(str)
2227             logger(f"wrote {filename}")
2228
2229         self.thinking_autoregression(
2230             n_epoch, model, result_dir, logger, deterministic_synthesis, nmax
2231         )
2232
2233
2234 ######################################################################