Update.
[picoclvr.git] / tasks.py
1 #!/usr/bin/env python
2
3 import math, os, tqdm
4
5 import torch, torchvision
6
7 from torch import nn
8 from torch.nn import functional as F
9
10 ######################################################################
11
12
13 def masked_inplace_autoregression(
14     model,
15     batch_size,
16     input,
17     ar_mask,
18     deterministic_synthesis,
19     forbidden_tokens=None,
20     progress_bar_desc="autoregression",
21     device=torch.device("cpu"),
22 ):
23     batches = zip(input.split(batch_size), ar_mask.split(batch_size))
24
25     if progress_bar_desc is not None:
26         batches = tqdm.tqdm(
27             batches,
28             dynamic_ncols=True,
29             desc=progress_bar_desc,
30             total=input.size(0) // batch_size,
31         )
32
33     with torch.autograd.no_grad():
34         t = model.training
35         model.eval()
36
37         for input, ar_mask in batches:
38             model.masked_inplace_autoregression(
39                 input, ar_mask, forbidden_tokens, deterministic_synthesis
40             )
41
42         model.train(t)
43
44
45 ######################################################################
46
47
48 class Task:
49     def batches(self, split="train"):
50         pass
51
52     def vocabulary_size(self):
53         pass
54
55     def produce_results(
56         self, n_epoch, model, result_dir, logger, deterministic_synthesis
57     ):
58         pass
59
60
61 ######################################################################
62
63 import picoclvr
64
65
66 class PicoCLVR(Task):
67     # Make a tensor from a list of strings
68     def tensorize(self, descr):
69         token_descr = [s.strip().split(" ") for s in descr]
70         l = max([len(s) for s in token_descr])
71         token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
72         id_descr = [[self.token2id[u] for u in s] for s in token_descr]
73         return torch.tensor(id_descr, device=self.device)
74
75     # Make a list of strings from a tensor
76     def detensorize(self, x):
77         return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
78
79     # trim all the tensors in the tuple z to remove as much token from
80     # left and right in the first tensor. If z is a tuple, all its
81     # elements are trimed according to the triming for the first
82     def trim(self, z, token="<nul>"):
83         n = self.token2id[token]
84         if type(z) == tuple:
85             x = z[0]
86             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
87             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
88             return tuple([t[:, a:b] for t in z])
89         else:
90             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
91             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
92             return z[:, a:b]
93
94     ######################
95
96     def __init__(
97         self,
98         nb_train_samples,
99         nb_test_samples,
100         batch_size,
101         height,
102         width,
103         nb_colors=5,
104         logger=None,
105         device=torch.device("cpu"),
106         pruner_train=None,
107         pruner_eval=None,
108     ):
109         def generate_descr(nb, cache_suffix, pruner):
110             return picoclvr.generate(
111                 nb,
112                 height=self.height,
113                 width=self.width,
114                 nb_colors=nb_colors,
115                 pruner=pruner,
116             )
117
118         self.height = height
119         self.width = width
120         self.batch_size = batch_size
121         self.device = device
122         self.pruner_train = pruner_train
123         self.pruner_eval = pruner_eval
124
125         if logger is not None:
126             logger(
127                 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
128             )
129
130         self.train_descr = generate_descr(
131             nb_train_samples, "train", pruner=self.pruner_train
132         )
133         self.test_descr = generate_descr(nb_test_samples, "test", pruner=None)
134
135         # Build the tokenizer
136         tokens = {"<nul>", "<img>"}
137         for d in [self.train_descr, self.test_descr]:
138             for s in d:
139                 for t in s.strip().split(" "):
140                     tokens.add(t)
141         # make this set a sorted list to get the same tensors given
142         # the same descr
143         tokens = list(tokens)
144         tokens.sort()
145         self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
146         self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
147         self.t_img, self.t_nul = self.token2id["<img>"], self.token2id["<nul>"]
148
149         # Tokenize the train and test sets
150         self.train_input = self.tensorize(self.train_descr)
151         self.test_input = self.tensorize(self.test_descr)
152
153     def batches(self, split="train"):
154         assert split in {"train", "test"}
155         input = self.train_input if split == "train" else self.test_input
156         for batch in tqdm.tqdm(
157             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
158         ):
159             yield self.trim(batch)
160
161     def vocabulary_size(self):
162         return len(self.token2id)
163
164     def compute_missing_properties(
165         self, n_epoch, model, logger, deterministic_synthesis, pruner=None
166     ):
167         acc_nb_requested_properties = []
168         acc_nb_missing_properties = []
169         acc_nb_results = 0
170
171         for input in tqdm.tqdm(
172             self.test_input.split(self.batch_size),
173             dynamic_ncols=True,
174             desc=f"test-properties",
175         ):
176             result = input.clone()
177             ar_mask = (result == self.t_img).long().cumsum(dim=1).clamp(max=1)
178             result = (1 - ar_mask) * result + ar_mask * self.t_nul
179             masked_inplace_autoregression(
180                 model,
181                 self.batch_size,
182                 result,
183                 ar_mask,
184                 deterministic_synthesis,
185                 progress_bar_desc=None,
186                 device=self.device,
187             )
188
189             result_descr = self.detensorize(result)
190             np = picoclvr.nb_properties(
191                 result_descr,
192                 height=self.height,
193                 width=self.width,
194                 pruner=pruner,
195             )
196             nb_requested_properties, _, nb_missing_properties = zip(*np)
197             acc_nb_requested_properties += nb_requested_properties
198             acc_nb_missing_properties += nb_missing_properties
199             acc_nb_results += len(result_descr)
200
201         nb_requested_properties = sum(acc_nb_requested_properties)
202         nb_missing_properties = sum(acc_nb_missing_properties)
203
204         prefix = "" if pruner is None else "pruned_"
205         logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
206         logger(
207             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
208         )
209         logger(
210             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
211         )
212
213     ######################################################################
214
215     def produce_results(
216         self, n_epoch, model, result_dir, logger, deterministic_synthesis
217     ):
218         self.compute_missing_properties(n_epoch, model, logger, deterministic_synthesis)
219
220         if self.pruner_eval is not None:
221             self.compute_missing_properties(n_epoch, model, self.pruner_eval)
222
223         nb_tokens_to_generate = self.height * self.width + 3
224         result_descr = []
225         nb_per_primer = 8
226         primer = []
227
228         for primer_descr in [
229             "red above green <sep> green top <sep> blue right of red",
230             "there is red <sep> there is yellow <sep> there is blue",
231             "red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left",
232             "green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top",
233         ]:
234             primer += [primer_descr + " <img>"] * nb_per_primer
235
236         result = self.tensorize(primer)
237         fill = result.new_full(
238             result.size()[:-1] + (self.height * self.width + 1,), self.t_nul
239         )
240         result = torch.cat((result, fill), 1)
241         ar_mask = (result == self.t_nul).long()
242         masked_inplace_autoregression(
243             model,
244             self.batch_size,
245             result,
246             ar_mask,
247             deterministic_synthesis,
248             device=self.device,
249         )
250         result_descr = self.detensorize(result)
251
252         np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
253
254         acc_nb_requested_properties, _, acc_nb_missing_properties = zip(*np)
255         acc_nb_results = len(result_descr)
256
257         nb_requested_properties = sum(acc_nb_requested_properties)
258         nb_missing_properties = sum(acc_nb_missing_properties)
259
260         prefix = "demo_"
261         logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
262         logger(
263             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
264         )
265         logger(
266             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
267         )
268
269         img = picoclvr.descr2img(result_descr, height=self.height, width=self.width)
270
271         if img.dim() == 5:
272             if img.size(1) == 1:
273                 img = F.pad(img.squeeze(1), pad=(1, 1, 1, 1), value=64)
274             else:
275                 img = torch.cat(
276                     [
277                         torchvision.utils.make_grid(x, padding=1, pad_value=64)[None]
278                         for x in img
279                     ],
280                     0,
281                 )
282
283         image_name = os.path.join(result_dir, f"picoclvr_result_{n_epoch:04d}.png")
284         torchvision.utils.save_image(
285             img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=0.0
286         )
287         logger(f"wrote {image_name}")
288
289
290 ######################################################################
291
292
293 class MNIST(Task):
294     def __init__(
295         self, nb_train_samples, nb_test_samples, batch_size, device=torch.device("cpu")
296     ):
297         self.nb_train_samples = (nb_train_samples,)
298         self.nb_test_samples = (nb_test_samples,)
299         self.batch_size = batch_size
300         self.device = device
301         data_set = torchvision.datasets.MNIST(root="./data", train=True, download=True)
302         self.train_input = data_set.data[:nb_train_samples].view(-1, 28 * 28).long()
303         data_set = torchvision.datasets.MNIST(root="./data", train=False, download=True)
304         self.test_input = data_set.data[:nb_test_samples].view(-1, 28 * 28).long()
305
306     def batches(self, split="train", nb_to_use=-1, desc=None):
307         assert split in {"train", "test"}
308         input = self.train_input if split == "train" else self.test_input
309         if nb_to_use > 0:
310             input = input[:nb_to_use]
311         if desc is None:
312             desc = f"epoch-{split}"
313         for batch in tqdm.tqdm(
314             input.split(self.batch_size), dynamic_ncols=True, desc=desc
315         ):
316             yield batch
317
318     def vocabulary_size(self):
319         return 256
320
321     def produce_results(
322         self, n_epoch, model, result_dir, logger, deterministic_synthesis
323     ):
324         results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64)
325         ar_mask = torch.full_like(results, 1)
326         masked_inplace_autoregression(
327             model,
328             self.batch_size,
329             results,
330             ar_mask,
331             deterministic_synthesis,
332             device=self.device,
333         )
334         image_name = os.path.join(result_dir, f"mnist_result_{n_epoch:04d}.png")
335         torchvision.utils.save_image(
336             1 - results.reshape(-1, 1, 28, 28) / 255.0,
337             image_name,
338             nrow=16,
339             pad_value=0.8,
340         )
341         logger(f"wrote {image_name}")
342
343
344 ######################################################################
345
346 import maze
347
348
349 class Maze(Task):
350     def map2seq(self, *m):
351         return torch.cat([x.flatten(1) for x in m], 1)
352
353     def seq2map(self, s):
354         s = s.reshape(s.size(0), -1, self.height, self.width)
355         return (s[:, k] for k in range(s.size(1)))
356
357     def __init__(
358         self,
359         nb_train_samples,
360         nb_test_samples,
361         batch_size,
362         height,
363         width,
364         nb_walls,
365         device=torch.device("cpu"),
366     ):
367         self.batch_size = batch_size
368         self.height = height
369         self.width = width
370         self.device = device
371
372         train_mazes, train_paths, _ = maze.create_maze_data(
373             nb_train_samples,
374             height=height,
375             width=width,
376             nb_walls=nb_walls,
377             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"),
378         )
379         self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device))
380
381         test_mazes, test_paths, _ = maze.create_maze_data(
382             nb_test_samples,
383             height=height,
384             width=width,
385             nb_walls=nb_walls,
386             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"),
387         )
388         self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
389
390         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
391
392     def batches(self, split="train", nb_to_use=-1, desc=None):
393         assert split in {"train", "test"}
394         input = self.train_input if split == "train" else self.test_input
395         if nb_to_use > 0:
396             input = input[:nb_to_use]
397         if desc is None:
398             desc = f"epoch-{split}"
399         for batch in tqdm.tqdm(
400             input.split(self.batch_size), dynamic_ncols=True, desc=desc
401         ):
402             yield batch
403
404     def vocabulary_size(self):
405         return self.nb_codes
406
407     def compute_error(
408         self, model, split="train", nb_to_use=-1, deterministic_synthesis=False
409     ):
410         nb_total, nb_correct = 0, 0
411         count = torch.zeros(
412             self.width * self.height,
413             self.width * self.height,
414             device=self.device,
415             dtype=torch.int64,
416         )
417
418         for input in self.batches(split, nb_to_use):
419             result = input.clone()
420             ar_mask = result.new_zeros(result.size())
421             ar_mask[:, self.height * self.width :] = 1
422             result *= 1 - ar_mask
423             masked_inplace_autoregression(
424                 model,
425                 self.batch_size,
426                 result,
427                 ar_mask,
428                 deterministic_synthesis,
429                 progress_bar_desc=None,
430                 device=self.device,
431             )
432             mazes, paths = self.seq2map(result)
433             path_correctness = maze.path_correctness(mazes, paths)
434             nb_correct += path_correctness.long().sum()
435             nb_total += mazes.size(0)
436
437             optimal_path_lengths = (
438                 (input[:, self.height * self.width :] == maze.v_path).long().sum(1)
439             )
440             predicted_path_lengths = (
441                 (result[:, self.height * self.width :] == maze.v_path).long().sum(1)
442             )
443             optimal_path_lengths = optimal_path_lengths[path_correctness]
444             predicted_path_lengths = predicted_path_lengths[path_correctness]
445             count[optimal_path_lengths, predicted_path_lengths] += 1
446
447         if count.max() == 0:
448             count = None
449         else:
450             count = count[
451                 : count.sum(1).nonzero().max() + 1, : count.sum(0).nonzero().max() + 1
452             ]
453
454         return nb_total, nb_correct, count
455
456     def produce_results(
457         self, n_epoch, model, result_dir, logger, deterministic_synthesis
458     ):
459         train_nb_total, train_nb_correct, count = self.compute_error(
460             model,
461             "train",
462             nb_to_use=1000,
463             deterministic_synthesis=deterministic_synthesis,
464         )
465         logger(
466             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
467         )
468
469         test_nb_total, test_nb_correct, count = self.compute_error(
470             model,
471             "test",
472             nb_to_use=1000,
473             deterministic_synthesis=deterministic_synthesis,
474         )
475         logger(
476             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
477         )
478
479         if count is not None:
480             proportion_optimal = count.diagonal().sum().float() / count.sum()
481             logger(f"proportion_optimal_test {proportion_optimal*100:.02f}%")
482             with open(
483                 os.path.join(result_dir, f"maze_result_{n_epoch:04d}.txt"), "w"
484             ) as f:
485                 for i in range(count.size(0)):
486                     for j in range(count.size(1)):
487                         eol = " " if j < count.size(1) - 1 else "\n"
488                         f.write(f"{count[i,j]}{eol}")
489
490         input = self.test_input[:48]
491         result = input.clone()
492         ar_mask = result.new_zeros(result.size())
493         ar_mask[:, self.height * self.width :] = 1
494         result *= 1 - ar_mask
495         masked_inplace_autoregression(
496             model,
497             self.batch_size,
498             result,
499             ar_mask,
500             deterministic_synthesis,
501             device=self.device,
502         )
503
504         mazes, paths = self.seq2map(input)
505         _, predicted_paths = self.seq2map(result)
506
507         filename = os.path.join(result_dir, f"maze_result_{n_epoch:04d}.png")
508         maze.save_image(
509             filename,
510             mazes=mazes,
511             target_paths=paths,
512             predicted_paths=predicted_paths,
513             path_correct=maze.path_correctness(mazes, predicted_paths),
514             path_optimal=maze.path_optimality(paths, predicted_paths),
515         )
516         logger(f"wrote {filename}")
517
518
519 ######################################################################
520
521
522 import snake
523
524
525 class Snake(Task):
526     def __init__(
527         self,
528         nb_train_samples,
529         nb_test_samples,
530         batch_size,
531         height,
532         width,
533         nb_colors,
534         length,
535         prompt_length,
536         device=torch.device("cpu"),
537     ):
538         self.batch_size = batch_size
539         self.height = height
540         self.width = width
541         self.device = device
542         self.prompt_length = prompt_length
543
544         self.train_input, self.train_prior_visits, _, _ = snake.generate_sequences(
545             nb_train_samples,
546             height,
547             width,
548             nb_colors,
549             length,
550             prompt_length,
551             self.device,
552         )
553         self.test_input, self.test_prior_visits, _, _ = snake.generate_sequences(
554             nb_test_samples,
555             height,
556             width,
557             nb_colors,
558             length,
559             prompt_length,
560             self.device,
561         )
562
563         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
564
565     def batches(self, split="train", nb_to_use=-1, desc=None):
566         assert split in {"train", "test"}
567         input = self.train_input if split == "train" else self.test_input
568         if nb_to_use > 0:
569             input = input[:nb_to_use]
570         if desc is None:
571             desc = f"epoch-{split}"
572         for batch in tqdm.tqdm(
573             input.split(self.batch_size), dynamic_ncols=True, desc=desc
574         ):
575             yield batch
576
577     def vocabulary_size(self):
578         return self.nb_codes
579
580     def produce_results(
581         self, n_epoch, model, result_dir, logger, deterministic_synthesis
582     ):
583         def compute_nb_correct(input, prior_visits):
584             result = input.clone()
585             i = torch.arange(result.size(1), device=result.device)[None, :]
586             ar_mask = (
587                 torch.logical_and(i >= self.prompt_length * 2, i % 2 == 0)
588                 .long()
589                 .expand_as(result)
590             )
591             result *= 1 - ar_mask
592
593             # snake.solver(result,ar_mask)
594
595             masked_inplace_autoregression(
596                 model,
597                 self.batch_size,
598                 result,
599                 ar_mask,
600                 deterministic_synthesis,
601                 device=self.device,
602             )
603
604             nb_total = ((prior_visits > 0) * ar_mask).sum()
605
606             nb_correct = ((result == input).long() * (prior_visits > 0) * ar_mask).sum()
607
608             # nb_total = result.size(0)
609             # nb_correct = ((result - input).abs().sum(1) == 0).sum()
610
611             return nb_total, nb_correct
612
613         # train_nb_total, train_nb_correct = compute_nb_correct(
614         # self.train_input, self.train_prior_visits
615         # )
616
617         # logger(
618         # f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
619         # )
620
621         test_nb_total, test_nb_correct = compute_nb_correct(
622             self.test_input[:1000], self.test_prior_visits[:1000]
623         )
624
625         logger(
626             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
627         )
628
629
630 ######################################################################
631
632
633 import stack
634
635
636 class Stack(Task):
637     def __init__(
638         self,
639         nb_train_samples,
640         nb_test_samples,
641         batch_size,
642         logger,
643         nb_steps,
644         nb_stacks,
645         nb_digits,
646         fraction_values_for_train=None,
647         device=torch.device("cpu"),
648     ):
649         self.batch_size = batch_size
650         self.nb_steps = nb_steps
651         self.nb_stacks = nb_stacks
652         self.nb_digits = nb_digits
653         self.device = device
654
655         if fraction_values_for_train is None:
656             values_for_train = None
657             values_for_test = None
658         else:
659             all = torch.randperm(10**nb_digits)
660             nb_for_train = int(all.size(0) * fraction_values_for_train)
661             values_for_train = all[:nb_for_train]
662             values_for_test = all[nb_for_train:]
663
664         self.train_input, self.train_stack_counts = stack.generate_sequences(
665             nb_train_samples,
666             nb_steps,
667             nb_stacks,
668             nb_digits,
669             values_for_train,
670             self.device,
671         )
672
673         self.test_input, self.test_stack_counts = stack.generate_sequences(
674             nb_test_samples,
675             nb_steps,
676             nb_stacks,
677             nb_digits,
678             values_for_test,
679             self.device,
680         )
681
682         i = torch.logical_and(self.test_input % 2 == 1, self.test_input < 2 * nb_stacks)
683         counts = self.test_stack_counts.flatten()[i.flatten()]
684         counts = F.one_hot(counts).sum(0)
685         logger(f"test_pop_stack_counts {counts}")
686
687         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
688
689     def batches(self, split="train", nb_to_use=-1, desc=None):
690         assert split in {"train", "test"}
691         input = self.train_input if split == "train" else self.test_input
692         if nb_to_use > 0:
693             input = input[:nb_to_use]
694         if desc is None:
695             desc = f"epoch-{split}"
696         for batch in tqdm.tqdm(
697             input.split(self.batch_size), dynamic_ncols=True, desc=desc
698         ):
699             yield batch
700
701     def vocabulary_size(self):
702         return self.nb_codes
703
704     def produce_results(
705         self, n_epoch, model, result_dir, logger, deterministic_synthesis
706     ):
707         def compute_nb_correct(input):
708             result = input.clone()
709             stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
710             ar_mask = (result != input).long()
711             masked_inplace_autoregression(
712                 model,
713                 self.batch_size,
714                 result,
715                 ar_mask,
716                 deterministic_synthesis,
717                 device=self.device,
718             )
719
720             errors = ((result != input).long() * ar_mask).reshape(
721                 -1, 1 + self.nb_digits
722             )
723             ar_mask = ar_mask.reshape(-1, 1 + self.nb_digits)
724
725             nb_total = ar_mask.max(1).values.sum()
726             nb_correct = nb_total - errors.max(1).values.sum()
727
728             return nb_total, nb_correct
729
730         test_nb_total, test_nb_correct = compute_nb_correct(self.test_input[:1000])
731
732         logger(
733             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
734         )
735
736         ##############################################################
737         # Log a few generated sequences
738         input = self.test_input[:10, : 12 * (1 + self.nb_digits)]
739         result = input.clone()
740         stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
741         ar_mask = (result != input).long()
742
743         # for n in range(result.size(0)):
744         # logger(
745         # f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
746         # )
747
748         masked_inplace_autoregression(
749             model,
750             self.batch_size,
751             result,
752             ar_mask,
753             deterministic_synthesis,
754             device=self.device,
755         )
756
757         for n in range(result.size(0)):
758             logger(
759                 f"test_after  {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
760             )
761         ##############################################################
762
763
764 ######################################################################
765
766
767 import expr
768
769
770 class Expr(Task):
771     def tensorize(self, sequences):
772         len_max = max([len(x) for x in sequences])
773         return torch.cat(
774             [
775                 torch.tensor(
776                     [
777                         [self.char2id[c] for c in s + "#" * (len_max - len(s))]
778                         for s in sequences
779                     ]
780                 )
781             ],
782             0,
783         ).to(self.device)
784
785     def __init__(
786         self,
787         nb_train_samples,
788         nb_test_samples,
789         nb_variables,
790         sequence_length,
791         operand_max,
792         result_max,
793         batch_size,
794         device=torch.device("cpu"),
795     ):
796         self.batch_size = batch_size
797         self.device = device
798
799         train_sequences = expr.generate_sequences(
800             nb_train_samples,
801             nb_variables=nb_variables,
802             length=sequence_length,
803             operand_max=operand_max,
804             result_max=result_max,
805         )
806
807         test_sequences = expr.generate_sequences(
808             nb_test_samples,
809             nb_variables=nb_variables,
810             length=sequence_length,
811             operand_max=operand_max,
812             result_max=result_max,
813         )
814
815         symbols = list(set("#" + "".join(train_sequences + test_sequences)))
816         symbols.sort()
817
818         self.char2id = dict([(c, n) for n, c in enumerate(symbols)])
819         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
820
821         self.filler, self.space = self.char2id["#"], self.char2id[" "]
822
823         self.train_input = self.tensorize(train_sequences)
824         self.test_input = self.tensorize(test_sequences)
825
826         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
827
828     def batches(self, split="train", nb_to_use=-1, desc=None):
829         assert split in {"train", "test"}
830         input = self.train_input if split == "train" else self.test_input
831         if nb_to_use > 0:
832             input = input[:nb_to_use]
833         if desc is None:
834             desc = f"epoch-{split}"
835         for batch in tqdm.tqdm(
836             input.split(self.batch_size), dynamic_ncols=True, desc=desc
837         ):
838             last = (batch != self.filler).max(0).values.nonzero().max() + 3
839             batch = batch[:, :last]
840             yield batch
841
842     def vocabulary_size(self):
843         return self.nb_codes
844
845     def seq2str(self, s):
846         return "".join([self.id2char[k.item()] for k in s])
847
848     def produce_results(
849         self,
850         n_epoch,
851         model,
852         result_dir,
853         logger,
854         deterministic_synthesis,
855         input_file=None,
856     ):
857         def compute_nb_correct(input):
858             result = input.clone()
859             s = (result == self.space).long()
860             ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
861             result = (1 - ar_mask) * result + ar_mask * self.filler
862             masked_inplace_autoregression(
863                 model,
864                 self.batch_size,
865                 result,
866                 ar_mask,
867                 deterministic_synthesis,
868                 device=self.device,
869             )
870
871             nb_total = input.size(0)
872             nb_correct = (input == result).long().min(1).values.sum()
873
874             #######################################################################
875             # Comput predicted vs. true variable values
876
877             nb_delta = torch.zeros(5, dtype=torch.int64)
878             nb_missed = 0
879
880             values_input = expr.extract_results([self.seq2str(s) for s in input])
881             values_result = expr.extract_results([self.seq2str(s) for s in result])
882
883             for i, r in zip(values_input, values_result):
884                 for n, vi in i.items():
885                     vr = r.get(n)
886                     if vr is None or vr < 0:
887                         nb_missed += 1
888                     else:
889                         d = abs(vr - vi)
890                         if d >= nb_delta.size(0):
891                             nb_missed += 1
892                         else:
893                             nb_delta[d] += 1
894
895             ######################################################################
896
897             return nb_total, nb_correct, nb_delta, nb_missed
898
899         (
900             test_nb_total,
901             test_nb_correct,
902             test_nb_delta,
903             test_nb_missed,
904         ) = compute_nb_correct(self.test_input[:10000])
905
906         logger(
907             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
908         )
909
910         nb_total = test_nb_delta.sum() + test_nb_missed
911         for d in range(test_nb_delta.size(0)):
912             logger(
913                 f"error_value {n_epoch} delta {d} {test_nb_delta[d]} {test_nb_delta[d]*100/nb_total:.02f}%"
914             )
915         logger(
916             f"error_value {n_epoch} missed {test_nb_missed} {test_nb_missed*100/nb_total:.02f}%"
917         )
918
919         ##############################################################
920         # Log a few generated sequences
921         if input_file is None:
922             input = self.test_input[:10]
923         else:
924             with open(input_file, "r") as f:
925                 sequences = [e.strip() for e in f.readlines()]
926                 sequences = [s + " " + "#" * 50 for s in sequences]
927                 input = self.tensorize(sequences)
928
929         result = input.clone()
930         s = (result == self.space).long()
931         ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
932         result = (1 - ar_mask) * result + ar_mask * self.filler
933
934         for n in range(result.size(0)):
935             logger(f"test_before {self.seq2str(result[n])}")
936
937         masked_inplace_autoregression(
938             model,
939             self.batch_size,
940             result,
941             ar_mask,
942             deterministic_synthesis,
943             device=self.device,
944         )
945
946         correct = (1 - ar_mask) * self.space + ar_mask * input
947         for n in range(result.size(0)):
948             comment = "GOOD" if (result[n] - input[n]).abs().max() == 0 else ""
949             logger(f"test_after  {self.seq2str(result[n])} {comment}")
950             logger(f"truth       {self.seq2str(correct[n])}")
951         ##############################################################
952
953
954 ######################################################################