Update.
[picoclvr.git] / tasks.py
1 #!/usr/bin/env python
2
3 import math, os, tqdm
4
5 import torch, torchvision
6
7 from torch import nn
8 from torch.nn import functional as F
9
10 ######################################################################
11
12
13 def masked_inplace_autoregression(
14     model,
15     batch_size,
16     input,
17     ar_mask,
18     deterministic_synthesis,
19     forbidden_tokens=None,
20     progress_bar_desc="autoregression",
21     device=torch.device("cpu"),
22 ):
23     assert input.size() == ar_mask.size()
24
25     batches = zip(input.split(batch_size), ar_mask.split(batch_size))
26
27     if progress_bar_desc is not None:
28         batches = tqdm.tqdm(
29             batches,
30             dynamic_ncols=True,
31             desc=progress_bar_desc,
32             # total=input.size(0) // batch_size,
33         )
34
35     with torch.autograd.no_grad():
36         t = model.training
37         model.eval()
38
39         for input, ar_mask in batches:
40             model.masked_inplace_autoregression(
41                 input, ar_mask, forbidden_tokens, deterministic_synthesis
42             )
43
44         model.train(t)
45
46
47 ######################################################################
48
49
50 class Task:
51     def batches(self, split="train"):
52         pass
53
54     def vocabulary_size(self):
55         pass
56
57     def produce_results(
58         self, n_epoch, model, result_dir, logger, deterministic_synthesis
59     ):
60         pass
61
62
63 ######################################################################
64
65
66 class Problem:
67     def generate_sequences(self, nb):
68         pass
69
70     def log_performance(self, sequences, logger):
71         pass
72
73
74 class ProblemByheart(Problem):
75     def __init__(self):
76         nb_seq, len_prompt, len_result = 100, 5, 5
77         self.seq = torch.randint(10, (nb_seq, len_prompt + 1 + len_result))
78         self.seq[:, len_prompt] = 10
79
80     def generate_sequences(self, nb):
81         sequences = self.seq[torch.randint(self.seq.size(0), (nb,))]
82         ar_mask = (sequences==10).long()
83         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
84         return sequences, ar_mask
85
86         # problems = [ProblemByheart()]
87         # nb_common_codes = 100
88
89         # def generate_sequences(nb_samples):
90             # problem_indexes = torch.randint(len(problems), (nb_samples,))
91             # nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0)
92             # print(f"{nb_samples_per_problem}")
93             # all_seq = []
94             # for nb, p in zip(nb_samples_per_problem, problems):
95                 # all_seq.append(p.generate_sequences(nb_samples_per_problem[nb]))
96             # return all_seq
97
98         # for strain, stest in zip(train_seq, test_seq):
99             # s = torch.cat((strain, stest), 0)
100
101 class SandBox(Task):
102     def __init__(
103         self,
104         problem,
105         nb_train_samples,
106         nb_test_samples,
107         batch_size,
108         logger=None,
109         device=torch.device("cpu"),
110     ):
111         super().__init__()
112
113         self.batch_size = batch_size
114         self.device = device
115
116         self.train_input, self.train_ar_mask = problem.generate_sequences(nb_train_samples)
117         self.test_input, self.test_ar_mask = problem.generate_sequences(nb_test_samples)
118
119         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
120
121     def batches(self, split="train", nb_to_use=-1, desc=None):
122         assert split in {"train", "test"}
123         input = self.train_input if split == "train" else self.test_input
124         if nb_to_use > 0:
125             input = input[:nb_to_use]
126         if desc is None:
127             desc = f"epoch-{split}"
128         for batch in tqdm.tqdm(
129             input.split(self.batch_size), dynamic_ncols=True, desc=desc
130         ):
131             yield batch
132
133     def vocabulary_size(self):
134         return self.nb_codes
135
136     def produce_results(
137         self, n_epoch, model, result_dir, logger, deterministic_synthesis
138     ):
139
140         def compute_accuracy(input, ar_mask):
141             result = input.clone() * (1-ar_mask)
142             masked_inplace_autoregression(
143                 model,
144                 self.batch_size,
145                 result,
146                 ar_mask,
147                 deterministic_synthesis,
148                 progress_bar_desc=None,
149                 device=self.device,
150             )
151
152             nb_total = ar_mask.sum().item()
153             nb_correct = ((result==input).long() * ar_mask).sum().item()
154
155             return nb_total, nb_correct
156
157         train_nb_total, train_nb_correct = compute_accuracy(self.train_input, self.train_ar_mask)
158
159         logger(
160             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
161         )
162
163         test_nb_total, test_nb_correct = compute_accuracy(self.test_input, self.test_ar_mask)
164
165         logger(
166             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
167         )
168
169 ######################################################################
170
171 import picoclvr
172
173
174 class PicoCLVR(Task):
175     # Make a tensor from a list of strings
176     def tensorize(self, descr):
177         token_descr = [s.strip().split(" ") for s in descr]
178         l = max([len(s) for s in token_descr])
179         token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
180         id_descr = [[self.token2id[u] for u in s] for s in token_descr]
181         return torch.tensor(id_descr, device=self.device)
182
183     # Make a list of strings from a tensor
184     def detensorize(self, x):
185         return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
186
187     # trim all the tensors in the tuple z to remove as much token from
188     # left and right in the first tensor. If z is a tuple, all its
189     # elements are trimed according to the triming for the first
190     def trim(self, z, token="<nul>"):
191         n = self.token2id[token]
192         if type(z) == tuple:
193             x = z[0]
194             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
195             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
196             return tuple([t[:, a:b] for t in z])
197         else:
198             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
199             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
200             return z[:, a:b]
201
202     ######################
203
204     def __init__(
205         self,
206         nb_train_samples,
207         nb_test_samples,
208         batch_size,
209         height,
210         width,
211         nb_colors=5,
212         logger=None,
213         device=torch.device("cpu"),
214         pruner_train=None,
215         pruner_eval=None,
216     ):
217         super().__init__()
218
219         def generate_descr(nb, cache_suffix, pruner):
220             return picoclvr.generate(
221                 nb,
222                 height=self.height,
223                 width=self.width,
224                 nb_colors=nb_colors,
225                 pruner=pruner,
226             )
227
228         self.height = height
229         self.width = width
230         self.batch_size = batch_size
231         self.device = device
232         self.pruner_train = pruner_train
233         self.pruner_eval = pruner_eval
234
235         if logger is not None:
236             logger(
237                 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
238             )
239
240         self.train_descr = generate_descr(
241             nb_train_samples, "train", pruner=self.pruner_train
242         )
243         self.test_descr = generate_descr(nb_test_samples, "test", pruner=None)
244
245         # Build the tokenizer
246         tokens = {"<nul>", "<img>"}
247         for d in [self.train_descr, self.test_descr]:
248             for s in d:
249                 for t in s.strip().split(" "):
250                     tokens.add(t)
251         # make this set a sorted list to get the same tensors given
252         # the same descr
253         tokens = list(tokens)
254         tokens.sort()
255         self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
256         self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
257         self.t_img, self.t_nul = self.token2id["<img>"], self.token2id["<nul>"]
258
259         # Tokenize the train and test sets
260         self.train_input = self.tensorize(self.train_descr)
261         self.test_input = self.tensorize(self.test_descr)
262
263     def batches(self, split="train"):
264         assert split in {"train", "test"}
265         input = self.train_input if split == "train" else self.test_input
266         for batch in tqdm.tqdm(
267             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
268         ):
269             yield self.trim(batch)
270
271     def vocabulary_size(self):
272         return len(self.token2id)
273
274     def compute_missing_properties(
275         self, n_epoch, model, logger, deterministic_synthesis, pruner=None
276     ):
277         acc_nb_requested_properties = []
278         acc_nb_missing_properties = []
279         acc_nb_results = 0
280
281         for input in tqdm.tqdm(
282             self.test_input.split(self.batch_size),
283             dynamic_ncols=True,
284             desc=f"test-properties",
285         ):
286             result = input.clone()
287             ar_mask = (result == self.t_img).long().cumsum(dim=1).clamp(max=1)
288             result = (1 - ar_mask) * result + ar_mask * self.t_nul
289             masked_inplace_autoregression(
290                 model,
291                 self.batch_size,
292                 result,
293                 ar_mask,
294                 deterministic_synthesis,
295                 progress_bar_desc=None,
296                 device=self.device,
297             )
298
299             result_descr = self.detensorize(result)
300             np = picoclvr.nb_properties(
301                 result_descr,
302                 height=self.height,
303                 width=self.width,
304                 pruner=pruner,
305             )
306             nb_requested_properties, _, nb_missing_properties = zip(*np)
307             acc_nb_requested_properties += nb_requested_properties
308             acc_nb_missing_properties += nb_missing_properties
309             acc_nb_results += len(result_descr)
310
311         nb_requested_properties = sum(acc_nb_requested_properties)
312         nb_missing_properties = sum(acc_nb_missing_properties)
313
314         prefix = "" if pruner is None else "pruned_"
315         logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
316         logger(
317             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
318         )
319         logger(
320             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
321         )
322
323     ######################################################################
324
325     def produce_results(
326         self, n_epoch, model, result_dir, logger, deterministic_synthesis
327     ):
328         self.compute_missing_properties(n_epoch, model, logger, deterministic_synthesis)
329
330         if self.pruner_eval is not None:
331             self.compute_missing_properties(n_epoch, model, self.pruner_eval)
332
333         nb_tokens_to_generate = self.height * self.width + 3
334         result_descr = []
335         nb_per_primer = 8
336         primer = []
337
338         for primer_descr in [
339             "red above green <sep> green top <sep> blue right of red",
340             "there is red <sep> there is yellow <sep> there is blue",
341             "red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left",
342             "green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top",
343         ]:
344             primer += [primer_descr + " <img>"] * nb_per_primer
345
346         result = self.tensorize(primer)
347         fill = result.new_full(
348             result.size()[:-1] + (self.height * self.width + 1,), self.t_nul
349         )
350         result = torch.cat((result, fill), 1)
351         ar_mask = (result == self.t_nul).long()
352         masked_inplace_autoregression(
353             model,
354             self.batch_size,
355             result,
356             ar_mask,
357             deterministic_synthesis,
358             device=self.device,
359         )
360         result_descr = self.detensorize(result)
361
362         np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
363
364         acc_nb_requested_properties, _, acc_nb_missing_properties = zip(*np)
365         acc_nb_results = len(result_descr)
366
367         nb_requested_properties = sum(acc_nb_requested_properties)
368         nb_missing_properties = sum(acc_nb_missing_properties)
369
370         prefix = "demo_"
371         logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
372         logger(
373             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
374         )
375         logger(
376             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
377         )
378
379         img = picoclvr.descr2img(result_descr, height=self.height, width=self.width)
380
381         if img.dim() == 5:
382             if img.size(1) == 1:
383                 img = F.pad(img.squeeze(1), pad=(1, 1, 1, 1), value=64)
384             else:
385                 img = torch.cat(
386                     [
387                         torchvision.utils.make_grid(x, padding=1, pad_value=64)[None]
388                         for x in img
389                     ],
390                     0,
391                 )
392
393         image_name = os.path.join(result_dir, f"picoclvr_result_{n_epoch:04d}.png")
394         torchvision.utils.save_image(
395             img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=0.0
396         )
397         logger(f"wrote {image_name}")
398
399
400 ######################################################################
401
402
403 class MNIST(Task):
404     def __init__(
405         self, nb_train_samples, nb_test_samples, batch_size, device=torch.device("cpu")
406     ):
407         super().__init__()
408
409         self.nb_train_samples = (nb_train_samples,)
410         self.nb_test_samples = (nb_test_samples,)
411         self.batch_size = batch_size
412         self.device = device
413         data_set = torchvision.datasets.MNIST(root="./data", train=True, download=True)
414         self.train_input = data_set.data[:nb_train_samples].view(-1, 28 * 28).long()
415         data_set = torchvision.datasets.MNIST(root="./data", train=False, download=True)
416         self.test_input = data_set.data[:nb_test_samples].view(-1, 28 * 28).long()
417
418     def batches(self, split="train", nb_to_use=-1, desc=None):
419         assert split in {"train", "test"}
420         input = self.train_input if split == "train" else self.test_input
421         if nb_to_use > 0:
422             input = input[:nb_to_use]
423         if desc is None:
424             desc = f"epoch-{split}"
425         for batch in tqdm.tqdm(
426             input.split(self.batch_size), dynamic_ncols=True, desc=desc
427         ):
428             yield batch
429
430     def vocabulary_size(self):
431         return 256
432
433     def produce_results(
434         self, n_epoch, model, result_dir, logger, deterministic_synthesis
435     ):
436         results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64)
437         ar_mask = torch.full_like(results, 1)
438         masked_inplace_autoregression(
439             model,
440             self.batch_size,
441             results,
442             ar_mask,
443             deterministic_synthesis,
444             device=self.device,
445         )
446         image_name = os.path.join(result_dir, f"mnist_result_{n_epoch:04d}.png")
447         torchvision.utils.save_image(
448             1 - results.reshape(-1, 1, 28, 28) / 255.0,
449             image_name,
450             nrow=16,
451             pad_value=0.8,
452         )
453         logger(f"wrote {image_name}")
454
455
456 ######################################################################
457
458 import maze
459
460
461 class Maze(Task):
462     def map2seq(self, *m):
463         return torch.cat([x.flatten(1) for x in m], 1)
464
465     def seq2map(self, s):
466         s = s.reshape(s.size(0), -1, self.height, self.width)
467         return (s[:, k] for k in range(s.size(1)))
468
469     def __init__(
470         self,
471         nb_train_samples,
472         nb_test_samples,
473         batch_size,
474         height,
475         width,
476         nb_walls,
477         device=torch.device("cpu"),
478     ):
479         super().__init__()
480
481         self.batch_size = batch_size
482         self.height = height
483         self.width = width
484         self.device = device
485
486         train_mazes, train_paths, _ = maze.create_maze_data(
487             nb_train_samples,
488             height=height,
489             width=width,
490             nb_walls=nb_walls,
491             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"),
492         )
493         self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device))
494
495         test_mazes, test_paths, _ = maze.create_maze_data(
496             nb_test_samples,
497             height=height,
498             width=width,
499             nb_walls=nb_walls,
500             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"),
501         )
502         self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
503
504         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
505
506     def batches(self, split="train", nb_to_use=-1, desc=None):
507         assert split in {"train", "test"}
508         input = self.train_input if split == "train" else self.test_input
509         if nb_to_use > 0:
510             input = input[:nb_to_use]
511         if desc is None:
512             desc = f"epoch-{split}"
513         for batch in tqdm.tqdm(
514             input.split(self.batch_size), dynamic_ncols=True, desc=desc
515         ):
516             yield batch
517
518     def vocabulary_size(self):
519         return self.nb_codes
520
521     def compute_error(
522         self, model, split="train", nb_to_use=-1, deterministic_synthesis=False
523     ):
524         nb_total, nb_correct = 0, 0
525         count = torch.zeros(
526             self.width * self.height,
527             self.width * self.height,
528             device=self.device,
529             dtype=torch.int64,
530         )
531
532         for input in self.batches(split, nb_to_use):
533             result = input.clone()
534             ar_mask = result.new_zeros(result.size())
535             ar_mask[:, self.height * self.width :] = 1
536             result *= 1 - ar_mask
537             masked_inplace_autoregression(
538                 model,
539                 self.batch_size,
540                 result,
541                 ar_mask,
542                 deterministic_synthesis,
543                 progress_bar_desc=None,
544                 device=self.device,
545             )
546             mazes, paths = self.seq2map(result)
547             path_correctness = maze.path_correctness(mazes, paths)
548             nb_correct += path_correctness.long().sum()
549             nb_total += mazes.size(0)
550
551             optimal_path_lengths = (
552                 (input[:, self.height * self.width :] == maze.v_path).long().sum(1)
553             )
554             predicted_path_lengths = (
555                 (result[:, self.height * self.width :] == maze.v_path).long().sum(1)
556             )
557             optimal_path_lengths = optimal_path_lengths[path_correctness]
558             predicted_path_lengths = predicted_path_lengths[path_correctness]
559             count[optimal_path_lengths, predicted_path_lengths] += 1
560
561         if count.max() == 0:
562             count = None
563         else:
564             count = count[
565                 : count.sum(1).nonzero().max() + 1, : count.sum(0).nonzero().max() + 1
566             ]
567
568         return nb_total, nb_correct, count
569
570     def produce_results(
571         self, n_epoch, model, result_dir, logger, deterministic_synthesis
572     ):
573         train_nb_total, train_nb_correct, count = self.compute_error(
574             model,
575             "train",
576             nb_to_use=1000,
577             deterministic_synthesis=deterministic_synthesis,
578         )
579         logger(
580             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
581         )
582
583         test_nb_total, test_nb_correct, count = self.compute_error(
584             model,
585             "test",
586             nb_to_use=1000,
587             deterministic_synthesis=deterministic_synthesis,
588         )
589         logger(
590             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
591         )
592
593         if count is not None:
594             proportion_optimal = count.diagonal().sum().float() / count.sum()
595             logger(f"proportion_optimal_test {proportion_optimal*100:.02f}%")
596             with open(
597                 os.path.join(result_dir, f"maze_result_{n_epoch:04d}.txt"), "w"
598             ) as f:
599                 for i in range(count.size(0)):
600                     for j in range(count.size(1)):
601                         eol = " " if j < count.size(1) - 1 else "\n"
602                         f.write(f"{count[i,j]}{eol}")
603
604         input = self.test_input[:48]
605         result = input.clone()
606         ar_mask = result.new_zeros(result.size())
607         ar_mask[:, self.height * self.width :] = 1
608         result *= 1 - ar_mask
609         masked_inplace_autoregression(
610             model,
611             self.batch_size,
612             result,
613             ar_mask,
614             deterministic_synthesis,
615             device=self.device,
616         )
617
618         mazes, paths = self.seq2map(input)
619         _, predicted_paths = self.seq2map(result)
620
621         filename = os.path.join(result_dir, f"maze_result_{n_epoch:04d}.png")
622         maze.save_image(
623             filename,
624             mazes=mazes,
625             target_paths=paths,
626             predicted_paths=predicted_paths,
627             path_correct=maze.path_correctness(mazes, predicted_paths),
628             path_optimal=maze.path_optimality(paths, predicted_paths),
629         )
630         logger(f"wrote {filename}")
631
632
633 ######################################################################
634
635
636 import snake
637
638
639 class Snake(Task):
640     def __init__(
641         self,
642         nb_train_samples,
643         nb_test_samples,
644         batch_size,
645         height,
646         width,
647         nb_colors,
648         length,
649         prompt_length,
650         device=torch.device("cpu"),
651     ):
652         super().__init__()
653
654         self.batch_size = batch_size
655         self.height = height
656         self.width = width
657         self.device = device
658         self.prompt_length = prompt_length
659
660         self.train_input, self.train_prior_visits, _, _ = snake.generate_sequences(
661             nb_train_samples,
662             height,
663             width,
664             nb_colors,
665             length,
666             prompt_length,
667             self.device,
668         )
669         self.test_input, self.test_prior_visits, _, _ = snake.generate_sequences(
670             nb_test_samples,
671             height,
672             width,
673             nb_colors,
674             length,
675             prompt_length,
676             self.device,
677         )
678
679         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
680
681     def batches(self, split="train", nb_to_use=-1, desc=None):
682         assert split in {"train", "test"}
683         input = self.train_input if split == "train" else self.test_input
684         if nb_to_use > 0:
685             input = input[:nb_to_use]
686         if desc is None:
687             desc = f"epoch-{split}"
688         for batch in tqdm.tqdm(
689             input.split(self.batch_size), dynamic_ncols=True, desc=desc
690         ):
691             yield batch
692
693     def vocabulary_size(self):
694         return self.nb_codes
695
696     def produce_results(
697         self, n_epoch, model, result_dir, logger, deterministic_synthesis
698     ):
699         def compute_nb_correct(input, prior_visits):
700             result = input.clone()
701             i = torch.arange(result.size(1), device=result.device)[None, :]
702             ar_mask = (
703                 torch.logical_and(i >= self.prompt_length * 2, i % 2 == 0)
704                 .long()
705                 .expand_as(result)
706             )
707             result *= 1 - ar_mask
708
709             masked_inplace_autoregression(
710                 model,
711                 self.batch_size,
712                 result,
713                 ar_mask,
714                 deterministic_synthesis,
715                 device=self.device,
716             )
717
718             nb_total = ((prior_visits > 0) * ar_mask).sum()
719
720             nb_correct = ((result == input).long() * (prior_visits > 0) * ar_mask).sum()
721
722             return nb_total, nb_correct
723
724         test_nb_total, test_nb_correct = compute_nb_correct(
725             self.test_input[:1000], self.test_prior_visits[:1000]
726         )
727
728         logger(
729             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
730         )
731
732
733 ######################################################################
734
735
736 import stack
737
738
739 class Stack(Task):
740     def __init__(
741         self,
742         nb_train_samples,
743         nb_test_samples,
744         batch_size,
745         logger,
746         nb_steps,
747         nb_stacks,
748         nb_digits,
749         fraction_values_for_train=None,
750         device=torch.device("cpu"),
751     ):
752         super().__init__()
753
754         self.batch_size = batch_size
755         self.nb_steps = nb_steps
756         self.nb_stacks = nb_stacks
757         self.nb_digits = nb_digits
758         self.device = device
759
760         if fraction_values_for_train is None:
761             values_for_train = None
762             values_for_test = None
763         else:
764             all = torch.randperm(10**nb_digits)
765             nb_for_train = int(all.size(0) * fraction_values_for_train)
766             values_for_train = all[:nb_for_train]
767             values_for_test = all[nb_for_train:]
768
769         self.train_input, self.train_stack_counts = stack.generate_sequences(
770             nb_train_samples,
771             nb_steps,
772             nb_stacks,
773             nb_digits,
774             values_for_train,
775             self.device,
776         )
777
778         self.test_input, self.test_stack_counts = stack.generate_sequences(
779             nb_test_samples,
780             nb_steps,
781             nb_stacks,
782             nb_digits,
783             values_for_test,
784             self.device,
785         )
786
787         i = torch.logical_and(self.test_input % 2 == 1, self.test_input < 2 * nb_stacks)
788         counts = self.test_stack_counts.flatten()[i.flatten()]
789         counts = F.one_hot(counts).sum(0)
790         logger(f"test_pop_stack_counts {counts}")
791
792         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
793
794     def batches(self, split="train", nb_to_use=-1, desc=None):
795         assert split in {"train", "test"}
796         input = self.train_input if split == "train" else self.test_input
797         if nb_to_use > 0:
798             input = input[:nb_to_use]
799         if desc is None:
800             desc = f"epoch-{split}"
801         for batch in tqdm.tqdm(
802             input.split(self.batch_size), dynamic_ncols=True, desc=desc
803         ):
804             yield batch
805
806     def vocabulary_size(self):
807         return self.nb_codes
808
809     def produce_results(
810         self, n_epoch, model, result_dir, logger, deterministic_synthesis
811     ):
812         def compute_nb_correct(input):
813             result = input.clone()
814             stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
815             ar_mask = (result != input).long()
816             masked_inplace_autoregression(
817                 model,
818                 self.batch_size,
819                 result,
820                 ar_mask,
821                 deterministic_synthesis,
822                 device=self.device,
823             )
824
825             errors = ((result != input).long() * ar_mask).reshape(
826                 -1, 1 + self.nb_digits
827             )
828             ar_mask = ar_mask.reshape(-1, 1 + self.nb_digits)
829
830             nb_total = ar_mask.max(1).values.sum()
831             nb_correct = nb_total - errors.max(1).values.sum()
832
833             return nb_total, nb_correct
834
835         test_nb_total, test_nb_correct = compute_nb_correct(self.test_input[:1000])
836
837         logger(
838             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
839         )
840
841         ##############################################################
842         # Log a few generated sequences
843         input = self.test_input[:10, : 12 * (1 + self.nb_digits)]
844         result = input.clone()
845         stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
846         ar_mask = (result != input).long()
847
848         # for n in range(result.size(0)):
849         # logger(
850         # f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
851         # )
852
853         masked_inplace_autoregression(
854             model,
855             self.batch_size,
856             result,
857             ar_mask,
858             deterministic_synthesis,
859             device=self.device,
860         )
861
862         for n in range(result.size(0)):
863             logger(
864                 f"test_after  {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
865             )
866         ##############################################################
867
868
869 ######################################################################
870
871
872 import expr
873
874
875 class Expr(Task):
876     def tensorize(self, sequences):
877         len_max = max([len(x) for x in sequences])
878         return torch.cat(
879             [
880                 torch.tensor(
881                     [
882                         [self.char2id[c] for c in s + "#" * (len_max - len(s))]
883                         for s in sequences
884                     ]
885                 )
886             ],
887             0,
888         ).to(self.device)
889
890     def __init__(
891         self,
892         nb_train_samples,
893         nb_test_samples,
894         nb_variables,
895         sequence_length,
896         operand_max,
897         result_max,
898         batch_size,
899         device=torch.device("cpu"),
900     ):
901         super().__init__()
902
903         self.batch_size = batch_size
904         self.device = device
905
906         train_sequences = expr.generate_sequences(
907             nb_train_samples,
908             nb_variables=nb_variables,
909             length=sequence_length,
910             operand_max=operand_max,
911             result_max=result_max,
912         )
913
914         test_sequences = expr.generate_sequences(
915             nb_test_samples,
916             nb_variables=nb_variables,
917             length=sequence_length,
918             operand_max=operand_max,
919             result_max=result_max,
920         )
921
922         symbols = list(set("#" + "".join(train_sequences + test_sequences)))
923         symbols.sort()
924
925         self.char2id = dict([(c, n) for n, c in enumerate(symbols)])
926         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
927
928         self.filler, self.space = self.char2id["#"], self.char2id[" "]
929
930         self.train_input = self.tensorize(train_sequences)
931         self.test_input = self.tensorize(test_sequences)
932
933         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
934
935     def batches(self, split="train", nb_to_use=-1, desc=None):
936         assert split in {"train", "test"}
937         input = self.train_input if split == "train" else self.test_input
938         if nb_to_use > 0:
939             input = input[:nb_to_use]
940         if desc is None:
941             desc = f"epoch-{split}"
942         for batch in tqdm.tqdm(
943             input.split(self.batch_size), dynamic_ncols=True, desc=desc
944         ):
945             last = (batch != self.filler).max(0).values.nonzero().max() + 3
946             batch = batch[:, :last]
947             yield batch
948
949     def vocabulary_size(self):
950         return self.nb_codes
951
952     def seq2str(self, s):
953         return "".join([self.id2char[k.item()] for k in s])
954
955     def produce_results(
956         self,
957         n_epoch,
958         model,
959         result_dir,
960         logger,
961         deterministic_synthesis,
962         input_file=None,
963     ):
964         def compute_nb_correct(input):
965             result = input.clone()
966             s = (result == self.space).long()
967             ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
968             result = (1 - ar_mask) * result + ar_mask * self.filler
969             masked_inplace_autoregression(
970                 model,
971                 self.batch_size,
972                 result,
973                 ar_mask,
974                 deterministic_synthesis,
975                 device=self.device,
976             )
977
978             nb_total = input.size(0)
979             nb_correct = (input == result).long().min(1).values.sum()
980
981             #######################################################################
982             # Comput predicted vs. true variable values
983
984             nb_delta = torch.zeros(5, dtype=torch.int64)
985             nb_missed = 0
986
987             values_input = expr.extract_results([self.seq2str(s) for s in input])
988             values_result = expr.extract_results([self.seq2str(s) for s in result])
989
990             filename = os.path.join(result_dir, f"expr_result_{n_epoch:04d}.txt")
991
992             with open(filename, "w") as f:
993                 for i, r in zip(values_input, values_result):
994                     for n, vi in i.items():
995                         vr = r.get(n)
996                         f.write(f"{vi} {-1 if vr is None else vr}\n")
997
998                         if vr is None or vr < 0:
999                             nb_missed += 1
1000                         else:
1001                             d = abs(vr - vi)
1002                             if d >= nb_delta.size(0):
1003                                 nb_missed += 1
1004                             else:
1005                                 nb_delta[d] += 1
1006
1007             ######################################################################
1008
1009             return nb_total, nb_correct, nb_delta, nb_missed
1010
1011         (
1012             test_nb_total,
1013             test_nb_correct,
1014             test_nb_delta,
1015             test_nb_missed,
1016         ) = compute_nb_correct(self.test_input[:10000])
1017
1018         logger(
1019             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
1020         )
1021
1022         nb_total = test_nb_delta.sum() + test_nb_missed
1023         for d in range(test_nb_delta.size(0)):
1024             logger(
1025                 f"error_value {n_epoch} delta {d} {test_nb_delta[d]} {test_nb_delta[d]*100/nb_total:.02f}%"
1026             )
1027         logger(
1028             f"error_value {n_epoch} missed {test_nb_missed} {test_nb_missed*100/nb_total:.02f}%"
1029         )
1030
1031         ##############################################################
1032         # Log a few generated sequences
1033         if input_file is None:
1034             input = self.test_input[:10]
1035         else:
1036             with open(input_file, "r") as f:
1037                 sequences = [e.strip() for e in f.readlines()]
1038                 sequences = [s + " " + "#" * 50 for s in sequences]
1039                 input = self.tensorize(sequences)
1040
1041         result = input.clone()
1042         s = (result == self.space).long()
1043         ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
1044         result = (1 - ar_mask) * result + ar_mask * self.filler
1045
1046         for n in range(result.size(0)):
1047             logger(f"test_before {self.seq2str(result[n])}")
1048
1049         masked_inplace_autoregression(
1050             model,
1051             self.batch_size,
1052             result,
1053             ar_mask,
1054             deterministic_synthesis,
1055             device=self.device,
1056         )
1057
1058         correct = (1 - ar_mask) * self.space + ar_mask * input
1059         for n in range(result.size(0)):
1060             comment = "GOOD" if (result[n] - input[n]).abs().max() == 0 else ""
1061             logger(f"test_after  {self.seq2str(result[n])} {comment}")
1062             logger(f"truth       {self.seq2str(correct[n])}")
1063         ##############################################################
1064
1065
1066 ######################################################################
1067
1068 import world
1069
1070
1071 class World(Task):
1072     def __init__(
1073         self,
1074         nb_train_samples,
1075         nb_test_samples,
1076         batch_size,
1077         vqae_nb_epochs,
1078         logger=None,
1079         device=torch.device("cpu"),
1080         device_storage=torch.device("cpu"),
1081     ):
1082         super().__init__()
1083
1084         self.batch_size = batch_size
1085         self.device = device
1086
1087         (
1088             train_frames,
1089             train_action_seq,
1090             test_frames,
1091             test_action_seq,
1092             self.frame2seq,
1093             self.seq2frame,
1094         ) = world.create_data_and_processors(
1095             nb_train_samples,
1096             nb_test_samples,
1097             mode="first_last",
1098             nb_steps=30,
1099             nb_epochs=vqae_nb_epochs,
1100             logger=logger,
1101             device=device,
1102             device_storage=device_storage,
1103         )
1104
1105         print(f"{train_action_seq.size()=}")
1106
1107         train_frame_seq = self.frame2seq(train_frames).to(device_storage)
1108         test_frame_seq = self.frame2seq(test_frames).to(device_storage)
1109
1110         nb_frame_codes = max(train_frame_seq.max(), test_frame_seq.max()) + 1
1111         nb_action_codes = max(train_action_seq.max(), test_action_seq.max()) + 1
1112
1113         self.len_frame_seq = train_frame_seq.size(1)
1114         self.len_action_seq = train_action_seq.size(1)
1115         self.nb_codes = nb_frame_codes + nb_action_codes
1116
1117         train_frame_seq = train_frame_seq.reshape(train_frame_seq.size(0) // 2, 2, -1)
1118         print(f"{train_action_seq.device=} {nb_frame_codes.device=}")
1119         train_action_seq += nb_frame_codes
1120         self.train_input = torch.cat(
1121             (train_frame_seq[:, 0, :], train_action_seq, train_frame_seq[:, 1, :]), 1
1122         )
1123
1124         test_frame_seq = test_frame_seq.reshape(test_frame_seq.size(0) // 2, 2, -1)
1125         test_action_seq += nb_frame_codes
1126         self.test_input = torch.cat(
1127             (test_frame_seq[:, 0, :], test_action_seq, test_frame_seq[:, 1, :]), 1
1128         )
1129
1130     def batches(self, split="train", nb_to_use=-1, desc=None):
1131         assert split in {"train", "test"}
1132         input = self.train_input if split == "train" else self.test_input
1133         if nb_to_use > 0:
1134             input = input[:nb_to_use]
1135         if desc is None:
1136             desc = f"epoch-{split}"
1137         for batch in tqdm.tqdm(
1138             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1139         ):
1140             yield batch.to(self.device)
1141
1142     def vocabulary_size(self):
1143         return self.nb_codes
1144
1145     def produce_results(
1146         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1147     ):
1148         k = torch.arange(
1149             2 * self.len_frame_seq + self.len_action_seq, device=self.device
1150         )[None, :]
1151
1152         input = self.test_input[:64].to(self.device)
1153         result = input.clone()
1154
1155         ar_mask = (
1156             (k >= self.len_frame_seq + self.len_action_seq).long().expand_as(result)
1157         )
1158         result *= 1 - ar_mask
1159
1160         masked_inplace_autoregression(
1161             model,
1162             self.batch_size,
1163             result,
1164             ar_mask,
1165             deterministic_synthesis,
1166             device=self.device,
1167         )
1168
1169         seq_start = input[:, : self.len_frame_seq]
1170         seq_end = input[:, self.len_frame_seq + self.len_action_seq :]
1171         seq_predicted = result[:, self.len_frame_seq + self.len_action_seq :]
1172
1173         result = torch.cat(
1174             (seq_start[:, None, :], seq_end[:, None, :], seq_predicted[:, None, :]), 1
1175         )
1176         result = result.reshape(-1, result.size(-1))
1177         print(f"{result.size()=}")
1178
1179         frames = self.seq2frame(result)
1180         image_name = os.path.join(result_dir, f"world_result_{n_epoch:04d}.png")
1181         torchvision.utils.save_image(
1182             frames.float() / (world.Box.nb_rgb_levels - 1),
1183             image_name,
1184             nrow=12,
1185             padding=1,
1186             pad_value=0.0,
1187         )
1188         logger(f"wrote {image_name}")
1189
1190
1191 ######################################################################