Update.
[picoclvr.git] / tasks.py
1 #!/usr/bin/env python
2
3 import math, os, tqdm
4
5 import torch, torchvision
6
7 from torch import nn
8 from torch.nn import functional as F
9
10 ######################################################################
11
12
13 def masked_inplace_autoregression(
14     model,
15     batch_size,
16     input,
17     ar_mask,
18     deterministic_synthesis,
19     forbidden_tokens=None,
20     progress_bar_desc="autoregression",
21     device=torch.device("cpu"),
22 ):
23     batches = zip(input.split(batch_size), ar_mask.split(batch_size))
24
25     if progress_bar_desc is not None:
26         batches = tqdm.tqdm(
27             batches,
28             dynamic_ncols=True,
29             desc=progress_bar_desc,
30             total=input.size(0) // batch_size,
31         )
32
33     for input, ar_mask in batches:
34         model.masked_inplace_autoregression(
35             input, ar_mask, forbidden_tokens, deterministic_synthesis
36         )
37
38
39 class Task:
40     def batches(self, split="train"):
41         pass
42
43     def vocabulary_size(self):
44         pass
45
46     def produce_results(
47         self, n_epoch, model, result_dir, logger, deterministic_synthesis
48     ):
49         pass
50
51
52 ######################################################################
53
54 import picoclvr
55
56
57 class PicoCLVR(Task):
58     # Make a tensor from a list of strings
59     def tensorize(self, descr):
60         token_descr = [s.strip().split(" ") for s in descr]
61         l = max([len(s) for s in token_descr])
62         token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
63         id_descr = [[self.token2id[u] for u in s] for s in token_descr]
64         return torch.tensor(id_descr, device=self.device)
65
66     # Make a list of strings from a tensor
67     def detensorize(self, x):
68         return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
69
70     # trim all the tensors in the tuple z to remove as much token from
71     # left and right in the first tensor. If z is a tuple, all its
72     # elements are trimed according to the triming for the first
73     def trim(self, z, token="<nul>"):
74         n = self.token2id[token]
75         if type(z) == tuple:
76             x = z[0]
77             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
78             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
79             return tuple([t[:, a:b] for t in z])
80         else:
81             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
82             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
83             return z[:, a:b]
84
85     ######################
86
87     def __init__(
88         self,
89         nb_train_samples,
90         nb_test_samples,
91         batch_size,
92         height,
93         width,
94         nb_colors=5,
95         logger=None,
96         device=torch.device("cpu"),
97         pruner_train=None,
98         pruner_eval=None,
99     ):
100         def generate_descr(nb, cache_suffix, pruner):
101             return picoclvr.generate(
102                 nb,
103                 height=self.height,
104                 width=self.width,
105                 nb_colors=nb_colors,
106                 pruner=pruner,
107             )
108
109         self.height = height
110         self.width = width
111         self.batch_size = batch_size
112         self.device = device
113         self.pruner_train = pruner_train
114         self.pruner_eval = pruner_eval
115
116         if logger is not None:
117             logger(
118                 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
119             )
120
121         self.train_descr = generate_descr(
122             nb_train_samples, "train", pruner=self.pruner_train
123         )
124         self.test_descr = generate_descr(nb_test_samples, "test", pruner=None)
125
126         # Build the tokenizer
127         tokens = {"<nul>", "<img>"}
128         for d in [self.train_descr, self.test_descr]:
129             for s in d:
130                 for t in s.strip().split(" "):
131                     tokens.add(t)
132         # make this set a sorted list to get the same tensors given
133         # the same descr
134         tokens = list(tokens)
135         tokens.sort()
136         self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
137         self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
138         self.t_img, self.t_nul = self.token2id["<img>"], self.token2id["<nul>"]
139
140         # Tokenize the train and test sets
141         self.train_input = self.tensorize(self.train_descr)
142         self.test_input = self.tensorize(self.test_descr)
143
144     def batches(self, split="train"):
145         assert split in {"train", "test"}
146         input = self.train_input if split == "train" else self.test_input
147         for batch in tqdm.tqdm(
148             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
149         ):
150             yield self.trim(batch)
151
152     def vocabulary_size(self):
153         return len(self.token2id)
154
155     def compute_missing_properties(
156         self, n_epoch, model, logger, deterministic_synthesis, pruner=None
157     ):
158         acc_nb_requested_properties = []
159         acc_nb_missing_properties = []
160         acc_nb_results = 0
161
162         for input in tqdm.tqdm(
163             self.test_input.split(self.batch_size),
164             dynamic_ncols=True,
165             desc=f"test-properties",
166         ):
167             result = input.clone()
168             ar_mask = (result == self.t_img).long().cumsum(dim=1).clamp(max=1)
169             result = (1 - ar_mask) * result + ar_mask * self.t_nul
170             masked_inplace_autoregression(
171                 model,
172                 self.batch_size,
173                 result,
174                 ar_mask,
175                 deterministic_synthesis,
176                 progress_bar_desc=None,
177                 device=self.device,
178             )
179
180             result_descr = self.detensorize(result)
181             np = picoclvr.nb_properties(
182                 result_descr,
183                 height=self.height,
184                 width=self.width,
185                 pruner=pruner,
186             )
187             nb_requested_properties, _, nb_missing_properties = zip(*np)
188             acc_nb_requested_properties += nb_requested_properties
189             acc_nb_missing_properties += nb_missing_properties
190             acc_nb_results += len(result_descr)
191
192         nb_requested_properties = sum(acc_nb_requested_properties)
193         nb_missing_properties = sum(acc_nb_missing_properties)
194
195         prefix = "" if pruner is None else "pruned_"
196         logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
197         logger(
198             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
199         )
200         logger(
201             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
202         )
203
204     ######################################################################
205
206     def produce_results(
207         self, n_epoch, model, result_dir, logger, deterministic_synthesis
208     ):
209         self.compute_missing_properties(n_epoch, model, logger, deterministic_synthesis)
210
211         if self.pruner_eval is not None:
212             self.compute_missing_properties(n_epoch, model, self.pruner_eval)
213
214         nb_tokens_to_generate = self.height * self.width + 3
215         result_descr = []
216         nb_per_primer = 8
217         primer = []
218
219         for primer_descr in [
220             "red above green <sep> green top <sep> blue right of red",
221             "there is red <sep> there is yellow <sep> there is blue",
222             "red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left",
223             "green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top",
224         ]:
225             primer += [primer_descr + " <img>"] * nb_per_primer
226
227         result = self.tensorize(primer)
228         ar_mask = (result == self.t_nul).long()
229         masked_inplace_autoregression(
230             model,
231             self.batch_size,
232             result,
233             ar_mask,
234             deterministic_synthesis,
235             device=self.device,
236         )
237         result_descr = self.detensorize(result)
238
239         np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
240
241         acc_nb_requested_properties, _, acc_nb_missing_properties = zip(*np)
242         acc_nb_results = len(result_descr)
243
244         nb_requested_properties = sum(acc_nb_requested_properties)
245         nb_missing_properties = sum(acc_nb_missing_properties)
246
247         prefix = "demo_"
248         logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
249         logger(
250             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
251         )
252         logger(
253             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
254         )
255
256         img = picoclvr.descr2img(result_descr, height=self.height, width=self.width)
257
258         if img.dim() == 5:
259             if img.size(1) == 1:
260                 img = F.pad(img.squeeze(1), pad=(1, 1, 1, 1), value=64)
261             else:
262                 img = torch.cat(
263                     [
264                         torchvision.utils.make_grid(x, padding=1, pad_value=64)[None]
265                         for x in img
266                     ],
267                     0,
268                 )
269
270         image_name = os.path.join(result_dir, f"picoclvr_result_{n_epoch:04d}.png")
271         torchvision.utils.save_image(
272             img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=0.0
273         )
274         logger(f"wrote {image_name}")
275
276
277 ######################################################################
278
279
280 class MNIST(Task):
281     def __init__(
282         self, nb_train_samples, nb_test_samples, batch_size, device=torch.device("cpu")
283     ):
284         self.nb_train_samples = (nb_train_samples,)
285         self.nb_test_samples = (nb_test_samples,)
286         self.batch_size = batch_size
287         self.device = device
288         data_set = torchvision.datasets.MNIST(root="./data", train=True, download=True)
289         self.train_input = data_set.data[:nb_train_samples].view(-1, 28 * 28).long()
290         data_set = torchvision.datasets.MNIST(root="./data", train=False, download=True)
291         self.test_input = data_set.data[:nb_test_samples].view(-1, 28 * 28).long()
292
293     def batches(self, split="train", nb_to_use=-1, desc=None):
294         assert split in {"train", "test"}
295         input = self.train_input if split == "train" else self.test_input
296         if nb_to_use > 0:
297             input = input[:nb_to_use]
298         if desc is None:
299             desc = f"epoch-{split}"
300         for batch in tqdm.tqdm(
301             input.split(self.batch_size), dynamic_ncols=True, desc=desc
302         ):
303             yield batch
304
305     def vocabulary_size(self):
306         return 256
307
308     def produce_results(
309         self, n_epoch, model, result_dir, logger, deterministic_synthesis
310     ):
311         results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64)
312         ar_mask = torch.full_like(results, 1)
313         masked_inplace_autoregression(
314             model,
315             self.batch_size,
316             results,
317             ar_mask,
318             deterministic_synthesis,
319             device=self.device,
320         )
321         image_name = os.path.join(result_dir, f"mnist_result_{n_epoch:04d}.png")
322         torchvision.utils.save_image(
323             1 - results.reshape(-1, 1, 28, 28) / 255.0,
324             image_name,
325             nrow=16,
326             pad_value=0.8,
327         )
328         logger(f"wrote {image_name}")
329
330
331 ######################################################################
332
333 import maze
334
335
336 class Maze(Task):
337     def map2seq(self, *m):
338         return torch.cat([x.flatten(1) for x in m], 1)
339
340     def seq2map(self, s):
341         s = s.reshape(s.size(0), -1, self.height, self.width)
342         return (s[:, k] for k in range(s.size(1)))
343
344     def __init__(
345         self,
346         nb_train_samples,
347         nb_test_samples,
348         batch_size,
349         height,
350         width,
351         nb_walls,
352         device=torch.device("cpu"),
353     ):
354         self.batch_size = batch_size
355         self.height = height
356         self.width = width
357         self.device = device
358
359         train_mazes, train_paths, _ = maze.create_maze_data(
360             nb_train_samples,
361             height=height,
362             width=width,
363             nb_walls=nb_walls,
364             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"),
365         )
366         self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device))
367
368         test_mazes, test_paths, _ = maze.create_maze_data(
369             nb_test_samples,
370             height=height,
371             width=width,
372             nb_walls=nb_walls,
373             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"),
374         )
375         self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
376
377         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
378
379     def batches(self, split="train", nb_to_use=-1, desc=None):
380         assert split in {"train", "test"}
381         input = self.train_input if split == "train" else self.test_input
382         if nb_to_use > 0:
383             input = input[:nb_to_use]
384         if desc is None:
385             desc = f"epoch-{split}"
386         for batch in tqdm.tqdm(
387             input.split(self.batch_size), dynamic_ncols=True, desc=desc
388         ):
389             yield batch
390
391     def vocabulary_size(self):
392         return self.nb_codes
393
394     def compute_error(
395         self, model, split="train", nb_to_use=-1, deterministic_synthesis=False
396     ):
397         nb_total, nb_correct = 0, 0
398         count = torch.zeros(
399             self.width * self.height,
400             self.width * self.height,
401             device=self.device,
402             dtype=torch.int64,
403         )
404
405         for input in self.batches(split, nb_to_use):
406             result = input.clone()
407             ar_mask = result.new_zeros(result.size())
408             ar_mask[:, self.height * self.width :] = 1
409             result *= 1 - ar_mask
410             masked_inplace_autoregression(
411                 model,
412                 self.batch_size,
413                 result,
414                 ar_mask,
415                 deterministic_synthesis,
416                 progress_bar_desc=None,
417                 device=self.device,
418             )
419             mazes, paths = self.seq2map(result)
420             path_correctness = maze.path_correctness(mazes, paths)
421             nb_correct += path_correctness.long().sum()
422             nb_total += mazes.size(0)
423
424             optimal_path_lengths = (
425                 (input[:, self.height * self.width :] == maze.v_path).long().sum(1)
426             )
427             predicted_path_lengths = (
428                 (result[:, self.height * self.width :] == maze.v_path).long().sum(1)
429             )
430             optimal_path_lengths = optimal_path_lengths[path_correctness]
431             predicted_path_lengths = predicted_path_lengths[path_correctness]
432             count[optimal_path_lengths, predicted_path_lengths] += 1
433
434         if count.max() == 0:
435             count = None
436         else:
437             count = count[
438                 : count.sum(1).nonzero().max() + 1, : count.sum(0).nonzero().max() + 1
439             ]
440
441         return nb_total, nb_correct, count
442
443     def produce_results(
444         self, n_epoch, model, result_dir, logger, deterministic_synthesis
445     ):
446         with torch.autograd.no_grad():
447             t = model.training
448             model.eval()
449
450             train_nb_total, train_nb_correct, count = self.compute_error(
451                 model,
452                 "train",
453                 nb_to_use=1000,
454                 deterministic_synthesis=deterministic_synthesis,
455             )
456             logger(
457                 f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
458             )
459
460             test_nb_total, test_nb_correct, count = self.compute_error(
461                 model,
462                 "test",
463                 nb_to_use=1000,
464                 deterministic_synthesis=deterministic_synthesis,
465             )
466             logger(
467                 f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
468             )
469
470             if count is not None:
471                 proportion_optimal = count.diagonal().sum().float() / count.sum()
472                 logger(f"proportion_optimal_test {proportion_optimal*100:.02f}%")
473                 with open(
474                     os.path.join(result_dir, f"maze_result_{n_epoch:04d}.txt"), "w"
475                 ) as f:
476                     for i in range(count.size(0)):
477                         for j in range(count.size(1)):
478                             eol = " " if j < count.size(1) - 1 else "\n"
479                             f.write(f"{count[i,j]}{eol}")
480
481             input = self.test_input[:48]
482             result = input.clone()
483             ar_mask = result.new_zeros(result.size())
484             ar_mask[:, self.height * self.width :] = 1
485             result *= 1 - ar_mask
486             masked_inplace_autoregression(
487                 model,
488                 self.batch_size,
489                 result,
490                 ar_mask,
491                 deterministic_synthesis,
492                 device=self.device,
493             )
494
495             mazes, paths = self.seq2map(input)
496             _, predicted_paths = self.seq2map(result)
497
498             filename = os.path.join(result_dir, f"maze_result_{n_epoch:04d}.png")
499             maze.save_image(
500                 filename,
501                 mazes=mazes,
502                 target_paths=paths,
503                 predicted_paths=predicted_paths,
504                 path_correct=maze.path_correctness(mazes, predicted_paths),
505                 path_optimal=maze.path_optimality(paths, predicted_paths),
506             )
507             logger(f"wrote {filename}")
508
509             model.train(t)
510
511
512 ######################################################################
513
514
515 import snake
516
517
518 class Snake(Task):
519     def __init__(
520         self,
521         nb_train_samples,
522         nb_test_samples,
523         batch_size,
524         height,
525         width,
526         nb_colors,
527         length,
528         prompt_length,
529         device=torch.device("cpu"),
530     ):
531         self.batch_size = batch_size
532         self.height = height
533         self.width = width
534         self.device = device
535         self.prompt_length = prompt_length
536
537         self.train_input, self.train_prior_visits, _, _ = snake.generate_sequences(
538             nb_train_samples,
539             height,
540             width,
541             nb_colors,
542             length,
543             prompt_length,
544             self.device,
545         )
546         self.test_input, self.test_prior_visits, _, _ = snake.generate_sequences(
547             nb_test_samples,
548             height,
549             width,
550             nb_colors,
551             length,
552             prompt_length,
553             self.device,
554         )
555
556         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
557
558     def batches(self, split="train", nb_to_use=-1, desc=None):
559         assert split in {"train", "test"}
560         input = self.train_input if split == "train" else self.test_input
561         if nb_to_use > 0:
562             input = input[:nb_to_use]
563         if desc is None:
564             desc = f"epoch-{split}"
565         for batch in tqdm.tqdm(
566             input.split(self.batch_size), dynamic_ncols=True, desc=desc
567         ):
568             yield batch
569
570     def vocabulary_size(self):
571         return self.nb_codes
572
573     def produce_results(
574         self, n_epoch, model, result_dir, logger, deterministic_synthesis
575     ):
576         with torch.autograd.no_grad():
577             t = model.training
578             model.eval()
579
580             def compute_nb_correct(input, prior_visits):
581                 result = input.clone()
582                 i = torch.arange(result.size(1), device=result.device)[None, :]
583                 ar_mask = (
584                     torch.logical_and(i >= self.prompt_length * 2, i % 2 == 0)
585                     .long()
586                     .expand_as(result)
587                 )
588                 result *= 1 - ar_mask
589
590                 # snake.solver(result,ar_mask)
591
592                 masked_inplace_autoregression(
593                     model,
594                     self.batch_size,
595                     result,
596                     ar_mask,
597                     deterministic_synthesis,
598                     device=self.device,
599                 )
600
601                 nb_total = ((prior_visits > 0) * ar_mask).sum()
602
603                 nb_correct = (
604                     (result == input).long() * (prior_visits > 0) * ar_mask
605                 ).sum()
606
607                 # nb_total = result.size(0)
608                 # nb_correct = ((result - input).abs().sum(1) == 0).sum()
609
610                 return nb_total, nb_correct
611
612             # train_nb_total, train_nb_correct = compute_nb_correct(
613             # self.train_input, self.train_prior_visits
614             # )
615
616             # logger(
617             # f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
618             # )
619
620             test_nb_total, test_nb_correct = compute_nb_correct(
621                 self.test_input[:1000], self.test_prior_visits[:1000]
622             )
623
624             logger(
625                 f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
626             )
627
628             model.train(t)
629
630
631 ######################################################################
632
633
634 import stack
635
636
637 class Stack(Task):
638     def __init__(
639         self,
640         nb_train_samples,
641         nb_test_samples,
642         batch_size,
643         logger,
644         nb_steps,
645         nb_stacks,
646         nb_digits,
647         fraction_values_for_train=None,
648         device=torch.device("cpu"),
649     ):
650         self.batch_size = batch_size
651         self.nb_steps = nb_steps
652         self.nb_stacks = nb_stacks
653         self.nb_digits = nb_digits
654         self.device = device
655
656         if fraction_values_for_train is None:
657             values_for_train = None
658             values_for_test = None
659         else:
660             all = torch.randperm(10**nb_digits)
661             nb_for_train = int(all.size(0) * fraction_values_for_train)
662             values_for_train = all[:nb_for_train]
663             values_for_test = all[nb_for_train:]
664
665         self.train_input, self.train_stack_counts = stack.generate_sequences(
666             nb_train_samples,
667             nb_steps,
668             nb_stacks,
669             nb_digits,
670             values_for_train,
671             self.device,
672         )
673
674         self.test_input, self.test_stack_counts = stack.generate_sequences(
675             nb_test_samples,
676             nb_steps,
677             nb_stacks,
678             nb_digits,
679             values_for_test,
680             self.device,
681         )
682
683         i = torch.logical_and(self.test_input % 2 == 1, self.test_input < 2 * nb_stacks)
684         counts = self.test_stack_counts.flatten()[i.flatten()]
685         counts = F.one_hot(counts).sum(0)
686         logger(f"test_pop_stack_counts {counts}")
687
688         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
689
690     def batches(self, split="train", nb_to_use=-1, desc=None):
691         assert split in {"train", "test"}
692         input = self.train_input if split == "train" else self.test_input
693         if nb_to_use > 0:
694             input = input[:nb_to_use]
695         if desc is None:
696             desc = f"epoch-{split}"
697         for batch in tqdm.tqdm(
698             input.split(self.batch_size), dynamic_ncols=True, desc=desc
699         ):
700             yield batch
701
702     def vocabulary_size(self):
703         return self.nb_codes
704
705     def produce_results(
706         self, n_epoch, model, result_dir, logger, deterministic_synthesis
707     ):
708         with torch.autograd.no_grad():
709             t = model.training
710             model.eval()
711
712             def compute_nb_correct(input):
713                 result = input.clone()
714                 stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
715                 ar_mask = (result != input).long()
716                 masked_inplace_autoregression(
717                     model,
718                     self.batch_size,
719                     result,
720                     ar_mask,
721                     deterministic_synthesis,
722                     device=self.device,
723                 )
724
725                 errors = ((result != input).long() * ar_mask).reshape(
726                     -1, 1 + self.nb_digits
727                 )
728                 ar_mask = ar_mask.reshape(-1, 1 + self.nb_digits)
729
730                 nb_total = ar_mask.max(1).values.sum()
731                 nb_correct = nb_total - errors.max(1).values.sum()
732
733                 return nb_total, nb_correct
734
735             test_nb_total, test_nb_correct = compute_nb_correct(self.test_input[:1000])
736
737             logger(
738                 f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
739             )
740
741             ##############################################################
742             # Log a few generated sequences
743             input = self.test_input[:10, : 12 * (1 + self.nb_digits)]
744             result = input.clone()
745             stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
746             ar_mask = (result != input).long()
747             for n in range(result.size(0)):
748                 logger(
749                     f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
750                 )
751                 masked_inplace_autoregression(
752                     model,
753                     self.batch_size,
754                     result,
755                     ar_mask,
756                     deterministic_synthesis,
757                     device=self.device,
758                 )
759             for n in range(result.size(0)):
760                 logger(
761                     f"test_after  {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
762                 )
763             ##############################################################
764
765             model.train(t)
766
767
768 ######################################################################
769
770
771 import expr
772
773
774 class Expr(Task):
775     def __init__(
776         self,
777         nb_train_samples,
778         nb_test_samples,
779         nb_variables,
780         sequence_length,
781         batch_size,
782         device=torch.device("cpu"),
783     ):
784         self.batch_size = batch_size
785         self.device = device
786
787         train_sequences = expr.generate_sequences(
788             nb_train_samples,
789             nb_variables=nb_variables,
790             length=sequence_length,
791             # length=2 * sequence_length,
792             # randomize_length=True,
793         )
794         test_sequences = expr.generate_sequences(
795             nb_test_samples,
796             nb_variables=nb_variables,
797             length=sequence_length,
798         )
799         self.char2id = dict(
800             [
801                 (c, n)
802                 for n, c in enumerate(
803                     set("#" + "".join(train_sequences + test_sequences))
804                 )
805             ]
806         )
807         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
808
809         self.filler, self.space = self.char2id["#"], self.char2id[" "]
810
811         len_max = max([len(x) for x in train_sequences])
812         self.train_input = torch.cat(
813             [
814                 torch.tensor(
815                     [
816                         [self.char2id[c] for c in s + "#" * (len_max - len(s))]
817                         for s in train_sequences
818                     ]
819                 )
820             ],
821             0,
822         ).to(device)
823
824         len_max = max([len(x) for x in test_sequences])
825         self.test_input = torch.cat(
826             [
827                 torch.tensor(
828                     [
829                         [self.char2id[c] for c in s + "#" * (len_max - len(s))]
830                         for s in test_sequences
831                     ]
832                 )
833             ],
834             0,
835         ).to(device)
836
837         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
838
839     def batches(self, split="train", nb_to_use=-1, desc=None):
840         assert split in {"train", "test"}
841         input = self.train_input if split == "train" else self.test_input
842         if nb_to_use > 0:
843             input = input[:nb_to_use]
844         if desc is None:
845             desc = f"epoch-{split}"
846         for batch in tqdm.tqdm(
847             input.split(self.batch_size), dynamic_ncols=True, desc=desc
848         ):
849             if split == "train":
850                 last = (batch != self.filler).max(0).values.nonzero().max() + 3
851                 batch = batch[:, :last]
852             yield batch
853
854     def vocabulary_size(self):
855         return self.nb_codes
856
857     def seq2str(self, s):
858         return "".join([self.id2char[k.item()] for k in s])
859
860     def produce_results(
861         self, n_epoch, model, result_dir, logger, deterministic_synthesis
862     ):
863         with torch.autograd.no_grad():
864             t = model.training
865             model.eval()
866
867             def compute_nb_correct(input):
868                 result = input.clone()
869                 ar_mask = (result == self.space).long().cumsum(dim=1).clamp(max=1)
870                 result = (1 - ar_mask) * result + ar_mask * self.filler
871                 masked_inplace_autoregression(
872                     model,
873                     self.batch_size,
874                     result,
875                     ar_mask,
876                     deterministic_synthesis,
877                     device=self.device,
878                 )
879
880                 nb_total = input.size(0)
881                 nb_correct = (input == result).long().min(1).values.sum()
882
883                 #######################################################################
884                 # Comput predicted vs. true variable values
885
886                 nb_delta = torch.zeros(5, dtype=torch.int64)
887                 nb_missed = 0
888
889                 values_input = expr.extract_results([self.seq2str(s) for s in input])
890                 values_result = expr.extract_results([self.seq2str(s) for s in result])
891
892                 for i, r in zip(values_input, values_result):
893                     for n, vi in i.items():
894                         vr = r.get(n)
895                         if vr is None or vr < 0:
896                             nb_missed += 1
897                         else:
898                             d = abs(vr - vi)
899                             if d >= nb_delta.size(0):
900                                 nb_missed += 1
901                             else:
902                                 nb_delta[d] += 1
903
904                 ######################################################################
905
906                 return nb_total, nb_correct, nb_delta, nb_missed
907
908             (
909                 test_nb_total,
910                 test_nb_correct,
911                 test_nb_delta,
912                 test_nb_missed,
913             ) = compute_nb_correct(self.test_input[:1000])
914
915             logger(
916                 f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
917             )
918
919             nb_total = test_nb_delta.sum() + test_nb_missed
920             for d in range(test_nb_delta.size(0)):
921                 logger(
922                     f"error_value {n_epoch} delta {d} {test_nb_delta[d]} {test_nb_delta[d]*100/nb_total:.02f}%"
923                 )
924             logger(
925                 f"error_value {n_epoch} missed {test_nb_missed} {test_nb_missed*100/nb_total:.02f}%"
926             )
927
928             ##############################################################
929             # Log a few generated sequences
930             input = self.test_input[:10]
931             result = input.clone()
932             ar_mask = (result == self.space).long().cumsum(dim=1).clamp(max=1)
933             result = (1 - ar_mask) * result + ar_mask * self.filler
934             for n in range(result.size(0)):
935                 logger(f"test_before {self.seq2str(result[n])}")
936                 masked_inplace_autoregression(
937                     model,
938                     self.batch_size,
939                     result,
940                     ar_mask,
941                     deterministic_synthesis,
942                     device=self.device,
943                 )
944             correct = (1 - ar_mask) * self.space + ar_mask * input
945             for n in range(result.size(0)):
946                 comment = "GOOD" if (result[n] - input[n]).abs().max() == 0 else ""
947                 logger(f"test_after  {self.seq2str(result[n])} {comment}")
948                 logger(f"correct     {self.seq2str(correct[n])}")
949             ##############################################################
950
951             model.train(t)
952
953
954 ######################################################################