Update.
[picoclvr.git] / tasks.py
1 #!/usr/bin/env python
2
3 import math, os, tqdm
4
5 import torch, torchvision
6
7 from torch import nn
8 from torch.nn import functional as F
9
10 ######################################################################
11
12
13 def masked_inplace_autoregression(
14     model,
15     batch_size,
16     input,
17     ar_mask,
18     deterministic_synthesis,
19     forbidden_tokens=None,
20     progress_bar_desc="autoregression",
21     device=torch.device("cpu"),
22 ):
23     batches = zip(input.split(batch_size), ar_mask.split(batch_size))
24
25     if progress_bar_desc is not None:
26         batches = tqdm.tqdm(
27             batches,
28             dynamic_ncols=True,
29             desc=progress_bar_desc,
30             total=input.size(0) // batch_size,
31         )
32
33     with torch.autograd.no_grad():
34         t = model.training
35         model.eval()
36
37         for input, ar_mask in batches:
38             model.masked_inplace_autoregression(
39                 input, ar_mask, forbidden_tokens, deterministic_synthesis
40             )
41
42         model.train(t)
43
44
45 ######################################################################
46
47
48 class Task:
49     def batches(self, split="train"):
50         pass
51
52     def vocabulary_size(self):
53         pass
54
55     def produce_results(
56         self, n_epoch, model, result_dir, logger, deterministic_synthesis
57     ):
58         pass
59
60
61 ######################################################################
62
63 import picoclvr
64
65
66 class PicoCLVR(Task):
67     # Make a tensor from a list of strings
68     def tensorize(self, descr):
69         token_descr = [s.strip().split(" ") for s in descr]
70         l = max([len(s) for s in token_descr])
71         token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
72         id_descr = [[self.token2id[u] for u in s] for s in token_descr]
73         return torch.tensor(id_descr, device=self.device)
74
75     # Make a list of strings from a tensor
76     def detensorize(self, x):
77         return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
78
79     # trim all the tensors in the tuple z to remove as much token from
80     # left and right in the first tensor. If z is a tuple, all its
81     # elements are trimed according to the triming for the first
82     def trim(self, z, token="<nul>"):
83         n = self.token2id[token]
84         if type(z) == tuple:
85             x = z[0]
86             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
87             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
88             return tuple([t[:, a:b] for t in z])
89         else:
90             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
91             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
92             return z[:, a:b]
93
94     ######################
95
96     def __init__(
97         self,
98         nb_train_samples,
99         nb_test_samples,
100         batch_size,
101         height,
102         width,
103         nb_colors=5,
104         logger=None,
105         device=torch.device("cpu"),
106         pruner_train=None,
107         pruner_eval=None,
108     ):
109         def generate_descr(nb, cache_suffix, pruner):
110             return picoclvr.generate(
111                 nb,
112                 height=self.height,
113                 width=self.width,
114                 nb_colors=nb_colors,
115                 pruner=pruner,
116             )
117
118         self.height = height
119         self.width = width
120         self.batch_size = batch_size
121         self.device = device
122         self.pruner_train = pruner_train
123         self.pruner_eval = pruner_eval
124
125         if logger is not None:
126             logger(
127                 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
128             )
129
130         self.train_descr = generate_descr(
131             nb_train_samples, "train", pruner=self.pruner_train
132         )
133         self.test_descr = generate_descr(nb_test_samples, "test", pruner=None)
134
135         # Build the tokenizer
136         tokens = {"<nul>", "<img>"}
137         for d in [self.train_descr, self.test_descr]:
138             for s in d:
139                 for t in s.strip().split(" "):
140                     tokens.add(t)
141         # make this set a sorted list to get the same tensors given
142         # the same descr
143         tokens = list(tokens)
144         tokens.sort()
145         self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
146         self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
147         self.t_img, self.t_nul = self.token2id["<img>"], self.token2id["<nul>"]
148
149         # Tokenize the train and test sets
150         self.train_input = self.tensorize(self.train_descr)
151         self.test_input = self.tensorize(self.test_descr)
152
153     def batches(self, split="train"):
154         assert split in {"train", "test"}
155         input = self.train_input if split == "train" else self.test_input
156         for batch in tqdm.tqdm(
157             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
158         ):
159             yield self.trim(batch)
160
161     def vocabulary_size(self):
162         return len(self.token2id)
163
164     def compute_missing_properties(
165         self, n_epoch, model, logger, deterministic_synthesis, pruner=None
166     ):
167         acc_nb_requested_properties = []
168         acc_nb_missing_properties = []
169         acc_nb_results = 0
170
171         for input in tqdm.tqdm(
172             self.test_input.split(self.batch_size),
173             dynamic_ncols=True,
174             desc=f"test-properties",
175         ):
176             result = input.clone()
177             ar_mask = (result == self.t_img).long().cumsum(dim=1).clamp(max=1)
178             result = (1 - ar_mask) * result + ar_mask * self.t_nul
179             masked_inplace_autoregression(
180                 model,
181                 self.batch_size,
182                 result,
183                 ar_mask,
184                 deterministic_synthesis,
185                 progress_bar_desc=None,
186                 device=self.device,
187             )
188
189             result_descr = self.detensorize(result)
190             np = picoclvr.nb_properties(
191                 result_descr,
192                 height=self.height,
193                 width=self.width,
194                 pruner=pruner,
195             )
196             nb_requested_properties, _, nb_missing_properties = zip(*np)
197             acc_nb_requested_properties += nb_requested_properties
198             acc_nb_missing_properties += nb_missing_properties
199             acc_nb_results += len(result_descr)
200
201         nb_requested_properties = sum(acc_nb_requested_properties)
202         nb_missing_properties = sum(acc_nb_missing_properties)
203
204         prefix = "" if pruner is None else "pruned_"
205         logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
206         logger(
207             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
208         )
209         logger(
210             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
211         )
212
213     ######################################################################
214
215     def produce_results(
216         self, n_epoch, model, result_dir, logger, deterministic_synthesis
217     ):
218         self.compute_missing_properties(n_epoch, model, logger, deterministic_synthesis)
219
220         if self.pruner_eval is not None:
221             self.compute_missing_properties(n_epoch, model, self.pruner_eval)
222
223         nb_tokens_to_generate = self.height * self.width + 3
224         result_descr = []
225         nb_per_primer = 8
226         primer = []
227
228         for primer_descr in [
229             "red above green <sep> green top <sep> blue right of red",
230             "there is red <sep> there is yellow <sep> there is blue",
231             "red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left",
232             "green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top",
233         ]:
234             primer += [primer_descr + " <img>"] * nb_per_primer
235
236         result = self.tensorize(primer)
237         fill = result.new_full(
238             result.size()[:-1] + (self.height * self.width + 1,), self.t_nul
239         )
240         result = torch.cat((result, fill), 1)
241         ar_mask = (result == self.t_nul).long()
242         masked_inplace_autoregression(
243             model,
244             self.batch_size,
245             result,
246             ar_mask,
247             deterministic_synthesis,
248             device=self.device,
249         )
250         result_descr = self.detensorize(result)
251
252         np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
253
254         acc_nb_requested_properties, _, acc_nb_missing_properties = zip(*np)
255         acc_nb_results = len(result_descr)
256
257         nb_requested_properties = sum(acc_nb_requested_properties)
258         nb_missing_properties = sum(acc_nb_missing_properties)
259
260         prefix = "demo_"
261         logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
262         logger(
263             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
264         )
265         logger(
266             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
267         )
268
269         img = picoclvr.descr2img(result_descr, height=self.height, width=self.width)
270
271         if img.dim() == 5:
272             if img.size(1) == 1:
273                 img = F.pad(img.squeeze(1), pad=(1, 1, 1, 1), value=64)
274             else:
275                 img = torch.cat(
276                     [
277                         torchvision.utils.make_grid(x, padding=1, pad_value=64)[None]
278                         for x in img
279                     ],
280                     0,
281                 )
282
283         image_name = os.path.join(result_dir, f"picoclvr_result_{n_epoch:04d}.png")
284         torchvision.utils.save_image(
285             img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=0.0
286         )
287         logger(f"wrote {image_name}")
288
289
290 ######################################################################
291
292
293 class MNIST(Task):
294     def __init__(
295         self, nb_train_samples, nb_test_samples, batch_size, device=torch.device("cpu")
296     ):
297         self.nb_train_samples = (nb_train_samples,)
298         self.nb_test_samples = (nb_test_samples,)
299         self.batch_size = batch_size
300         self.device = device
301         data_set = torchvision.datasets.MNIST(root="./data", train=True, download=True)
302         self.train_input = data_set.data[:nb_train_samples].view(-1, 28 * 28).long()
303         data_set = torchvision.datasets.MNIST(root="./data", train=False, download=True)
304         self.test_input = data_set.data[:nb_test_samples].view(-1, 28 * 28).long()
305
306     def batches(self, split="train", nb_to_use=-1, desc=None):
307         assert split in {"train", "test"}
308         input = self.train_input if split == "train" else self.test_input
309         if nb_to_use > 0:
310             input = input[:nb_to_use]
311         if desc is None:
312             desc = f"epoch-{split}"
313         for batch in tqdm.tqdm(
314             input.split(self.batch_size), dynamic_ncols=True, desc=desc
315         ):
316             yield batch
317
318     def vocabulary_size(self):
319         return 256
320
321     def produce_results(
322         self, n_epoch, model, result_dir, logger, deterministic_synthesis
323     ):
324         results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64)
325         ar_mask = torch.full_like(results, 1)
326         masked_inplace_autoregression(
327             model,
328             self.batch_size,
329             results,
330             ar_mask,
331             deterministic_synthesis,
332             device=self.device,
333         )
334         image_name = os.path.join(result_dir, f"mnist_result_{n_epoch:04d}.png")
335         torchvision.utils.save_image(
336             1 - results.reshape(-1, 1, 28, 28) / 255.0,
337             image_name,
338             nrow=16,
339             pad_value=0.8,
340         )
341         logger(f"wrote {image_name}")
342
343
344 ######################################################################
345
346 import maze
347
348
349 class Maze(Task):
350     def map2seq(self, *m):
351         return torch.cat([x.flatten(1) for x in m], 1)
352
353     def seq2map(self, s):
354         s = s.reshape(s.size(0), -1, self.height, self.width)
355         return (s[:, k] for k in range(s.size(1)))
356
357     def __init__(
358         self,
359         nb_train_samples,
360         nb_test_samples,
361         batch_size,
362         height,
363         width,
364         nb_walls,
365         device=torch.device("cpu"),
366     ):
367         self.batch_size = batch_size
368         self.height = height
369         self.width = width
370         self.device = device
371
372         train_mazes, train_paths, _ = maze.create_maze_data(
373             nb_train_samples,
374             height=height,
375             width=width,
376             nb_walls=nb_walls,
377             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"),
378         )
379         self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device))
380
381         test_mazes, test_paths, _ = maze.create_maze_data(
382             nb_test_samples,
383             height=height,
384             width=width,
385             nb_walls=nb_walls,
386             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"),
387         )
388         self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
389
390         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
391
392     def batches(self, split="train", nb_to_use=-1, desc=None):
393         assert split in {"train", "test"}
394         input = self.train_input if split == "train" else self.test_input
395         if nb_to_use > 0:
396             input = input[:nb_to_use]
397         if desc is None:
398             desc = f"epoch-{split}"
399         for batch in tqdm.tqdm(
400             input.split(self.batch_size), dynamic_ncols=True, desc=desc
401         ):
402             yield batch
403
404     def vocabulary_size(self):
405         return self.nb_codes
406
407     def compute_error(
408         self, model, split="train", nb_to_use=-1, deterministic_synthesis=False
409     ):
410         nb_total, nb_correct = 0, 0
411         count = torch.zeros(
412             self.width * self.height,
413             self.width * self.height,
414             device=self.device,
415             dtype=torch.int64,
416         )
417
418         for input in self.batches(split, nb_to_use):
419             result = input.clone()
420             ar_mask = result.new_zeros(result.size())
421             ar_mask[:, self.height * self.width :] = 1
422             result *= 1 - ar_mask
423             masked_inplace_autoregression(
424                 model,
425                 self.batch_size,
426                 result,
427                 ar_mask,
428                 deterministic_synthesis,
429                 progress_bar_desc=None,
430                 device=self.device,
431             )
432             mazes, paths = self.seq2map(result)
433             path_correctness = maze.path_correctness(mazes, paths)
434             nb_correct += path_correctness.long().sum()
435             nb_total += mazes.size(0)
436
437             optimal_path_lengths = (
438                 (input[:, self.height * self.width :] == maze.v_path).long().sum(1)
439             )
440             predicted_path_lengths = (
441                 (result[:, self.height * self.width :] == maze.v_path).long().sum(1)
442             )
443             optimal_path_lengths = optimal_path_lengths[path_correctness]
444             predicted_path_lengths = predicted_path_lengths[path_correctness]
445             count[optimal_path_lengths, predicted_path_lengths] += 1
446
447         if count.max() == 0:
448             count = None
449         else:
450             count = count[
451                 : count.sum(1).nonzero().max() + 1, : count.sum(0).nonzero().max() + 1
452             ]
453
454         return nb_total, nb_correct, count
455
456     def produce_results(
457         self, n_epoch, model, result_dir, logger, deterministic_synthesis
458     ):
459         train_nb_total, train_nb_correct, count = self.compute_error(
460             model,
461             "train",
462             nb_to_use=1000,
463             deterministic_synthesis=deterministic_synthesis,
464         )
465         logger(
466             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
467         )
468
469         test_nb_total, test_nb_correct, count = self.compute_error(
470             model,
471             "test",
472             nb_to_use=1000,
473             deterministic_synthesis=deterministic_synthesis,
474         )
475         logger(
476             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
477         )
478
479         if count is not None:
480             proportion_optimal = count.diagonal().sum().float() / count.sum()
481             logger(f"proportion_optimal_test {proportion_optimal*100:.02f}%")
482             with open(
483                 os.path.join(result_dir, f"maze_result_{n_epoch:04d}.txt"), "w"
484             ) as f:
485                 for i in range(count.size(0)):
486                     for j in range(count.size(1)):
487                         eol = " " if j < count.size(1) - 1 else "\n"
488                         f.write(f"{count[i,j]}{eol}")
489
490         input = self.test_input[:48]
491         result = input.clone()
492         ar_mask = result.new_zeros(result.size())
493         ar_mask[:, self.height * self.width :] = 1
494         result *= 1 - ar_mask
495         masked_inplace_autoregression(
496             model,
497             self.batch_size,
498             result,
499             ar_mask,
500             deterministic_synthesis,
501             device=self.device,
502         )
503
504         mazes, paths = self.seq2map(input)
505         _, predicted_paths = self.seq2map(result)
506
507         filename = os.path.join(result_dir, f"maze_result_{n_epoch:04d}.png")
508         maze.save_image(
509             filename,
510             mazes=mazes,
511             target_paths=paths,
512             predicted_paths=predicted_paths,
513             path_correct=maze.path_correctness(mazes, predicted_paths),
514             path_optimal=maze.path_optimality(paths, predicted_paths),
515         )
516         logger(f"wrote {filename}")
517
518
519 ######################################################################
520
521
522 import snake
523
524
525 class Snake(Task):
526     def __init__(
527         self,
528         nb_train_samples,
529         nb_test_samples,
530         batch_size,
531         height,
532         width,
533         nb_colors,
534         length,
535         prompt_length,
536         device=torch.device("cpu"),
537     ):
538         self.batch_size = batch_size
539         self.height = height
540         self.width = width
541         self.device = device
542         self.prompt_length = prompt_length
543
544         self.train_input, self.train_prior_visits, _, _ = snake.generate_sequences(
545             nb_train_samples,
546             height,
547             width,
548             nb_colors,
549             length,
550             prompt_length,
551             self.device,
552         )
553         self.test_input, self.test_prior_visits, _, _ = snake.generate_sequences(
554             nb_test_samples,
555             height,
556             width,
557             nb_colors,
558             length,
559             prompt_length,
560             self.device,
561         )
562
563         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
564
565     def batches(self, split="train", nb_to_use=-1, desc=None):
566         assert split in {"train", "test"}
567         input = self.train_input if split == "train" else self.test_input
568         if nb_to_use > 0:
569             input = input[:nb_to_use]
570         if desc is None:
571             desc = f"epoch-{split}"
572         for batch in tqdm.tqdm(
573             input.split(self.batch_size), dynamic_ncols=True, desc=desc
574         ):
575             yield batch
576
577     def vocabulary_size(self):
578         return self.nb_codes
579
580     def produce_results(
581         self, n_epoch, model, result_dir, logger, deterministic_synthesis
582     ):
583         def compute_nb_correct(input, prior_visits):
584             result = input.clone()
585             i = torch.arange(result.size(1), device=result.device)[None, :]
586             ar_mask = (
587                 torch.logical_and(i >= self.prompt_length * 2, i % 2 == 0)
588                 .long()
589                 .expand_as(result)
590             )
591             result *= 1 - ar_mask
592
593             masked_inplace_autoregression(
594                 model,
595                 self.batch_size,
596                 result,
597                 ar_mask,
598                 deterministic_synthesis,
599                 device=self.device,
600             )
601
602             nb_total = ((prior_visits > 0) * ar_mask).sum()
603
604             nb_correct = ((result == input).long() * (prior_visits > 0) * ar_mask).sum()
605
606             return nb_total, nb_correct
607
608         test_nb_total, test_nb_correct = compute_nb_correct(
609             self.test_input[:1000], self.test_prior_visits[:1000]
610         )
611
612         logger(
613             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
614         )
615
616
617 ######################################################################
618
619
620 import stack
621
622
623 class Stack(Task):
624     def __init__(
625         self,
626         nb_train_samples,
627         nb_test_samples,
628         batch_size,
629         logger,
630         nb_steps,
631         nb_stacks,
632         nb_digits,
633         fraction_values_for_train=None,
634         device=torch.device("cpu"),
635     ):
636         self.batch_size = batch_size
637         self.nb_steps = nb_steps
638         self.nb_stacks = nb_stacks
639         self.nb_digits = nb_digits
640         self.device = device
641
642         if fraction_values_for_train is None:
643             values_for_train = None
644             values_for_test = None
645         else:
646             all = torch.randperm(10**nb_digits)
647             nb_for_train = int(all.size(0) * fraction_values_for_train)
648             values_for_train = all[:nb_for_train]
649             values_for_test = all[nb_for_train:]
650
651         self.train_input, self.train_stack_counts = stack.generate_sequences(
652             nb_train_samples,
653             nb_steps,
654             nb_stacks,
655             nb_digits,
656             values_for_train,
657             self.device,
658         )
659
660         self.test_input, self.test_stack_counts = stack.generate_sequences(
661             nb_test_samples,
662             nb_steps,
663             nb_stacks,
664             nb_digits,
665             values_for_test,
666             self.device,
667         )
668
669         i = torch.logical_and(self.test_input % 2 == 1, self.test_input < 2 * nb_stacks)
670         counts = self.test_stack_counts.flatten()[i.flatten()]
671         counts = F.one_hot(counts).sum(0)
672         logger(f"test_pop_stack_counts {counts}")
673
674         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
675
676     def batches(self, split="train", nb_to_use=-1, desc=None):
677         assert split in {"train", "test"}
678         input = self.train_input if split == "train" else self.test_input
679         if nb_to_use > 0:
680             input = input[:nb_to_use]
681         if desc is None:
682             desc = f"epoch-{split}"
683         for batch in tqdm.tqdm(
684             input.split(self.batch_size), dynamic_ncols=True, desc=desc
685         ):
686             yield batch
687
688     def vocabulary_size(self):
689         return self.nb_codes
690
691     def produce_results(
692         self, n_epoch, model, result_dir, logger, deterministic_synthesis
693     ):
694         def compute_nb_correct(input):
695             result = input.clone()
696             stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
697             ar_mask = (result != input).long()
698             masked_inplace_autoregression(
699                 model,
700                 self.batch_size,
701                 result,
702                 ar_mask,
703                 deterministic_synthesis,
704                 device=self.device,
705             )
706
707             errors = ((result != input).long() * ar_mask).reshape(
708                 -1, 1 + self.nb_digits
709             )
710             ar_mask = ar_mask.reshape(-1, 1 + self.nb_digits)
711
712             nb_total = ar_mask.max(1).values.sum()
713             nb_correct = nb_total - errors.max(1).values.sum()
714
715             return nb_total, nb_correct
716
717         test_nb_total, test_nb_correct = compute_nb_correct(self.test_input[:1000])
718
719         logger(
720             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
721         )
722
723         ##############################################################
724         # Log a few generated sequences
725         input = self.test_input[:10, : 12 * (1 + self.nb_digits)]
726         result = input.clone()
727         stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
728         ar_mask = (result != input).long()
729
730         # for n in range(result.size(0)):
731         # logger(
732         # f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
733         # )
734
735         masked_inplace_autoregression(
736             model,
737             self.batch_size,
738             result,
739             ar_mask,
740             deterministic_synthesis,
741             device=self.device,
742         )
743
744         for n in range(result.size(0)):
745             logger(
746                 f"test_after  {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
747             )
748         ##############################################################
749
750
751 ######################################################################
752
753
754 import expr
755
756
757 class Expr(Task):
758     def tensorize(self, sequences):
759         len_max = max([len(x) for x in sequences])
760         return torch.cat(
761             [
762                 torch.tensor(
763                     [
764                         [self.char2id[c] for c in s + "#" * (len_max - len(s))]
765                         for s in sequences
766                     ]
767                 )
768             ],
769             0,
770         ).to(self.device)
771
772     def __init__(
773         self,
774         nb_train_samples,
775         nb_test_samples,
776         nb_variables,
777         sequence_length,
778         operand_max,
779         result_max,
780         batch_size,
781         device=torch.device("cpu"),
782     ):
783         self.batch_size = batch_size
784         self.device = device
785
786         train_sequences = expr.generate_sequences(
787             nb_train_samples,
788             nb_variables=nb_variables,
789             length=sequence_length,
790             operand_max=operand_max,
791             result_max=result_max,
792         )
793
794         test_sequences = expr.generate_sequences(
795             nb_test_samples,
796             nb_variables=nb_variables,
797             length=sequence_length,
798             operand_max=operand_max,
799             result_max=result_max,
800         )
801
802         symbols = list(set("#" + "".join(train_sequences + test_sequences)))
803         symbols.sort()
804
805         self.char2id = dict([(c, n) for n, c in enumerate(symbols)])
806         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
807
808         self.filler, self.space = self.char2id["#"], self.char2id[" "]
809
810         self.train_input = self.tensorize(train_sequences)
811         self.test_input = self.tensorize(test_sequences)
812
813         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
814
815     def batches(self, split="train", nb_to_use=-1, desc=None):
816         assert split in {"train", "test"}
817         input = self.train_input if split == "train" else self.test_input
818         if nb_to_use > 0:
819             input = input[:nb_to_use]
820         if desc is None:
821             desc = f"epoch-{split}"
822         for batch in tqdm.tqdm(
823             input.split(self.batch_size), dynamic_ncols=True, desc=desc
824         ):
825             last = (batch != self.filler).max(0).values.nonzero().max() + 3
826             batch = batch[:, :last]
827             yield batch
828
829     def vocabulary_size(self):
830         return self.nb_codes
831
832     def seq2str(self, s):
833         return "".join([self.id2char[k.item()] for k in s])
834
835     def produce_results(
836         self,
837         n_epoch,
838         model,
839         result_dir,
840         logger,
841         deterministic_synthesis,
842         input_file=None,
843     ):
844         def compute_nb_correct(input):
845             result = input.clone()
846             s = (result == self.space).long()
847             ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
848             result = (1 - ar_mask) * result + ar_mask * self.filler
849             masked_inplace_autoregression(
850                 model,
851                 self.batch_size,
852                 result,
853                 ar_mask,
854                 deterministic_synthesis,
855                 device=self.device,
856             )
857
858             nb_total = input.size(0)
859             nb_correct = (input == result).long().min(1).values.sum()
860
861             #######################################################################
862             # Comput predicted vs. true variable values
863
864             nb_delta = torch.zeros(5, dtype=torch.int64)
865             nb_missed = 0
866
867             values_input = expr.extract_results([self.seq2str(s) for s in input])
868             values_result = expr.extract_results([self.seq2str(s) for s in result])
869
870             filename = os.path.join(result_dir, f"expr_result_{n_epoch:04d}.txt")
871
872             with open(filename, "w") as f:
873                 for i, r in zip(values_input, values_result):
874                     for n, vi in i.items():
875                         vr = r.get(n)
876                         f.write(f"{vi} {-1 if vr is None else vr}\n")
877
878                         if vr is None or vr < 0:
879                             nb_missed += 1
880                         else:
881                             d = abs(vr - vi)
882                             if d >= nb_delta.size(0):
883                                 nb_missed += 1
884                             else:
885                                 nb_delta[d] += 1
886
887             ######################################################################
888
889             return nb_total, nb_correct, nb_delta, nb_missed
890
891         (
892             test_nb_total,
893             test_nb_correct,
894             test_nb_delta,
895             test_nb_missed,
896         ) = compute_nb_correct(self.test_input[:10000])
897
898         logger(
899             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
900         )
901
902         nb_total = test_nb_delta.sum() + test_nb_missed
903         for d in range(test_nb_delta.size(0)):
904             logger(
905                 f"error_value {n_epoch} delta {d} {test_nb_delta[d]} {test_nb_delta[d]*100/nb_total:.02f}%"
906             )
907         logger(
908             f"error_value {n_epoch} missed {test_nb_missed} {test_nb_missed*100/nb_total:.02f}%"
909         )
910
911         ##############################################################
912         # Log a few generated sequences
913         if input_file is None:
914             input = self.test_input[:10]
915         else:
916             with open(input_file, "r") as f:
917                 sequences = [e.strip() for e in f.readlines()]
918                 sequences = [s + " " + "#" * 50 for s in sequences]
919                 input = self.tensorize(sequences)
920
921         result = input.clone()
922         s = (result == self.space).long()
923         ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
924         result = (1 - ar_mask) * result + ar_mask * self.filler
925
926         for n in range(result.size(0)):
927             logger(f"test_before {self.seq2str(result[n])}")
928
929         masked_inplace_autoregression(
930             model,
931             self.batch_size,
932             result,
933             ar_mask,
934             deterministic_synthesis,
935             device=self.device,
936         )
937
938         correct = (1 - ar_mask) * self.space + ar_mask * input
939         for n in range(result.size(0)):
940             comment = "GOOD" if (result[n] - input[n]).abs().max() == 0 else ""
941             logger(f"test_after  {self.seq2str(result[n])} {comment}")
942             logger(f"truth       {self.seq2str(correct[n])}")
943         ##############################################################
944
945
946 ######################################################################
947 import world
948
949
950 class World(Task):
951     def __init__(
952         self,
953         nb_train_samples,
954         nb_test_samples,
955         batch_size,
956         device=torch.device("cpu"),
957     ):
958         self.batch_size = batch_size
959         self.device = device
960
961         (
962             self.train_input,
963             self.train_actions,
964             self.test_input,
965             self.test_actions,
966             self.frame2seq,
967             self.seq2frame,
968         ) = world.create_data_and_processors(
969             nb_train_samples,
970             nb_test_samples,
971             mode="first_last",
972             nb_steps=30,
973             nb_epochs=2,
974         )
975
976         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
977
978     def batches(self, split="train", nb_to_use=-1, desc=None):
979         assert split in {"train", "test"}
980         input = self.train_input if split == "train" else self.test_input
981         if nb_to_use > 0:
982             input = input[:nb_to_use]
983         if desc is None:
984             desc = f"epoch-{split}"
985         for batch in tqdm.tqdm(
986             input.split(self.batch_size), dynamic_ncols=True, desc=desc
987         ):
988             yield batch
989
990     def vocabulary_size(self):
991         return self.nb_codes
992
993     def produce_results(
994         self, n_epoch, model, result_dir, logger, deterministic_synthesis
995     ):
996         pass
997
998
999 ######################################################################