Update.
[picoclvr.git] / tasks.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, os, tqdm, warnings
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 from mygpt import BracketedSequence
16
17 # from graph import save_attention_image
18 save_attention_image = None
19
20 ######################################################################
21
22
23 def masked_inplace_autoregression(
24     model,
25     batch_size,
26     input,
27     ar_mask,
28     deterministic_synthesis,
29     forbidden_tokens=None,
30     logit_biases=None,
31     progress_bar_desc="autoregression",
32     device=torch.device("cpu"),
33 ):
34     assert input.size() == ar_mask.size()
35
36     batches = zip(input.split(batch_size), ar_mask.split(batch_size))
37
38     if progress_bar_desc is not None:
39         batches = tqdm.tqdm(
40             batches,
41             dynamic_ncols=True,
42             desc=progress_bar_desc,
43             total=(input.size(0) + batch_size - 1) // batch_size,
44         )
45
46     with torch.autograd.no_grad():
47         t = model.training
48         model.eval()
49
50         for input, ar_mask in batches:
51             model.masked_inplace_autoregression(
52                 input,
53                 ar_mask,
54                 deterministic_synthesis,
55                 forbidden_tokens,
56                 logit_biases,
57             )
58
59         model.train(t)
60
61
62 ######################################################################
63
64
65 class Task:
66     def batches(self, split="train", nb_to_use=-1, desc=None):
67         pass
68
69     def vocabulary_size(self):
70         pass
71
72     def produce_results(
73         self, n_epoch, model, result_dir, logger, deterministic_synthesis
74     ):
75         pass
76
77
78 class TaskFromFile(Task):
79     def tensorize(self, pairs, shuffle):
80         len_max = max([len(x[0]) for x in pairs])
81
82         input = torch.cat(
83             [
84                 torch.tensor(
85                     [
86                         [self.char2id[c] for c in s[0] + "#" * (len_max - len(s[0]))]
87                         for s in pairs
88                     ]
89                 )
90             ],
91             0,
92         ).to("cpu")
93
94         pred_mask = torch.cat(
95             [
96                 torch.tensor(
97                     [
98                         [int(c) for c in s[1] + "0" * (len_max - len(s[1]))]
99                         for s in pairs
100                     ]
101                 )
102             ],
103             0,
104         ).to("cpu")
105
106         if shuffle:
107             i = torch.randperm(input.size(0))
108             input = input[i].contiguous()
109             pred_mask = pred_mask[i].contiguous()
110
111         return input, pred_mask
112
113     # trim all the tensors in the tuple z to remove as much token from
114     # left and right in the first tensor. If z is a tuple, all its
115     # elements are trimed according to the triming for the first
116     def trim(self, z, token="#"):
117         n = self.char2id[token]
118         if type(z) == tuple:
119             x = z[0]
120             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
121             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
122             return tuple([t[:, a:b] for t in z])
123         else:
124             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
125             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
126             return z[:, a:b]
127
128     def __init__(
129         self,
130         train_filename,
131         test_filename,
132         nb_train_samples,
133         nb_test_samples,
134         batch_size,
135         shuffle=False,
136         device=torch.device("cpu"),
137     ):
138         self.batch_size = batch_size
139         self.device = device
140
141         def read_file(filename, nb=-1):
142             pairs = []
143             with open(filename, "r") as f:
144                 while True:
145                     sequence = f.readline().strip()
146                     if not sequence:
147                         break
148                     pred_mask = f.readline().strip()
149                     assert len(sequence) == len(pred_mask)
150                     assert set(pred_mask).issubset({"0", "1", "2"}), f"{set(pred_mask)}"
151                     pairs.append((sequence, pred_mask))
152                     if len(pairs) == nb:
153                         break
154
155             if nb > 0:
156                 pairs = pairs[:nb]
157                 assert len(pairs) == nb
158
159             return pairs
160
161         train_pairs = read_file(train_filename, nb_train_samples)
162         test_pairs = read_file(test_filename, nb_test_samples)
163
164         symbols = ["#"] + list(
165             set("".join([x[0] for x in train_pairs + test_pairs])) - set(["#"])
166         )
167         self.char2id = dict([(c, n) for n, c in enumerate(symbols)])
168         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
169
170         self.train_input, self.train_pred_masks = self.tensorize(
171             train_pairs, shuffle=shuffle
172         )
173         self.test_input, self.test_pred_masks = self.tensorize(
174             test_pairs, shuffle=shuffle
175         )
176
177     def batches(self, split="train", nb_to_use=-1, desc=None):
178         assert split in {"train", "test"}
179         input = self.train_input if split == "train" else self.test_input
180         if nb_to_use > 0:
181             input = input[:nb_to_use]
182         if desc is None:
183             desc = f"epoch-{split}"
184         for batch in tqdm.tqdm(
185             input.split(self.batch_size), dynamic_ncols=True, desc=desc
186         ):
187             yield self.trim(batch).to(self.device)
188
189     def vocabulary_size(self):
190         return len(self.char2id)
191
192     def tensor2str(self, t):
193         return ["".join([self.id2char[x.item()] for x in s]) for s in t]
194
195     def produce_results(
196         self, n_epoch, model, result_dir, logger, deterministic_synthesis
197     ):
198         correct = self.trim(self.test_input[:1000]).to(self.device)
199         result = correct.clone()
200         pred_mask = self.test_pred_masks[:1000, : result.size(1)].to(self.device)
201         ar_mask = (pred_mask > 0).long()
202         result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
203
204         logger(f"----------------------------------------------------------")
205
206         for e in self.tensor2str(result[:50]):
207             logger(f"test_before {e}")
208
209         masked_inplace_autoregression(
210             model,
211             self.batch_size,
212             result,
213             ar_mask,
214             deterministic_synthesis,
215             device=self.device,
216         )
217
218         logger(f"----------------------------------------------------------")
219
220         for e, c in zip(self.tensor2str(result[:50]), self.tensor2str(correct[:50])):
221             logger(f"test_after  {e}")
222             logger(f"correct     {c}")
223
224         logger(f"----------------------------------------------------------")
225
226         err_mask = (pred_mask == 2).long()
227         nb_total = err_mask.sum().item()
228         nb_correct = ((correct == result).long() * err_mask).sum().item()
229
230         logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}")
231         logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}")
232
233
234 ####################
235
236 import problems
237
238
239 class SandBox(Task):
240     def __init__(
241         self,
242         problem,
243         nb_train_samples,
244         nb_test_samples,
245         batch_size,
246         logger=None,
247         device=torch.device("cpu"),
248         max_nb_codes=1024,
249     ):
250         super().__init__()
251
252         self.batch_size = batch_size
253         self.device = device
254         self.problem = problem
255
256         self.train_input, self.train_ar_mask = self.problem.generate_sequences(
257             nb_train_samples
258         )
259         self.test_input, self.test_ar_mask = self.problem.generate_sequences(
260             nb_test_samples
261         )
262
263         self.train_input, self.train_ar_mask = self.train_input.to(
264             device
265         ), self.train_ar_mask.to(device)
266         self.test_input, self.test_ar_mask = self.test_input.to(
267             device
268         ), self.test_ar_mask.to(device)
269
270         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
271
272         # A bit of paranoia never hurts
273         assert self.nb_codes <= max_nb_codes
274         assert self.train_input.min() >= 0
275         assert self.test_input.min() >= 0
276         assert tuple(x.item() for x in self.train_ar_mask.unique()) in {
277             (0,),
278             (1,),
279             (0, 1),
280         }
281         assert tuple(x.item() for x in self.test_ar_mask.unique()) in {
282             (0,),
283             (1,),
284             (0, 1),
285         }
286
287         if logger is not None:
288             for s, a in zip(self.train_input[:100], self.train_ar_mask[:100]):
289                 logger(f"train_sequences {self.problem.seq2str(s)}")
290                 a = "".join(["01"[x.item()] for x in a])
291                 logger(f"                {a}")
292
293     def batches(self, split="train", nb_to_use=-1, desc=None):
294         assert split in {"train", "test"}
295         input = self.train_input if split == "train" else self.test_input
296         if nb_to_use > 0:
297             input = input[:nb_to_use]
298         if desc is None:
299             desc = f"epoch-{split}"
300         for batch in tqdm.tqdm(
301             input.split(self.batch_size), dynamic_ncols=True, desc=desc
302         ):
303             yield batch
304
305     def vocabulary_size(self):
306         return self.nb_codes
307
308     def produce_results(
309         self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
310     ):
311         def compute_accuracy(input, ar_mask, logger=None):
312             input, ar_mask = input[:nmax], ar_mask[:nmax]
313             result = input.clone() * (1 - ar_mask)
314
315             masked_inplace_autoregression(
316                 model,
317                 self.batch_size,
318                 result,
319                 ar_mask,
320                 deterministic_synthesis,
321                 progress_bar_desc=None,
322                 device=self.device,
323             )
324
325             log_ground_truth = ar_mask.min() == 0
326
327             if logger is not None:
328                 for sp, st in zip(result[:10], input[:10]):
329                     logger(
330                         f"test_sequences {n_epoch} prediction   {self.problem.seq2str(sp)}"
331                     )
332                     if log_ground_truth:
333                         logger(
334                             f"               {n_epoch} ground truth {self.problem.seq2str(st)}"
335                         )
336
337             nb_total, nb_correct = self.problem.compute_nb_correct(
338                 input, ar_mask, result
339             )
340
341             # nb_total = ar_mask.sum().item()
342             # nb_correct = ((result == input).long() * ar_mask).sum().item()
343
344             return nb_total, nb_correct
345
346         train_nb_total, train_nb_correct = compute_accuracy(
347             self.train_input, self.train_ar_mask
348         )
349
350         logger(
351             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
352         )
353
354         test_nb_total, test_nb_correct = compute_accuracy(
355             self.test_input, self.test_ar_mask, logger
356         )
357
358         logger(
359             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
360         )
361
362         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
363
364         if save_attention_image is not None:
365             for k in range(10):
366                 ns = torch.randint(self.test_input.size(0), (1,)).item()
367                 input = self.test_input[ns : ns + 1].clone()
368
369                 with torch.autograd.no_grad():
370                     t = model.training
371                     model.eval()
372                     # model.record_attention(True)
373                     model(BracketedSequence(input))
374                     model.train(t)
375                     # ram = model.retrieve_attention()
376                     # model.record_attention(False)
377
378                 # tokens_output = [c for c in self.problem.seq2str(input[0])]
379                 # tokens_input = ["n/a"] + tokens_output[:-1]
380                 # for n_head in range(ram[0].size(1)):
381                 # filename = os.path.join(
382                 # result_dir, f"sandbox_attention_{k}_h{n_head}.pdf"
383                 # )
384                 # attention_matrices = [m[0, n_head] for m in ram]
385                 # save_attention_image(
386                 # filename,
387                 # tokens_input,
388                 # tokens_output,
389                 # attention_matrices,
390                 # k_top=10,
391                 ##min_total_attention=0.9,
392                 # token_gap=12,
393                 # layer_gap=50,
394                 # )
395                 # logger(f"wrote {filename}")
396
397
398 ######################################################################
399
400 import picoclvr
401
402
403 class PicoCLVR(Task):
404     # Make a tensor from a list of strings
405     def tensorize(self, descr):
406         token_descr = [s.strip().split(" ") for s in descr]
407         l = max([len(s) for s in token_descr])
408         token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
409         id_descr = [[self.token2id[u] for u in s] for s in token_descr]
410         return torch.tensor(id_descr, device=self.device)
411
412     # Make a list of strings from a tensor
413     def detensorize(self, x):
414         return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
415
416     # trim all the tensors in the tuple z to remove as much token from
417     # left and right in the first tensor. If z is a tuple, all its
418     # elements are trimed according to the triming for the first
419     def trim(self, z, token="<nul>"):
420         n = self.token2id[token]
421         if type(z) == tuple:
422             x = z[0]
423             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
424             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
425             return tuple([t[:, a:b] for t in z])
426         else:
427             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
428             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
429             return z[:, a:b]
430
431     ######################
432
433     def __init__(
434         self,
435         nb_train_samples,
436         nb_test_samples,
437         batch_size,
438         height,
439         width,
440         nb_colors=5,
441         logger=None,
442         device=torch.device("cpu"),
443         pruner_train=None,
444         pruner_eval=None,
445     ):
446         super().__init__()
447
448         def generate_descr(nb, cache_suffix, pruner):
449             return picoclvr.generate(
450                 nb,
451                 height=self.height,
452                 width=self.width,
453                 nb_colors=nb_colors,
454                 pruner=pruner,
455             )
456
457         self.height = height
458         self.width = width
459         self.batch_size = batch_size
460         self.device = device
461         self.pruner_train = pruner_train
462         self.pruner_eval = pruner_eval
463
464         if logger is not None:
465             logger(
466                 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
467             )
468
469         self.train_descr = generate_descr(
470             nb_train_samples, "train", pruner=self.pruner_train
471         )
472         self.test_descr = generate_descr(nb_test_samples, "test", pruner=None)
473
474         # Build the tokenizer
475         tokens = {"<nul>", "<img>"}
476         for d in [self.train_descr, self.test_descr]:
477             for s in d:
478                 for t in s.strip().split(" "):
479                     tokens.add(t)
480         # make this set a sorted list to get the same tensors given
481         # the same descr
482         tokens = list(tokens)
483         tokens.sort()
484         self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
485         self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
486         self.t_img, self.t_nul = self.token2id["<img>"], self.token2id["<nul>"]
487
488         # Tokenize the train and test sets
489         self.train_input = self.tensorize(self.train_descr)
490         self.test_input = self.tensorize(self.test_descr)
491
492     def batches(self, split="train", nb_to_use=-1, desc=None):
493         assert split in {"train", "test"}
494         input = self.train_input if split == "train" else self.test_input
495         for batch in tqdm.tqdm(
496             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
497         ):
498             yield self.trim(batch)
499
500     def vocabulary_size(self):
501         return len(self.token2id)
502
503     def compute_missing_properties(
504         self, n_epoch, model, logger, deterministic_synthesis, pruner=None
505     ):
506         acc_nb_requested_properties = []
507         acc_nb_missing_properties = []
508         acc_nb_results = 0
509
510         for input in tqdm.tqdm(
511             self.test_input.split(self.batch_size),
512             dynamic_ncols=True,
513             desc=f"test-properties",
514         ):
515             result = input.clone()
516             ar_mask = (result == self.t_img).long().cumsum(dim=1).clamp(max=1)
517             result = (1 - ar_mask) * result + ar_mask * self.t_nul
518             masked_inplace_autoregression(
519                 model,
520                 self.batch_size,
521                 result,
522                 ar_mask,
523                 deterministic_synthesis,
524                 progress_bar_desc=None,
525                 device=self.device,
526             )
527
528             result_descr = self.detensorize(result)
529             np = picoclvr.nb_properties(
530                 result_descr,
531                 height=self.height,
532                 width=self.width,
533                 pruner=pruner,
534             )
535             nb_requested_properties, _, nb_missing_properties = zip(*np)
536             acc_nb_requested_properties += nb_requested_properties
537             acc_nb_missing_properties += nb_missing_properties
538             acc_nb_results += len(result_descr)
539
540         nb_requested_properties = sum(acc_nb_requested_properties)
541         nb_missing_properties = sum(acc_nb_missing_properties)
542
543         prefix = "" if pruner is None else "pruned_"
544         logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
545         logger(
546             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
547         )
548         logger(
549             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
550         )
551
552         logger(
553             f"main_test_accuracy {n_epoch} {1-nb_missing_properties/nb_requested_properties}"
554         )
555
556     ######################################################################
557
558     def produce_results(
559         self, n_epoch, model, result_dir, logger, deterministic_synthesis
560     ):
561         self.compute_missing_properties(n_epoch, model, logger, deterministic_synthesis)
562
563         if self.pruner_eval is not None:
564             self.compute_missing_properties(n_epoch, model, self.pruner_eval)
565
566         nb_tokens_to_generate = self.height * self.width + 3
567         result_descr = []
568         nb_per_primer = 8
569         primer = []
570
571         for primer_descr in [
572             "red above green <sep> green top <sep> blue right of red",
573             "there is red <sep> there is yellow <sep> there is blue",
574             "red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left",
575             "green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top",
576         ]:
577             primer += [primer_descr + " <img>"] * nb_per_primer
578
579         result = self.tensorize(primer)
580         fill = result.new_full(
581             result.size()[:-1] + (self.height * self.width + 1,), self.t_nul
582         )
583         result = torch.cat((result, fill), 1)
584         ar_mask = (result == self.t_nul).long()
585         masked_inplace_autoregression(
586             model,
587             self.batch_size,
588             result,
589             ar_mask,
590             deterministic_synthesis,
591             device=self.device,
592         )
593         result_descr = self.detensorize(result)
594
595         np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
596
597         acc_nb_requested_properties, _, acc_nb_missing_properties = zip(*np)
598         acc_nb_results = len(result_descr)
599
600         nb_requested_properties = sum(acc_nb_requested_properties)
601         nb_missing_properties = sum(acc_nb_missing_properties)
602
603         prefix = "demo_"
604         logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
605         logger(
606             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
607         )
608         logger(
609             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
610         )
611
612         img = picoclvr.descr2img(result_descr, height=self.height, width=self.width)
613
614         if img.dim() == 5:
615             if img.size(1) == 1:
616                 img = F.pad(img.squeeze(1), pad=(1, 1, 1, 1), value=64)
617             else:
618                 img = torch.cat(
619                     [
620                         torchvision.utils.make_grid(x, padding=1, pad_value=64)[None]
621                         for x in img
622                     ],
623                     0,
624                 )
625
626         image_name = os.path.join(result_dir, f"picoclvr_result_{n_epoch:04d}.png")
627         torchvision.utils.save_image(
628             img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=0.0
629         )
630         logger(f"wrote {image_name}")
631
632
633 ######################################################################
634
635
636 class MNIST(Task):
637     def __init__(
638         self, nb_train_samples, nb_test_samples, batch_size, device=torch.device("cpu")
639     ):
640         super().__init__()
641
642         self.nb_train_samples = (nb_train_samples,)
643         self.nb_test_samples = (nb_test_samples,)
644         self.batch_size = batch_size
645         self.device = device
646         data_set = torchvision.datasets.MNIST(root="./data", train=True, download=True)
647         self.train_input = data_set.data[:nb_train_samples].view(-1, 28 * 28).long()
648         data_set = torchvision.datasets.MNIST(root="./data", train=False, download=True)
649         self.test_input = data_set.data[:nb_test_samples].view(-1, 28 * 28).long()
650
651     def batches(self, split="train", nb_to_use=-1, desc=None):
652         assert split in {"train", "test"}
653         input = self.train_input if split == "train" else self.test_input
654         if nb_to_use > 0:
655             input = input[:nb_to_use]
656         if desc is None:
657             desc = f"epoch-{split}"
658         for batch in tqdm.tqdm(
659             input.split(self.batch_size), dynamic_ncols=True, desc=desc
660         ):
661             yield batch
662
663     def vocabulary_size(self):
664         return 256
665
666     def produce_results(
667         self, n_epoch, model, result_dir, logger, deterministic_synthesis
668     ):
669         results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64)
670         ar_mask = torch.full_like(results, 1)
671         masked_inplace_autoregression(
672             model,
673             self.batch_size,
674             results,
675             ar_mask,
676             deterministic_synthesis,
677             device=self.device,
678         )
679         image_name = os.path.join(result_dir, f"mnist_result_{n_epoch:04d}.png")
680         torchvision.utils.save_image(
681             1 - results.reshape(-1, 1, 28, 28) / 255.0,
682             image_name,
683             nrow=16,
684             pad_value=0.8,
685         )
686         logger(f"wrote {image_name}")
687
688
689 ######################################################################
690
691 import maze
692
693
694 class Maze(Task):
695     def map2seq(self, *m):
696         return torch.cat([x.flatten(1) for x in m], 1)
697
698     def seq2map(self, s):
699         s = s.reshape(s.size(0), -1, self.height, self.width)
700         return (s[:, k] for k in range(s.size(1)))
701
702     def __init__(
703         self,
704         nb_train_samples,
705         nb_test_samples,
706         batch_size,
707         height,
708         width,
709         nb_walls,
710         device=torch.device("cpu"),
711     ):
712         super().__init__()
713
714         self.batch_size = batch_size
715         self.height = height
716         self.width = width
717         self.device = device
718
719         train_mazes, train_paths, _ = maze.create_maze_data(
720             nb_train_samples,
721             height=height,
722             width=width,
723             nb_walls=nb_walls,
724             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"),
725         )
726         self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device))
727
728         test_mazes, test_paths, _ = maze.create_maze_data(
729             nb_test_samples,
730             height=height,
731             width=width,
732             nb_walls=nb_walls,
733             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"),
734         )
735         self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
736
737         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
738
739     def batches(self, split="train", nb_to_use=-1, desc=None):
740         assert split in {"train", "test"}
741         input = self.train_input if split == "train" else self.test_input
742         if nb_to_use > 0:
743             input = input[:nb_to_use]
744         if desc is None:
745             desc = f"epoch-{split}"
746         for batch in tqdm.tqdm(
747             input.split(self.batch_size), dynamic_ncols=True, desc=desc
748         ):
749             yield batch
750
751     def vocabulary_size(self):
752         return self.nb_codes
753
754     def compute_error(
755         self, model, split="train", nb_to_use=-1, deterministic_synthesis=False
756     ):
757         model_device = next(model.parameters()).device
758         nb_total, nb_correct = 0, 0
759         count = torch.zeros(
760             self.width * self.height,
761             self.width * self.height,
762             device=model_device,
763             dtype=torch.int64,
764         )
765
766         for input in self.batches(split, nb_to_use):
767             input = input.to(model_device)
768             result = input.clone()
769             ar_mask = result.new_zeros(result.size())
770             ar_mask[:, self.height * self.width :] = 1
771             result *= 1 - ar_mask
772             masked_inplace_autoregression(
773                 model,
774                 self.batch_size,
775                 result,
776                 ar_mask,
777                 deterministic_synthesis,
778                 progress_bar_desc=None,
779                 device=self.device,
780             )
781             mazes, paths = self.seq2map(result)
782             path_correctness = maze.path_correctness(mazes, paths)
783             nb_correct += path_correctness.long().sum()
784             nb_total += mazes.size(0)
785
786             optimal_path_lengths = (
787                 (input[:, self.height * self.width :] == maze.v_path).long().sum(1)
788             )
789             predicted_path_lengths = (
790                 (result[:, self.height * self.width :] == maze.v_path).long().sum(1)
791             )
792             optimal_path_lengths = optimal_path_lengths[path_correctness]
793             predicted_path_lengths = predicted_path_lengths[path_correctness]
794             count[optimal_path_lengths, predicted_path_lengths] += 1
795
796         if count.max() == 0:
797             count = None
798         else:
799             count = count[
800                 : count.sum(1).nonzero().max() + 1, : count.sum(0).nonzero().max() + 1
801             ]
802
803         return nb_total, nb_correct, count
804
805     def produce_results(
806         self, n_epoch, model, result_dir, logger, deterministic_synthesis
807     ):
808         train_nb_total, train_nb_correct, count = self.compute_error(
809             model,
810             "train",
811             nb_to_use=1000,
812             deterministic_synthesis=deterministic_synthesis,
813         )
814         logger(
815             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
816         )
817
818         test_nb_total, test_nb_correct, count = self.compute_error(
819             model,
820             "test",
821             nb_to_use=1000,
822             deterministic_synthesis=deterministic_synthesis,
823         )
824         logger(
825             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
826         )
827
828         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
829
830         if count is not None:
831             proportion_optimal = count.diagonal().sum().float() / count.sum()
832             logger(f"proportion_optimal_test {proportion_optimal*100:.02f}%")
833             with open(
834                 os.path.join(result_dir, f"maze_result_{n_epoch:04d}.txt"), "w"
835             ) as f:
836                 for i in range(count.size(0)):
837                     for j in range(count.size(1)):
838                         eol = " " if j < count.size(1) - 1 else "\n"
839                         f.write(f"{count[i,j]}{eol}")
840
841         input = self.test_input[:48].to(next(model.parameters()).device)
842         result = input.clone()
843         ar_mask = result.new_zeros(result.size())
844         ar_mask[:, self.height * self.width :] = 1
845         result *= 1 - ar_mask
846         masked_inplace_autoregression(
847             model,
848             self.batch_size,
849             result,
850             ar_mask,
851             deterministic_synthesis,
852             device=self.device,
853         )
854
855         mazes, paths = self.seq2map(input)
856         _, predicted_paths = self.seq2map(result)
857
858         filename = os.path.join(result_dir, f"maze_result_{n_epoch:04d}.png")
859         maze.save_image(
860             filename,
861             mazes=mazes,
862             target_paths=paths,
863             predicted_paths=predicted_paths,
864             path_correct=maze.path_correctness(mazes, predicted_paths),
865             path_optimal=maze.path_optimality(paths, predicted_paths),
866         )
867         logger(f"wrote {filename}")
868
869
870 ######################################################################
871
872
873 import snake
874
875
876 class Snake(Task):
877     def __init__(
878         self,
879         nb_train_samples,
880         nb_test_samples,
881         batch_size,
882         height,
883         width,
884         nb_colors,
885         length,
886         prompt_length,
887         device=torch.device("cpu"),
888     ):
889         super().__init__()
890
891         self.batch_size = batch_size
892         self.height = height
893         self.width = width
894         self.device = device
895         self.prompt_length = prompt_length
896
897         self.train_input, self.train_prior_visits, _, _ = snake.generate_sequences(
898             nb_train_samples,
899             height,
900             width,
901             nb_colors,
902             length,
903             prompt_length,
904             self.device,
905         )
906         self.test_input, self.test_prior_visits, _, _ = snake.generate_sequences(
907             nb_test_samples,
908             height,
909             width,
910             nb_colors,
911             length,
912             prompt_length,
913             self.device,
914         )
915
916         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
917
918     def batches(self, split="train", nb_to_use=-1, desc=None):
919         assert split in {"train", "test"}
920         input = self.train_input if split == "train" else self.test_input
921         if nb_to_use > 0:
922             input = input[:nb_to_use]
923         if desc is None:
924             desc = f"epoch-{split}"
925         for batch in tqdm.tqdm(
926             input.split(self.batch_size), dynamic_ncols=True, desc=desc
927         ):
928             yield batch
929
930     def vocabulary_size(self):
931         return self.nb_codes
932
933     def produce_results(
934         self, n_epoch, model, result_dir, logger, deterministic_synthesis
935     ):
936         def compute_nb_correct(input, prior_visits):
937             result = input.clone()
938             i = torch.arange(result.size(1), device=result.device)[None, :]
939             ar_mask = (
940                 torch.logical_and(i >= self.prompt_length * 2, i % 2 == 0)
941                 .long()
942                 .expand_as(result)
943             )
944             result *= 1 - ar_mask
945
946             masked_inplace_autoregression(
947                 model,
948                 self.batch_size,
949                 result,
950                 ar_mask,
951                 deterministic_synthesis,
952                 device=self.device,
953             )
954
955             nb_total = ((prior_visits > 0) * ar_mask).sum()
956
957             nb_correct = ((result == input).long() * (prior_visits > 0) * ar_mask).sum()
958
959             return nb_total, nb_correct
960
961         test_nb_total, test_nb_correct = compute_nb_correct(
962             self.test_input[:1000], self.test_prior_visits[:1000]
963         )
964
965         logger(
966             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
967         )
968
969         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
970
971
972 ######################################################################
973
974
975 import stack
976
977
978 class Stack(Task):
979     def __init__(
980         self,
981         nb_train_samples,
982         nb_test_samples,
983         batch_size,
984         logger,
985         nb_steps,
986         nb_stacks,
987         nb_digits,
988         fraction_values_for_train=None,
989         device=torch.device("cpu"),
990     ):
991         super().__init__()
992
993         self.batch_size = batch_size
994         self.nb_steps = nb_steps
995         self.nb_stacks = nb_stacks
996         self.nb_digits = nb_digits
997         self.device = device
998
999         if fraction_values_for_train is None:
1000             values_for_train = None
1001             values_for_test = None
1002         else:
1003             all = torch.randperm(10**nb_digits)
1004             nb_for_train = int(all.size(0) * fraction_values_for_train)
1005             values_for_train = all[:nb_for_train]
1006             values_for_test = all[nb_for_train:]
1007
1008         self.train_input, self.train_stack_counts = stack.generate_sequences(
1009             nb_train_samples,
1010             nb_steps,
1011             nb_stacks,
1012             nb_digits,
1013             values_for_train,
1014             self.device,
1015         )
1016
1017         self.test_input, self.test_stack_counts = stack.generate_sequences(
1018             nb_test_samples,
1019             nb_steps,
1020             nb_stacks,
1021             nb_digits,
1022             values_for_test,
1023             self.device,
1024         )
1025
1026         i = torch.logical_and(self.test_input % 2 == 1, self.test_input < 2 * nb_stacks)
1027         counts = self.test_stack_counts.flatten()[i.flatten()]
1028         counts = F.one_hot(counts).sum(0)
1029         logger(f"test_pop_stack_counts {counts}")
1030
1031         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1032
1033     def batches(self, split="train", nb_to_use=-1, desc=None):
1034         assert split in {"train", "test"}
1035         input = self.train_input if split == "train" else self.test_input
1036         if nb_to_use > 0:
1037             input = input[:nb_to_use]
1038         if desc is None:
1039             desc = f"epoch-{split}"
1040         for batch in tqdm.tqdm(
1041             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1042         ):
1043             yield batch
1044
1045     def vocabulary_size(self):
1046         return self.nb_codes
1047
1048     def produce_results(
1049         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1050     ):
1051         def compute_nb_correct(input):
1052             result = input.clone()
1053             stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
1054             ar_mask = (result != input).long()
1055             masked_inplace_autoregression(
1056                 model,
1057                 self.batch_size,
1058                 result,
1059                 ar_mask,
1060                 deterministic_synthesis,
1061                 device=self.device,
1062             )
1063
1064             errors = ((result != input).long() * ar_mask).reshape(
1065                 -1, 1 + self.nb_digits
1066             )
1067             ar_mask = ar_mask.reshape(-1, 1 + self.nb_digits)
1068
1069             nb_total = ar_mask.max(1).values.sum()
1070             nb_correct = nb_total - errors.max(1).values.sum()
1071
1072             return nb_total, nb_correct
1073
1074         test_nb_total, test_nb_correct = compute_nb_correct(self.test_input[:1000])
1075
1076         logger(
1077             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
1078         )
1079
1080         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
1081
1082         ##############################################################
1083         # Log a few generated sequences
1084         input = self.test_input[:10, : 12 * (1 + self.nb_digits)]
1085         result = input.clone()
1086         stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
1087         ar_mask = (result != input).long()
1088
1089         # for n in range(result.size(0)):
1090         # logger(
1091         # f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
1092         # )
1093
1094         masked_inplace_autoregression(
1095             model,
1096             self.batch_size,
1097             result,
1098             ar_mask,
1099             deterministic_synthesis,
1100             device=self.device,
1101         )
1102
1103         #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
1104         for label, input in [
1105             ("train", self.train_input[:32]),
1106             ("test", self.test_input[:32]),
1107         ]:
1108             output = model(BracketedSequence(input)).x
1109             output = output.log_softmax(dim=-1)
1110             filename = os.path.join(
1111                 result_dir, f"stack_with_crossentropy_{n_epoch:04d}_{label}.txt"
1112             )
1113             with open(filename, "w") as f:
1114                 for n in range(input.size(0)):
1115                     s = stack.seq_to_str(
1116                         input[n], nb_stacks=self.nb_stacks, nb_digits=self.nb_digits
1117                     )
1118                     for t, k, w in zip(range(input[n].size(0)), input[n], s.split(" ")):
1119                         u = (
1120                             " " * (10 - len(w))
1121                             + w
1122                             + " "
1123                             + str(output[n][t][k].exp().item())
1124                             + "\n"
1125                         )
1126                         f.write(u)
1127                     f.write("\n")
1128             logger(f"wrote {filename}")
1129         #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
1130
1131         for n in range(result.size(0)):
1132             logger(
1133                 f"test_after  {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
1134             )
1135         ##############################################################
1136
1137
1138 ######################################################################
1139
1140 import rpl
1141
1142
1143 class RPL(Task):
1144     def tensorize(self, sequences):
1145         len_max = max([len(x) for x in sequences])
1146         return torch.cat(
1147             [
1148                 torch.tensor(
1149                     [
1150                         [
1151                             self.token2id[str(c)]
1152                             for c in s + ["<nul>"] * (len_max - len(s))
1153                         ]
1154                         for s in sequences
1155                     ]
1156                 )
1157             ],
1158             0,
1159         )
1160
1161     def seq2str(self, seq):
1162         return " ".join([self.id2token[i] for i in seq])
1163
1164     def __init__(
1165         self,
1166         nb_train_samples,
1167         nb_test_samples,
1168         batch_size,
1169         nb_starting_values=3,
1170         max_input=9,
1171         prog_len=6,
1172         nb_runs=5,
1173         no_prog=False,
1174         logger=None,
1175         device=torch.device("cpu"),
1176     ):
1177         super().__init__()
1178
1179         self.batch_size = batch_size
1180         self.device = device
1181         self.no_prog = no_prog
1182
1183         train_sequences = [
1184             rpl.generate(
1185                 nb_starting_values=nb_starting_values,
1186                 nb_result_values_max=4 * nb_starting_values,
1187                 max_input=max_input,
1188                 prog_len=prog_len,
1189                 nb_runs=nb_runs,
1190             )
1191             for _ in tqdm.tqdm(range(nb_train_samples), desc="train-data")
1192         ]
1193
1194         test_sequences = [
1195             rpl.generate(
1196                 nb_starting_values=nb_starting_values,
1197                 nb_result_values_max=4 * nb_starting_values,
1198                 max_input=max_input,
1199                 prog_len=prog_len,
1200                 nb_runs=nb_runs,
1201             )
1202             for _ in tqdm.tqdm(range(nb_test_samples), desc="test-data")
1203         ]
1204
1205         symbols = list(
1206             set(["<nul>"] + [x for l in train_sequences + test_sequences for x in l])
1207         )
1208         val_max = max([x if type(x) is int else 0 for x in symbols])
1209         symbols = list(filter(lambda x: type(x) is str, symbols))
1210         symbols.sort()
1211         symbols += [str(n) for n in range(val_max + 1)]
1212         self.token2id = dict([(c, n) for n, c in enumerate(symbols)])
1213         self.id2token = dict([(n, c) for c, n in self.token2id.items()])
1214
1215         self.t_nul = self.token2id["<nul>"]
1216         self.t_input = self.token2id["<in>"]
1217         self.t_output = self.token2id["<out>"]
1218         self.t_prog = self.token2id["<prg>"]
1219         self.t_end = self.token2id["<end>"]
1220
1221         self.train_input = self.tensorize(train_sequences)
1222         self.test_input = self.tensorize(test_sequences)
1223
1224         if no_prog:
1225             # Excise the program from every train and test example
1226             k = torch.arange(self.train_input.size(1), device=self.train_input.device)[
1227                 None, :
1228             ]
1229             p = (
1230                 ((self.train_input == self.t_prog).long() * k)
1231                 .max(1, keepdim=True)
1232                 .values
1233             )
1234             self.train_input = (
1235                 self.train_input * (k <= p).long()
1236                 + self.t_end * (k == p + 1).long()
1237                 + self.t_nul * (k > p + 1).long()
1238             )
1239             k = torch.arange(self.test_input.size(1), device=self.test_input.device)[
1240                 None, :
1241             ]
1242             p = (
1243                 ((self.test_input == self.t_prog).long() * k)
1244                 .max(1, keepdim=True)
1245                 .values
1246             )
1247             self.test_input = (
1248                 self.test_input * (k <= p).long()
1249                 + self.t_end * (k == p + 1).long()
1250                 + self.t_nul * (k > p + 1).long()
1251             )
1252
1253         if logger is not None:
1254             logger(f"value_max {val_max}")
1255             for x in self.train_input[:25]:
1256                 end = (x != self.t_nul).nonzero().max().item() + 1
1257                 seq = [self.id2token[i.item()] for i in x[:end]]
1258                 s = " ".join(seq)
1259                 logger(f"example_seq {s}")
1260
1261         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1262
1263     def batches(self, split="train", nb_to_use=-1, desc=None):
1264         assert split in {"train", "test"}
1265         input = self.train_input if split == "train" else self.test_input
1266         if nb_to_use > 0:
1267             input = input[:nb_to_use]
1268         if desc is None:
1269             desc = f"epoch-{split}"
1270         for batch in tqdm.tqdm(
1271             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1272         ):
1273             last = (batch != self.t_nul).max(0).values.nonzero().max() + 3
1274             batch = batch[:, :last].to(self.device)
1275             yield batch
1276
1277     def vocabulary_size(self):
1278         return self.nb_codes
1279
1280     def produce_results(
1281         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1282     ):
1283         # --------------------------------------------------------------------
1284         def compute_nb_errors_prog(input, nb_to_log=0):
1285             result = input.clone()
1286             s = (result == self.t_prog).long()
1287             ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
1288             result = (1 - ar_mask) * result + ar_mask * self.t_nul
1289
1290             masked_inplace_autoregression(
1291                 model,
1292                 self.batch_size,
1293                 result,
1294                 ar_mask,
1295                 deterministic_synthesis,
1296                 device=self.device,
1297             )
1298
1299             sum_nb_total, sum_nb_errors = 0, 0
1300             for one_input, one_result in zip(input, result):
1301                 seq = [self.id2token[i.item()] for i in one_result]
1302                 nb_total, nb_errors, prog, stacks = rpl.compute_nb_errors(seq)
1303                 sum_nb_total += 1
1304                 sum_nb_errors += 0 if nb_errors == 0 else 1
1305                 if nb_to_log > 0:
1306                     gt_seq = [self.id2token[i.item()] for i in one_input]
1307                     _, _, gt_prog, _ = rpl.compute_nb_errors(gt_seq)
1308                     gt_prog = " ".join([str(x) for x in gt_prog])
1309                     prog = " ".join([str(x) for x in prog])
1310                     comment = "*" if nb_errors == 0 else "-"
1311                     logger(f"{comment} PROG [{gt_prog}] PREDICTED [{prog}]")
1312                     for start_stack, target_stack, result_stack, correct in stacks:
1313                         comment = "*" if correct else "-"
1314                         start_stack = " ".join([str(x) for x in start_stack])
1315                         target_stack = " ".join([str(x) for x in target_stack])
1316                         result_stack = " ".join([str(x) for x in result_stack])
1317                         logger(
1318                             f"  {comment} [{start_stack}] -> [{target_stack}] PREDICTED [{result_stack}]"
1319                         )
1320                     nb_to_log -= 1
1321
1322             return sum_nb_total, sum_nb_errors
1323
1324         # --------------------------------------------------------------------
1325         def compute_nb_errors_output(input, nb_to_log=0):
1326             result = input.clone()
1327             k = torch.arange(result.size(1), device=result.device)[None, :]
1328             last_output_idx = (
1329                 ((result == self.t_output) * k).max(dim=1, keepdim=True).values
1330             )
1331             first_prog_idx = (
1332                 ((result == self.t_prog) * k).max(dim=1, keepdim=True).values
1333             )
1334             ar_mask = (k > last_output_idx).long() * (k < first_prog_idx).long()
1335             result = (1 - ar_mask) * result + ar_mask * self.t_nul
1336
1337             masked_inplace_autoregression(
1338                 model,
1339                 self.batch_size,
1340                 result,
1341                 ar_mask,
1342                 deterministic_synthesis,
1343                 device=self.device,
1344             )
1345
1346             sum_nb_total, sum_nb_errors = 0, 0
1347             for one_input, one_result, i, j in zip(
1348                 input, result, last_output_idx, first_prog_idx
1349             ):
1350                 seq = [self.id2token[i.item()] for i in one_result]
1351                 sum_nb_total += 1
1352                 correct = (one_input - one_result).abs().max() == 0
1353                 sum_nb_errors += 0 if correct else 1
1354                 if nb_to_log > 0:
1355                     result_stack = [
1356                         self.id2token[i.item()] for i in one_result[i : j + 1]
1357                     ]
1358                     target_stack = [
1359                         self.id2token[i.item()] for i in one_input[i : j + 1]
1360                     ]
1361                     comment = "*" if correct else "-"
1362                     result_stack = " ".join([str(x) for x in result_stack])
1363                     target_stack = " ".join([str(x) for x in target_stack])
1364                     logger(
1365                         f"output_test {comment} [{target_stack}] PREDICTED [{result_stack}]"
1366                     )
1367                     nb_to_log -= 1
1368
1369             return sum_nb_total, sum_nb_errors
1370
1371         # --------------------------------------------------------------------
1372
1373         if not self.no_prog:
1374             test_nb_total, test_nb_errors = compute_nb_errors_prog(
1375                 self.test_input[:1000].to(self.device), nb_to_log=10
1376             )
1377
1378             logger(
1379                 f"accuracy_prog_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
1380             )
1381
1382             logger(f"main_test_accuracy {n_epoch} {1-test_nb_errors/test_nb_total}")
1383
1384         test_nb_total, test_nb_errors = compute_nb_errors_output(
1385             self.test_input[:1000].to(self.device), nb_to_log=10
1386         )
1387
1388         logger(
1389             f"accuracy_output_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
1390         )
1391
1392         if save_attention_image is None:
1393             logger("no save_attention_image (is pycairo installed?)")
1394         else:
1395             ns = torch.randint(self.test_input.size(0), (1,)).item()
1396             input = self.test_input[ns : ns + 1].clone()
1397             last = (input != self.t_nul).max(0).values.nonzero().max() + 3
1398             input = input[:, :last].to(self.device)
1399
1400             with torch.autograd.no_grad():
1401                 t = model.training
1402                 model.eval()
1403                 model.record_attention(True)
1404                 model(BracketedSequence(input))
1405                 model.train(t)
1406                 ram = model.retrieve_attention()
1407                 model.record_attention(False)
1408
1409             tokens_output = [self.id2token[i.item()] for i in input[0]]
1410             tokens_input = ["n/a"] + tokens_output[:-1]
1411             for n_head in range(ram[0].size(1)):
1412                 filename = os.path.join(
1413                     result_dir, f"rpl_attention_{n_epoch}_h{n_head}.pdf"
1414                 )
1415                 attention_matrices = [m[0, n_head] for m in ram]
1416                 save_attention_image(
1417                     filename,
1418                     tokens_input,
1419                     tokens_output,
1420                     attention_matrices,
1421                     k_top=10,
1422                     # min_total_attention=0.9,
1423                     token_gap=12,
1424                     layer_gap=50,
1425                 )
1426                 logger(f"wrote {filename}")
1427
1428
1429 ######################################################################
1430
1431
1432 import expr
1433
1434
1435 class Expr(Task):
1436     def tensorize(self, sequences):
1437         len_max = max([len(x) for x in sequences])
1438         return torch.cat(
1439             [
1440                 torch.tensor(
1441                     [
1442                         [self.char2id[c] for c in s + "#" * (len_max - len(s))]
1443                         for s in sequences
1444                     ]
1445                 )
1446             ],
1447             0,
1448         ).to(self.device)
1449
1450     def __init__(
1451         self,
1452         nb_train_samples,
1453         nb_test_samples,
1454         nb_variables,
1455         sequence_length,
1456         operand_max,
1457         result_max,
1458         batch_size,
1459         device=torch.device("cpu"),
1460     ):
1461         super().__init__()
1462
1463         self.batch_size = batch_size
1464         self.device = device
1465
1466         train_sequences = expr.generate_sequences(
1467             nb_train_samples,
1468             nb_variables=nb_variables,
1469             length=sequence_length,
1470             operand_max=operand_max,
1471             result_max=result_max,
1472         )
1473
1474         test_sequences = expr.generate_sequences(
1475             nb_test_samples,
1476             nb_variables=nb_variables,
1477             length=sequence_length,
1478             operand_max=operand_max,
1479             result_max=result_max,
1480         )
1481
1482         symbols = list(set("#" + "".join(train_sequences + test_sequences)))
1483         symbols.sort()
1484
1485         self.char2id = dict([(c, n) for n, c in enumerate(symbols)])
1486         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
1487
1488         self.filler, self.space = self.char2id["#"], self.char2id[" "]
1489
1490         self.train_input = self.tensorize(train_sequences)
1491         self.test_input = self.tensorize(test_sequences)
1492
1493         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1494
1495     def batches(self, split="train", nb_to_use=-1, desc=None):
1496         assert split in {"train", "test"}
1497         input = self.train_input if split == "train" else self.test_input
1498         if nb_to_use > 0:
1499             input = input[:nb_to_use]
1500         if desc is None:
1501             desc = f"epoch-{split}"
1502         for batch in tqdm.tqdm(
1503             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1504         ):
1505             last = (batch != self.filler).max(0).values.nonzero().max() + 3
1506             batch = batch[:, :last]
1507             yield batch
1508
1509     def vocabulary_size(self):
1510         return self.nb_codes
1511
1512     def seq2str(self, s):
1513         return "".join([self.id2char[k.item()] for k in s])
1514
1515     def produce_results(
1516         self,
1517         n_epoch,
1518         model,
1519         result_dir,
1520         logger,
1521         deterministic_synthesis,
1522         input_file=None,
1523     ):
1524         def compute_nb_correct(input):
1525             result = input.clone()
1526             s = (result == self.space).long()
1527             ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
1528             result = (1 - ar_mask) * result + ar_mask * self.filler
1529             masked_inplace_autoregression(
1530                 model,
1531                 self.batch_size,
1532                 result,
1533                 ar_mask,
1534                 deterministic_synthesis,
1535                 device=self.device,
1536             )
1537
1538             nb_total = input.size(0)
1539             nb_correct = (input == result).long().min(1).values.sum()
1540
1541             #######################################################################
1542             # Comput predicted vs. true variable values
1543
1544             nb_delta = torch.zeros(5, dtype=torch.int64)
1545             nb_missed = 0
1546
1547             values_input = expr.extract_results([self.seq2str(s) for s in input])
1548             values_result = expr.extract_results([self.seq2str(s) for s in result])
1549
1550             filename = os.path.join(result_dir, f"expr_result_{n_epoch:04d}.txt")
1551
1552             with open(filename, "w") as f:
1553                 for i, r in zip(values_input, values_result):
1554                     for n, vi in i.items():
1555                         vr = r.get(n)
1556                         f.write(f"{vi} {-1 if vr is None else vr}\n")
1557
1558                         if vr is None or vr < 0:
1559                             nb_missed += 1
1560                         else:
1561                             d = abs(vr - vi)
1562                             if d >= nb_delta.size(0):
1563                                 nb_missed += 1
1564                             else:
1565                                 nb_delta[d] += 1
1566
1567             ######################################################################
1568
1569             return nb_total, nb_correct, nb_delta, nb_missed
1570
1571         (
1572             test_nb_total,
1573             test_nb_correct,
1574             test_nb_delta,
1575             test_nb_missed,
1576         ) = compute_nb_correct(self.test_input[:10000])
1577
1578         logger(
1579             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
1580         )
1581
1582         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
1583
1584         nb_total = test_nb_delta.sum() + test_nb_missed
1585         for d in range(test_nb_delta.size(0)):
1586             logger(
1587                 f"error_value {n_epoch} delta {d} {test_nb_delta[d]} {test_nb_delta[d]*100/nb_total:.02f}%"
1588             )
1589         logger(
1590             f"error_value {n_epoch} missed {test_nb_missed} {test_nb_missed*100/nb_total:.02f}%"
1591         )
1592
1593         ##############################################################
1594         # Log a few generated sequences
1595         if input_file is None:
1596             input = self.test_input[:10]
1597         else:
1598             with open(input_file, "r") as f:
1599                 sequences = [e.strip() for e in f.readlines()]
1600                 sequences = [s + " " + "#" * 50 for s in sequences]
1601                 input = self.tensorize(sequences)
1602
1603         result = input.clone()
1604         s = (result == self.space).long()
1605         ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
1606         result = (1 - ar_mask) * result + ar_mask * self.filler
1607
1608         for n in range(result.size(0)):
1609             logger(f"test_before {self.seq2str(result[n])}")
1610
1611         masked_inplace_autoregression(
1612             model,
1613             self.batch_size,
1614             result,
1615             ar_mask,
1616             deterministic_synthesis,
1617             device=self.device,
1618         )
1619
1620         correct = (1 - ar_mask) * self.space + ar_mask * input
1621         for n in range(result.size(0)):
1622             comment = "GOOD" if (result[n] - input[n]).abs().max() == 0 else ""
1623             logger(f"test_after  {self.seq2str(result[n])} {comment}")
1624             logger(f"truth       {self.seq2str(correct[n])}")
1625         ##############################################################
1626
1627
1628 ######################################################################
1629
1630 import grid
1631
1632
1633 class Grid(Task):
1634     # Make a tensor from a list of strings
1635     def str2tensor(self, descr):
1636         token_descr = [s.strip().split(" ") for s in descr]
1637         l = max([len(s) for s in token_descr])
1638         token_descr = [s + ["#"] * (l - len(s)) for s in token_descr]
1639         id_descr = [[self.token2id[u] for u in s] for s in token_descr]
1640         return torch.tensor(id_descr, device=self.device)
1641
1642     # Make a list of strings from a tensor
1643     def tensor2str(self, x):
1644         return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
1645
1646     # trim all the tensors in the tuple z to remove as much token from
1647     # left and right in the first tensor. If z is a tuple, all its
1648     # elements are trimed according to the triming for the first
1649     def trim(self, z, token="#"):
1650         n = self.token2id[token]
1651         if type(z) == tuple:
1652             x = z[0]
1653             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
1654             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
1655             return tuple([t[:, a:b] for t in z])
1656         else:
1657             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
1658             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
1659             return z[:, a:b]
1660
1661     ######################
1662
1663     def __init__(
1664         self,
1665         nb_train_samples,
1666         nb_test_samples,
1667         batch_size,
1668         size,
1669         fraction_play=0.0,
1670         logger=None,
1671         device=torch.device("cpu"),
1672     ):
1673         super().__init__()
1674
1675         self.device = device
1676         self.batch_size = batch_size
1677         self.grid_factory = grid.GridFactory(size=size)
1678         self.fraction_play = fraction_play
1679
1680         if logger is not None:
1681             logger(
1682                 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
1683             )
1684
1685         self.train_descr = self.grid_factory.generate_samples(
1686             nb=nb_train_samples,
1687             fraction_play=fraction_play,
1688             progress_bar=lambda r: tqdm.tqdm(r),
1689         )
1690
1691         self.test_descr = self.grid_factory.generate_samples(
1692             nb=nb_test_samples, fraction_play=0.0, progress_bar=lambda r: tqdm.tqdm(r)
1693         )
1694
1695         if fraction_play > 0:
1696             self.play_descr = self.grid_factory.generate_samples(
1697                 nb=25, fraction_play=1.0, progress_bar=lambda r: tqdm.tqdm(r)
1698             )
1699         else:
1700             self.play_descr = []
1701
1702         # Build the tokenizer
1703         tokens = set()
1704         for d in [self.train_descr, self.test_descr, self.play_descr]:
1705             for s in d:
1706                 for t in s.strip().split(" "):
1707                     tokens.add(t)
1708         # make this set a sorted list to get the same tensors given
1709         # the same descr
1710         tokens = list(tokens)
1711         tokens.sort()
1712         tokens = ["#"] + tokens
1713         self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
1714         self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
1715         self.t_nul = self.token2id["#"]
1716         self.t_true = self.token2id["true"]
1717         self.t_false = self.token2id["false"]
1718         # self.t_pipe = self.token2id["|"]
1719
1720         # Tokenize the train and test sets
1721         self.train_input = self.str2tensor(self.train_descr)
1722         self.test_input = self.str2tensor(self.test_descr)
1723         self.play_input = (
1724             None if len(self.play_descr) == 0 else self.str2tensor(self.play_descr)
1725         )
1726
1727     def batches(self, split="train", nb_to_use=-1, desc=None):
1728         assert split in {"train", "test"}
1729         input = self.train_input if split == "train" else self.test_input
1730         for batch in tqdm.tqdm(
1731             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
1732         ):
1733             yield self.trim(batch)
1734
1735     def vocabulary_size(self):
1736         return len(self.token2id)
1737
1738     def produce_results(
1739         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1740     ):
1741         correct = self.test_input[:1000]
1742         result = correct.clone()
1743         ar_mask = torch.logical_or(result == self.t_true, result == self.t_false).long()
1744         result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
1745
1746         logger(f"----------------------------------------------------------")
1747
1748         for e in self.tensor2str(result[:10]):
1749             logger(f"test_before {e}")
1750
1751         masked_inplace_autoregression(
1752             model,
1753             self.batch_size,
1754             result,
1755             ar_mask,
1756             deterministic_synthesis,
1757             device=self.device,
1758         )
1759
1760         logger(f"----------------------------------------------------------")
1761
1762         for e in self.tensor2str(result[:10]):
1763             logger(f"test_after  {e}")
1764
1765         logger(f"----------------------------------------------------------")
1766
1767         nb_total = ar_mask.sum().item()
1768         nb_correct = ((correct == result).long() * ar_mask).sum().item()
1769
1770         logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}")
1771         logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}")
1772
1773         if self.play_input is not None:
1774             result = self.play_input.clone()
1775             ar_mask = (result == self.t_pipe).long().cumsum(dim=1).clamp(max=1)
1776             result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
1777
1778             logger(f"----------------------------------------------------------")
1779
1780             for e in self.tensor2str(result[:10]):
1781                 logger(f"play_before {e}")
1782
1783             masked_inplace_autoregression(
1784                 model,
1785                 self.batch_size,
1786                 result,
1787                 ar_mask,
1788                 deterministic_synthesis,
1789                 device=self.device,
1790             )
1791
1792             logger(f"----------------------------------------------------------")
1793
1794             for e in self.tensor2str(result[:10]):
1795                 logger(f"play_after  {e}")
1796
1797             logger(f"----------------------------------------------------------")
1798
1799
1800 ######################################################################
1801
1802 import qmlp
1803
1804
1805 class QMLP(Task):
1806     ######################
1807
1808     def __init__(
1809         self,
1810         nb_train_samples,
1811         nb_test_samples,
1812         batch_size,
1813         result_dir,
1814         logger=None,
1815         device=torch.device("cpu"),
1816     ):
1817         super().__init__()
1818
1819         self.device = device
1820         self.batch_size = batch_size
1821         self.nb_samples_per_mlp = 256
1822
1823         if logger is not None:
1824             logger(
1825                 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
1826             )
1827
1828         seq, q_test_set, test_error = qmlp.generate_sequence_and_test_set(
1829             nb_mlps=nb_train_samples + nb_test_samples,
1830             nb_samples=self.nb_samples_per_mlp,
1831             device=self.device,
1832             batch_size=64,
1833             nb_epochs=250,
1834             nb_mlps_per_batch=1024,
1835         )
1836
1837         self.train_input = seq[:nb_train_samples]
1838         self.train_q_test_set = q_test_set[:nb_train_samples]
1839         self.train_ref_test_errors = test_error[:nb_train_samples]
1840         self.test_input = seq[nb_train_samples:]
1841         self.test_q_test_set = q_test_set[nb_train_samples:]
1842         self.test_ref_test_errors = test_error[nb_train_samples:]
1843
1844         filename = os.path.join(result_dir, f"train_errors_ref.dat")
1845         with open(filename, "w") as f:
1846             for e in self.train_ref_test_errors:
1847                 f.write(f"{e}\n")
1848
1849         filename = os.path.join(result_dir, f"test_errors_ref.dat")
1850         with open(filename, "w") as f:
1851             for e in self.test_ref_test_errors:
1852                 f.write(f"{e}\n")
1853
1854         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1855
1856     def batches(self, split="train", nb_to_use=-1, desc=None):
1857         assert split in {"train", "test"}
1858         input = self.train_input if split == "train" else self.test_input
1859         for batch in tqdm.tqdm(
1860             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
1861         ):
1862             yield batch
1863
1864     def vocabulary_size(self):
1865         return self.nb_codes
1866
1867     def produce_results(
1868         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1869     ):
1870         correct = self.test_input[:1000]
1871         result = correct.clone()
1872         ar_mask = (
1873             torch.arange(result.size(1), device=result.device)
1874             > self.nb_samples_per_mlp * 3 + 1
1875         ).long()[None, :]
1876         ar_mask = ar_mask.expand_as(result)
1877         result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
1878
1879         masked_inplace_autoregression(
1880             model,
1881             self.batch_size,
1882             result,
1883             ar_mask,
1884             deterministic_synthesis,
1885             device=self.device,
1886         )
1887
1888         q_train_set = result[:, : self.nb_samples_per_mlp * 3]
1889         q_params = result[:, self.nb_samples_per_mlp * 3 + 1 :]
1890         error_test = qmlp.evaluate_q_params(q_params, self.test_q_test_set)
1891
1892         filename = os.path.join(result_dir, f"test_errors_{n_epoch:04d}.dat")
1893         with open(filename, "w") as f:
1894             for e in error_test:
1895                 f.write(f"{e}\n")
1896
1897
1898 ######################################################################
1899
1900 import greed
1901
1902
1903 class Greed(Task):
1904     def __init__(
1905         self,
1906         nb_train_samples,
1907         nb_test_samples,
1908         batch_size,
1909         height,
1910         width,
1911         T,
1912         nb_walls,
1913         nb_coins,
1914         logger=None,
1915         device=torch.device("cpu"),
1916     ):
1917         super().__init__()
1918
1919         self.batch_size = batch_size
1920         self.device = device
1921
1922         self.world = greed.GreedWorld(height, width, T, nb_walls, nb_coins)
1923
1924         states, actions, rewards = self.world.generate_episodes(
1925             nb_train_samples + nb_test_samples
1926         )
1927         seq = self.world.episodes2seq(states, actions, rewards)
1928         self.train_input = seq[:nb_train_samples].to(self.device)
1929         self.test_input = seq[nb_train_samples:].to(self.device)
1930
1931     def wipe_lookahead_rewards(self, batch):
1932         t = torch.arange(batch.size(1), device=batch.device)[None, :]
1933         u = torch.randint(batch.size(1), (batch.size(0), 1), device=batch.device)
1934         lr_mask = (t <= u).long() * (
1935             t % self.world.it_len == self.world.index_lookahead_reward
1936         ).long()
1937
1938         return (
1939             lr_mask * self.world.lookahead_reward2code(greed.REWARD_UNKNOWN)
1940             + (1 - lr_mask) * batch
1941         )
1942
1943     def batches(self, split="train", nb_to_use=-1, desc=None):
1944         assert split in {"train", "test"}
1945         input = self.train_input if split == "train" else self.test_input
1946         if nb_to_use > 0:
1947             input = input[:nb_to_use]
1948         if desc is None:
1949             desc = f"epoch-{split}"
1950         for batch in tqdm.tqdm(
1951             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1952         ):
1953             yield self.wipe_lookahead_rewards(batch)
1954
1955     def vocabulary_size(self):
1956         return self.world.nb_codes
1957
1958     def thinking_autoregression(
1959         self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
1960     ):
1961         snapshots = []
1962
1963         def ar(result, ar_mask, logit_biases=None):
1964             ar_mask = ar_mask.expand_as(result)
1965             result *= 1 - ar_mask
1966             masked_inplace_autoregression(
1967                 model,
1968                 self.batch_size,
1969                 result,
1970                 ar_mask,
1971                 deterministic_synthesis=deterministic_synthesis,
1972                 logit_biases=logit_biases,
1973                 device=self.device,
1974                 progress_bar_desc=None,
1975             )
1976             warnings.warn("keeping thinking snapshots", RuntimeWarning)
1977             snapshots.append(result[:100].detach().clone())
1978
1979         # Generate iteration after iteration
1980
1981         result = self.test_input[:250].clone()
1982         # Erase all the content but that of the first iteration
1983         result[:, self.world.it_len :] = -1
1984         # Set the lookahead_reward of the firs to UNKNOWN
1985         result[:, self.world.index_lookahead_reward] = self.world.lookahead_reward2code(
1986             greed.REWARD_UNKNOWN
1987         )
1988
1989         t = torch.arange(result.size(1), device=result.device)[None, :]
1990
1991         for u in tqdm.tqdm(
1992             range(0, result.size(1), self.world.it_len),
1993             desc="thinking",
1994         ):
1995             # Generate the next state but keep the initial one, the
1996             # lookahead_reward of previous iterations are set to
1997             # UNKNOWN
1998             if u > 0:
1999                 result[
2000                     :, u + self.world.index_lookahead_reward
2001                 ] = self.world.lookahead_reward2code(greed.REWARD_UNKNOWN)
2002                 ar_mask = (t >= u + self.world.index_states).long() * (
2003                     t < u + self.world.index_states + self.world.state_len
2004                 ).long()
2005                 ar(result, ar_mask)
2006
2007             # Generate the action and reward with lookahead_reward to +1
2008             result[
2009                 :, u + self.world.index_lookahead_reward
2010             ] = self.world.lookahead_reward2code(greed.REWARD_PLUS)
2011             ar_mask = (t >= u + self.world.index_reward).long() * (
2012                 t <= u + self.world.index_action
2013             ).long()
2014             ar(result, ar_mask)
2015
2016             # Set the lookahead_reward to UNKNOWN for the next iterations
2017             result[
2018                 :, u + self.world.index_lookahead_reward
2019             ] = self.world.lookahead_reward2code(greed.REWARD_UNKNOWN)
2020
2021         filename = os.path.join(result_dir, f"test_thinking_compute_{n_epoch:04d}.txt")
2022         with open(filename, "w") as f:
2023             for n in range(snapshots[0].size(0)):
2024                 for s in snapshots:
2025                     lr, s, a, r = self.world.seq2episodes(
2026                         s[n : n + 1],
2027                     )
2028                     str = self.world.episodes2str(
2029                         lr, s, a, r, unicode=True, ansi_colors=True
2030                     )
2031                     f.write(str)
2032                 f.write("\n\n")
2033
2034         # Saving the generated sequences
2035
2036         lr, s, a, r = self.world.seq2episodes(result)
2037         str = self.world.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True)
2038
2039         filename = os.path.join(result_dir, f"test_thinking_seq_{n_epoch:04d}.txt")
2040         with open(filename, "w") as f:
2041             f.write(str)
2042             logger(f"wrote {filename}")
2043
2044     def produce_results(
2045         self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
2046     ):
2047         result = self.wipe_lookahead_rewards(self.test_input[:250].clone())
2048
2049         # Saving the ground truth
2050
2051         lr, s, a, r = self.world.seq2episodes(
2052             result,
2053         )
2054         str = self.world.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True)
2055
2056         filename = os.path.join(result_dir, f"test_true_seq_{n_epoch:04d}.txt")
2057         with open(filename, "w") as f:
2058             f.write(str)
2059             logger(f"wrote {filename}")
2060
2061         # Re-generating from the first frame
2062
2063         ar_mask = (
2064             torch.arange(result.size(1), device=result.device) >= self.world.it_len
2065         ).long()[None, :]
2066         ar_mask = ar_mask.expand_as(result)
2067         result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
2068
2069         masked_inplace_autoregression(
2070             model,
2071             self.batch_size,
2072             result,
2073             ar_mask,
2074             deterministic_synthesis,
2075             device=self.device,
2076         )
2077
2078         # Saving the generated sequences
2079
2080         lr, s, a, r = self.world.seq2episodes(
2081             result,
2082         )
2083         str = self.world.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True)
2084
2085         filename = os.path.join(result_dir, f"test_seq_{n_epoch:04d}.txt")
2086         with open(filename, "w") as f:
2087             f.write(str)
2088             logger(f"wrote {filename}")
2089
2090         self.thinking_autoregression(
2091             n_epoch, model, result_dir, logger, deterministic_synthesis, nmax
2092         )
2093
2094
2095 ######################################################################