Update.
[picoclvr.git] / main.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # torch.backends.cuda.matmul.allow_tf23
9 # torch.autocast(torch.bfloat16)
10
11 import math, sys, argparse, time, tqdm, os
12
13 import torch, torchvision
14 from torch import nn
15 from torch.nn import functional as F
16
17 import mygpt, tensorstack
18
19 ######################################################################
20
21 if torch.cuda.is_available():
22     device = torch.device("cuda")
23     torch.backends.cuda.matmul.allow_tf32 = True
24 else:
25     device = torch.device("cpu")
26
27 ######################################################################
28
29 parser = argparse.ArgumentParser(
30     description="An implementation of GPT with cache.",
31     formatter_class=argparse.ArgumentDefaultsHelpFormatter,
32 )
33
34 parser.add_argument(
35     "--task",
36     type=str,
37     default="picoclvr",
38     help="picoclvr, mnist, maze, snake, stack, expr",
39 )
40
41 parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
42
43 parser.add_argument("--result_dir", type=str, default=None)
44
45 parser.add_argument("--seed", type=int, default=0)
46
47 parser.add_argument("--nb_epochs", type=int, default=None)
48
49 parser.add_argument("--batch_size", type=int, default=None)
50
51 parser.add_argument("--nb_train_samples", type=int, default=None)
52
53 parser.add_argument("--nb_test_samples", type=int, default=None)
54
55 parser.add_argument("--optim", type=str, default="adam")
56
57 parser.add_argument("--learning_rate", type=float, default=1e-4)
58
59 parser.add_argument("--learning_rate_schedule", type=str, default="10: 2e-5,30: 4e-6")
60
61 parser.add_argument("--dim_model", type=int, default=512)
62
63 parser.add_argument("--dim_keys", type=int, default=64)
64
65 parser.add_argument("--dim_hidden", type=int, default=2048)
66
67 parser.add_argument("--nb_heads", type=int, default=8)
68
69 parser.add_argument("--nb_blocks", type=int, default=12)
70
71 parser.add_argument("--dropout", type=float, default=0.1)
72
73 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
74
75 parser.add_argument("--no_checkpoint", action="store_true", default=False)
76
77 parser.add_argument("--overwrite_results", action="store_true", default=False)
78
79 parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
80
81 ##############################
82 # picoclvr options
83
84 parser.add_argument("--picoclvr_nb_colors", type=int, default=5)
85
86 parser.add_argument("--picoclvr_height", type=int, default=12)
87
88 parser.add_argument("--picoclvr_width", type=int, default=16)
89
90 parser.add_argument("--picocvlr_prune_properties", type=str, default="none")
91
92 ##############################
93 # Maze options
94
95 parser.add_argument("--maze_height", type=int, default=13)
96
97 parser.add_argument("--maze_width", type=int, default=21)
98
99 parser.add_argument("--maze_nb_walls", type=int, default=15)
100
101 ##############################
102 # Snake options
103
104 parser.add_argument("--snake_height", type=int, default=6)
105
106 parser.add_argument("--snake_width", type=int, default=8)
107
108 parser.add_argument("--snake_nb_colors", type=int, default=5)
109
110 parser.add_argument("--snake_length", type=int, default=200)
111
112 ##############################
113 # Snake options
114
115 parser.add_argument("--stack_nb_steps", type=int, default=100)
116
117 parser.add_argument("--stack_nb_stacks", type=int, default=1)
118
119 parser.add_argument("--stack_nb_digits", type=int, default=3)
120
121 parser.add_argument("--stack_fraction_values_for_train", type=float, default=None)
122
123 ######################################################################
124
125 args = parser.parse_args()
126
127 assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"}
128
129 if args.result_dir is None:
130     args.result_dir = f"results_{args.task}"
131
132 ######################################################################
133
134 default_args = {
135     "picoclvr": {
136         "nb_epochs": 25,
137         "batch_size": 25,
138         "nb_train_samples": 250000,
139         "nb_test_samples": 10000,
140     },
141     "mnist": {
142         "nb_epochs": 25,
143         "batch_size": 10,
144         "nb_train_samples": 250000,
145         "nb_test_samples": 10000,
146     },
147     "maze": {
148         "nb_epochs": 25,
149         "batch_size": 25,
150         "nb_train_samples": 250000,
151         "nb_test_samples": 10000,
152     },
153     "snake": {
154         "nb_epochs": 5,
155         "batch_size": 25,
156         "nb_train_samples": 250000,
157         "nb_test_samples": 10000,
158     },
159     "stack": {
160         "nb_epochs": 5,
161         "batch_size": 25,
162         "nb_train_samples": 100000,
163         "nb_test_samples": 1000,
164     },
165     "expr": {
166         "nb_epochs": 5,
167         "batch_size": 25,
168         "nb_train_samples": 100000,
169         "nb_test_samples": 1000,
170     },
171 }
172
173 if args.task in default_args:
174     for k, v in default_args[args.task].items():
175         if getattr(args, k) is None:
176             setattr(args, k, v)
177
178 ######################################################################
179
180 try:
181     os.mkdir(args.result_dir)
182 except FileExistsError:
183     if not args.overwrite_results:
184         print(f"result directory {args.result_dir} already exists")
185         exit(1)
186
187 log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
188
189 if args.seed >= 0:
190     # torch.backends.cudnn.deterministic = True
191     # torch.backends.cudnn.benchmark = False
192     # torch.use_deterministic_algorithms(True)
193     torch.manual_seed(args.seed)
194     if torch.cuda.is_available():
195         torch.cuda.manual_seed_all(args.seed)
196
197 ######################################################################
198
199
200 def log_string(s):
201     t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
202
203     if log_file is not None:
204         log_file.write(t + s + "\n")
205         log_file.flush()
206
207     print(t + s)
208     sys.stdout.flush()
209
210
211 for n in vars(args):
212     log_string(f"args.{n} {getattr(args, n)}")
213
214 ######################################################################
215
216
217 # ra_mask is boolean, with 1s on the values to generate
218
219
220 def masked_inplace_autoregression(
221     model,
222     batch_size,
223     input,
224     ar_mask,
225     forbidden_tokens=None,
226     progress_bar_desc="autoregression",
227     device=torch.device("cpu"),
228 ):
229     batches = zip(input.split(batch_size), ar_mask.split(batch_size))
230
231     if progress_bar_desc is not None:
232         batches = tqdm.tqdm(
233             batches,
234             dynamic_ncols=True,
235             desc=progress_bar_desc,
236             total=input.size(0) // batch_size,
237         )
238
239     for input, ar_mask in batches:
240         i = (ar_mask.sum(0) > 0).nonzero()
241         if i.min() > 0:
242             model(
243                 mygpt.BracketedSequence(input, 0, i.min())
244             )  # Needed to initialize the model's cache
245         for s in range(i.min(), i.max() + 1):
246             output = model(mygpt.BracketedSequence(input, s, 1)).x
247             logits = output[:, s]
248             if forbidden_tokens is not None:
249                 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
250             if args.deterministic_synthesis:
251                 t_next = logits.argmax(1)
252             else:
253                 dist = torch.distributions.categorical.Categorical(logits=logits)
254                 t_next = dist.sample()
255             input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
256
257
258 ######################################################################
259
260
261 class Task:
262     def batches(self, split="train"):
263         pass
264
265     def vocabulary_size(self):
266         pass
267
268     def produce_results(self, n_epoch, model):
269         pass
270
271
272 ######################################################################
273
274 import picoclvr
275
276
277 class TaskPicoCLVR(Task):
278     # Make a tensor from a list of strings
279     def tensorize(self, descr):
280         token_descr = [s.strip().split(" ") for s in descr]
281         l = max([len(s) for s in token_descr])
282         token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
283         id_descr = [[self.token2id[u] for u in s] for s in token_descr]
284         return torch.tensor(id_descr, device=self.device)
285
286     # Make a list of strings from a tensor
287     def detensorize(self, x):
288         return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
289
290     # trim all the tensors in the tuple z to remove as much token from
291     # left and right in the first tensor. If z is a tuple, all its
292     # elements are trimed according to the triming for the first
293     def trim(self, z, token="<nul>"):
294         n = self.token2id[token]
295         if type(z) == tuple:
296             x = z[0]
297             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
298             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
299             return tuple([t[:, a:b] for t in z])
300         else:
301             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
302             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
303             return z[:, a:b]
304
305     ######################
306     # Not the cleanest part of the code
307
308     # Extract the last image of each sequence, from the last <img>
309     # included, and set to <nul> all the tokens from the beginning of
310     # that image to the end
311     def excise_last_image(self, input):
312         t_img, t_nul = self.token2id["<img>"], self.token2id["<nul>"]
313         nb_img_tokens = self.height * self.width + 1
314
315         input = input.clone()
316         t = (input == t_img).long()
317         tail_masks = (t.cumsum(dim=1) == t.sum(dim=1, keepdim=True)).long()
318         i = (t * tail_masks).nonzero(as_tuple=True)
319         j = (
320             i[0][:, None],
321             i[1][:, None] + torch.arange(nb_img_tokens, device=input.device)[None, :],
322         )
323         images = self.trim(input[j])
324         input[j] = t_nul
325         loss_masks = 1 - tail_masks
326         input, loss_masks = self.trim((input, loss_masks))
327         return input, loss_masks, images
328
329     def add_true_image(self, input, images, loss_masks):
330         t_nul = self.token2id["<nul>"]
331         nb_img_tokens = self.height * self.width + 1
332         input = F.pad(input, (0, nb_img_tokens), value=t_nul)
333         loss_masks = F.pad(loss_masks, (0, nb_img_tokens), value=0)
334         t = (input == t_nul).long()
335         i = (t.cumsum(dim=1) == 1).nonzero(as_tuple=True)
336         j = (
337             i[0][:, None],
338             i[1][:, None] + torch.arange(nb_img_tokens, device=input.device)[None, :],
339         )
340         input[j] = images
341         loss_masks[j] = 1
342         input, loss_masks = self.trim((input, loss_masks))
343         return input, loss_masks
344
345     def add_generated_image(self, input, loss_masks, model):
346         t_img, t_nul = self.token2id["<img>"], self.token2id["<nul>"]
347         nb_img_tokens = self.height * self.width + 1
348
349         input = F.pad(input, (0, nb_img_tokens), value=t_nul)
350         loss_masks = F.pad(loss_masks, (0, nb_img_tokens), value=0)
351         t = (input == t_nul).long()
352         i = (t.cumsum(dim=1) == 1).nonzero(as_tuple=True)
353         input[i] = t_img
354
355         j = (
356             i[0][:, None],
357             i[1][:, None]
358             + 1
359             + torch.arange(nb_img_tokens - 1, device=input.device)[None, :],
360         )
361         ar_masks = input.new_zeros(input.size(), dtype=torch.int64)
362         ar_masks[j] = 1
363         forbidden_tokens = (
364             torch.arange(self.vocabulary_size(), device=input.device) == t_nul
365         )
366         with torch.autograd.no_grad():
367             t = model.training
368             model.eval()
369             masked_inplace_autoregression(
370                 model,
371                 self.batch_size,
372                 input,
373                 ar_masks,
374                 forbidden_tokens,
375                 progress_bar_desc=None,
376                 device=self.device,
377             )
378             model.train(t)
379
380         input, loss_masks = self.trim((input, loss_masks))
381
382         return input, loss_masks
383
384     ######################
385
386     def __init__(
387         self,
388         nb_train_samples,
389         nb_test_samples,
390         batch_size,
391         height,
392         width,
393         nb_colors=5,
394         device=torch.device("cpu"),
395         pruner_train=None,
396         pruner_eval=None,
397     ):
398         def generate_descr(nb, cache_suffix, pruner):
399             return picoclvr.generate(
400                 nb,
401                 height=self.height,
402                 width=self.width,
403                 nb_colors=nb_colors,
404                 pruner=pruner,
405             )
406
407         self.height = height
408         self.width = width
409         self.batch_size = batch_size
410         self.device = device
411         self.pruner_train = pruner_train
412         self.pruner_eval = pruner_eval
413
414         param = {
415             "nb_train_samples": nb_train_samples,
416             "nb_test_samples": nb_test_samples,
417             "height": height,
418             "width": width,
419             "nb_colors": nb_colors,
420             "batch_size": batch_size,
421             "rng_state": list(torch.get_rng_state()),
422         }
423
424         log_string(
425             f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
426         )
427         self.train_descr = generate_descr(
428             nb_train_samples, "train", pruner=self.pruner_train
429         )
430         self.test_descr = generate_descr(nb_test_samples, "test", pruner=None)
431
432         # Build the tokenizer
433         tokens = {"<nul>", "<img>"}
434         for d in [self.train_descr, self.test_descr]:
435             for s in d:
436                 for t in s.strip().split(" "):
437                     tokens.add(t)
438         # make this set a sorted list to get the same tensors given
439         # the same descr
440         tokens = list(tokens)
441         tokens.sort()
442         self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
443         self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
444
445         # Tokenize the train and test sets
446         self.train_input = self.tensorize(self.train_descr)
447         self.test_input = self.tensorize(self.test_descr)
448
449     def batches(self, split="train"):
450         assert split in {"train", "test"}
451         input = self.train_input if split == "train" else self.test_input
452         for batch in tqdm.tqdm(
453             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
454         ):
455             yield self.trim(batch)
456
457     def vocabulary_size(self):
458         return len(self.token2id)
459
460     def compute_missing_properties(self, n_epoch, model, pruner=None):
461         acc_nb_requested_properties = []
462         acc_nb_missing_properties = []
463         acc_nb_results = 0
464
465         for input in tqdm.tqdm(
466             self.test_input.split(self.batch_size),
467             dynamic_ncols=True,
468             desc=f"test-properties",
469         ):
470             tape, loss_masks, _ = self.excise_last_image(input)
471             tape, loss_masks = self.add_generated_image(tape, loss_masks, model)
472             result_descr = self.detensorize(tape)
473             np = picoclvr.nb_properties(
474                 result_descr,
475                 height=self.height,
476                 width=self.width,
477                 pruner=pruner,
478             )
479             nb_requested_properties, _, nb_missing_properties = zip(*np)
480             acc_nb_requested_properties += nb_requested_properties
481             acc_nb_missing_properties += nb_missing_properties
482             acc_nb_results += len(result_descr)
483
484         nb_requested_properties = sum(acc_nb_requested_properties)
485         nb_missing_properties = sum(acc_nb_missing_properties)
486
487         prefix = "" if pruner is None else "pruned_"
488         log_string(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
489         log_string(
490             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
491         )
492         log_string(
493             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
494         )
495
496     ######################################################################
497
498     def produce_results(self, n_epoch, model):
499         self.compute_missing_properties(n_epoch, model)
500
501         if self.pruner_eval is not None:
502             self.compute_missing_properties(n_epoch, model, self.pruner_eval)
503
504         nb_tokens_to_generate = self.height * self.width + 3
505         result_descr = []
506         nb_per_primer = 8
507         primer = []
508
509         for primer_descr in [
510             "red above green <sep> green top <sep> blue right of red",
511             "there is red <sep> there is yellow <sep> there is blue",
512             "red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left",
513             "green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top",
514         ]:
515             primer += [primer_descr] * nb_per_primer
516
517         tape = self.tensorize(primer)
518         loss_masks = 1 - (tape == self.token2id["<nul>"]).long()
519         tape, loss_masks = self.add_generated_image(tape, loss_masks, model)
520         result_descr = self.detensorize(tape)
521
522         np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
523
524         acc_nb_requested_properties, _, acc_nb_missing_properties = zip(*np)
525         acc_nb_results = len(result_descr)
526
527         nb_requested_properties = sum(acc_nb_requested_properties)
528         nb_missing_properties = sum(acc_nb_missing_properties)
529
530         prefix = "demo_"
531         log_string(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
532         log_string(
533             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
534         )
535         log_string(
536             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
537         )
538
539         img = picoclvr.descr2img(result_descr, height=self.height, width=self.width)
540
541         if img.dim() == 5:
542             if img.size(1) == 1:
543                 img = F.pad(img.squeeze(1), pad=(1, 1, 1, 1), value=64)
544             else:
545                 img = torch.cat(
546                     [
547                         torchvision.utils.make_grid(x, padding=1, pad_value=64)[None]
548                         for x in img
549                     ],
550                     0,
551                 )
552
553         image_name = os.path.join(args.result_dir, f"picoclvr_result_{n_epoch:04d}.png")
554         torchvision.utils.save_image(
555             img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=0.0
556         )
557         log_string(f"wrote {image_name}")
558
559
560 ######################################################################
561
562
563 class TaskMNIST(Task):
564     def __init__(self, batch_size, device=torch.device("cpu")):
565         self.device = device
566         self.batch_size = batch_size
567
568     def batches(self, split="train"):
569         assert split in {"train", "test"}
570         data_set = torchvision.datasets.MNIST(
571             root="./data", train=(split == "train"), download=True
572         )
573         data_input = data_set.data.view(-1, 28 * 28).long()
574         if args.nb_train_samples is not None:
575             data_input = data_input[: args.nb_train_samples]
576         for batch in tqdm.tqdm(
577             data_input.split(self.batch_size), desc=f"epoch-{split}"
578         ):
579             yield batch
580
581     def vocabulary_size(self):
582         return 256
583
584     def produce_results(self, n_epoch, model):
585         results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64)
586         ar_mask = torch.full_like(results, 1)
587         masked_inplace_autoregression(
588             model, self.batch_size, results, ar_mask, device=self.device
589         )
590         image_name = os.path.join(args.result_dir, f"mnist_result_{n_epoch:04d}.png")
591         torchvision.utils.save_image(
592             1 - results.reshape(-1, 1, 28, 28) / 255.0,
593             image_name,
594             nrow=16,
595             pad_value=0.8,
596         )
597         log_string(f"wrote {image_name}")
598
599
600 ######################################################################
601
602 import maze
603
604
605 class TaskMaze(Task):
606     def map2seq(self, *m):
607         return torch.cat([x.flatten(1) for x in m], 1)
608
609     def seq2map(self, s):
610         s = s.reshape(s.size(0), -1, self.height, self.width)
611         return (s[:, k] for k in range(s.size(1)))
612
613     def __init__(
614         self,
615         nb_train_samples,
616         nb_test_samples,
617         batch_size,
618         height,
619         width,
620         nb_walls,
621         device=torch.device("cpu"),
622     ):
623         self.batch_size = batch_size
624         self.height = height
625         self.width = width
626         self.device = device
627
628         train_mazes, train_paths, _ = maze.create_maze_data(
629             nb_train_samples,
630             height=height,
631             width=width,
632             nb_walls=nb_walls,
633             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"),
634         )
635         self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device))
636
637         test_mazes, test_paths, _ = maze.create_maze_data(
638             nb_test_samples,
639             height=height,
640             width=width,
641             nb_walls=nb_walls,
642             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"),
643         )
644         self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
645
646         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
647
648     def batches(self, split="train", nb_to_use=-1, desc=None):
649         assert split in {"train", "test"}
650         input = self.train_input if split == "train" else self.test_input
651         if nb_to_use > 0:
652             input = input[:nb_to_use]
653         if desc is None:
654             desc = f"epoch-{split}"
655         for batch in tqdm.tqdm(
656             input.split(self.batch_size), dynamic_ncols=True, desc=desc
657         ):
658             yield batch
659
660     def vocabulary_size(self):
661         return self.nb_codes
662
663     def compute_error(self, model, split="train", nb_to_use=-1):
664         nb_total, nb_correct = 0, 0
665         count = torch.zeros(
666             self.width * self.height,
667             self.width * self.height,
668             device=self.device,
669             dtype=torch.int64,
670         )
671         for input in tqdm.tqdm(
672             task.batches(split, nb_to_use),
673             dynamic_ncols=True,
674             desc=f"test-mazes",
675         ):
676             result = input.clone()
677             ar_mask = result.new_zeros(result.size())
678             ar_mask[:, self.height * self.width :] = 1
679             result *= 1 - ar_mask
680             masked_inplace_autoregression(
681                 model,
682                 self.batch_size,
683                 result,
684                 ar_mask,
685                 progress_bar_desc=None,
686                 device=self.device,
687             )
688             mazes, paths = self.seq2map(result)
689             path_correctness = maze.path_correctness(mazes, paths)
690             nb_correct += path_correctness.long().sum()
691             nb_total += mazes.size(0)
692
693             optimal_path_lengths = (
694                 (input[:, self.height * self.width :] == maze.v_path).long().sum(1)
695             )
696             predicted_path_lengths = (
697                 (result[:, self.height * self.width :] == maze.v_path).long().sum(1)
698             )
699             optimal_path_lengths = optimal_path_lengths[path_correctness]
700             predicted_path_lengths = predicted_path_lengths[path_correctness]
701             count[optimal_path_lengths, predicted_path_lengths] += 1
702
703         if count.max() == 0:
704             count = None
705         else:
706             count = count[
707                 : count.sum(1).nonzero().max() + 1, : count.sum(0).nonzero().max() + 1
708             ]
709
710         return nb_total, nb_correct, count
711
712     def produce_results(self, n_epoch, model):
713         with torch.autograd.no_grad():
714             t = model.training
715             model.eval()
716
717             train_nb_total, train_nb_correct, count = self.compute_error(
718                 model, "train", nb_to_use=1000
719             )
720             log_string(
721                 f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
722             )
723
724             test_nb_total, test_nb_correct, count = self.compute_error(
725                 model, "test", nb_to_use=1000
726             )
727             log_string(
728                 f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
729             )
730
731             if count is not None:
732                 proportion_optimal = count.diagonal().sum().float() / count.sum()
733                 log_string(f"proportion_optimal_test {proportion_optimal*100:.02f}%")
734                 with open(
735                     os.path.join(args.result_dir, f"maze_result_{n_epoch:04d}.txt"), "w"
736                 ) as f:
737                     for i in range(count.size(0)):
738                         for j in range(count.size(1)):
739                             eol = " " if j < count.size(1) - 1 else "\n"
740                             f.write(f"{count[i,j]}{eol}")
741
742             input = self.test_input[:48]
743             result = input.clone()
744             ar_mask = result.new_zeros(result.size())
745             ar_mask[:, self.height * self.width :] = 1
746             result *= 1 - ar_mask
747             masked_inplace_autoregression(
748                 model, self.batch_size, result, ar_mask, device=self.device
749             )
750
751             mazes, paths = self.seq2map(input)
752             _, predicted_paths = self.seq2map(result)
753
754             filename = os.path.join(args.result_dir, f"maze_result_{n_epoch:04d}.png")
755             maze.save_image(
756                 filename,
757                 mazes=mazes,
758                 target_paths=paths,
759                 predicted_paths=predicted_paths,
760                 path_correct=maze.path_correctness(mazes, predicted_paths),
761                 path_optimal=maze.path_optimality(paths, predicted_paths),
762             )
763             log_string(f"wrote {filename}")
764
765             model.train(t)
766
767
768 ######################################################################
769
770
771 import snake
772
773
774 class TaskSnake(Task):
775     def __init__(
776         self,
777         nb_train_samples,
778         nb_test_samples,
779         batch_size,
780         height,
781         width,
782         nb_colors,
783         length,
784         prompt_length,
785         device=torch.device("cpu"),
786     ):
787         self.batch_size = batch_size
788         self.height = height
789         self.width = width
790         self.device = device
791         self.prompt_length = prompt_length
792
793         self.train_input, self.train_prior_visits, _, _ = snake.generate_sequences(
794             nb_train_samples,
795             height,
796             width,
797             nb_colors,
798             length,
799             prompt_length,
800             self.device,
801         )
802         self.test_input, self.test_prior_visits, _, _ = snake.generate_sequences(
803             nb_test_samples,
804             height,
805             width,
806             nb_colors,
807             length,
808             prompt_length,
809             self.device,
810         )
811
812         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
813
814     def batches(self, split="train", nb_to_use=-1, desc=None):
815         assert split in {"train", "test"}
816         input = self.train_input if split == "train" else self.test_input
817         if nb_to_use > 0:
818             input = input[:nb_to_use]
819         if desc is None:
820             desc = f"epoch-{split}"
821         for batch in tqdm.tqdm(
822             input.split(self.batch_size), dynamic_ncols=True, desc=desc
823         ):
824             yield batch
825
826     def vocabulary_size(self):
827         return self.nb_codes
828
829     def produce_results(self, n_epoch, model):
830         with torch.autograd.no_grad():
831             t = model.training
832             model.eval()
833
834             def compute_nb_correct(input, prior_visits):
835                 result = input.clone()
836                 i = torch.arange(result.size(1), device=result.device)[None, :]
837                 ar_mask = (
838                     torch.logical_and(i >= self.prompt_length * 2, i % 2 == 0)
839                     .long()
840                     .expand_as(result)
841                 )
842                 result *= 1 - ar_mask
843
844                 # snake.solver(result,ar_mask)
845
846                 masked_inplace_autoregression(
847                     model, self.batch_size, result, ar_mask, device=self.device
848                 )
849
850                 nb_total = ((prior_visits > 0) * ar_mask).sum()
851
852                 nb_correct = (
853                     (result == input).long() * (prior_visits > 0) * ar_mask
854                 ).sum()
855
856                 # nb_total = result.size(0)
857                 # nb_correct = ((result - input).abs().sum(1) == 0).sum()
858
859                 return nb_total, nb_correct
860
861             # train_nb_total, train_nb_correct = compute_nb_correct(
862             # self.train_input, self.train_prior_visits
863             # )
864
865             # log_string(
866             # f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
867             # )
868
869             test_nb_total, test_nb_correct = compute_nb_correct(
870                 self.test_input[:1000], self.test_prior_visits[:1000]
871             )
872
873             log_string(
874                 f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
875             )
876
877             model.train(t)
878
879
880 ######################################################################
881
882
883 import stack
884
885
886 class TaskStack(Task):
887     def __init__(
888         self,
889         nb_train_samples,
890         nb_test_samples,
891         batch_size,
892         nb_steps,
893         nb_stacks,
894         nb_digits,
895         fraction_values_for_train=None,
896         device=torch.device("cpu"),
897     ):
898         self.batch_size = batch_size
899         self.nb_steps = nb_steps
900         self.nb_stacks = nb_stacks
901         self.nb_digits = nb_digits
902         self.device = device
903
904         if fraction_values_for_train is None:
905             values_for_train = None
906             values_for_test = None
907         else:
908             all = torch.randperm(10**nb_digits)
909             nb_for_train = int(all.size(0) * fraction_values_for_train)
910             values_for_train = all[:nb_for_train]
911             values_for_test = all[nb_for_train:]
912
913         self.train_input, self.train_stack_counts = stack.generate_sequences(
914             nb_train_samples,
915             nb_steps,
916             nb_stacks,
917             nb_digits,
918             values_for_train,
919             self.device,
920         )
921
922         self.test_input, self.test_stack_counts = stack.generate_sequences(
923             nb_test_samples,
924             nb_steps,
925             nb_stacks,
926             nb_digits,
927             values_for_test,
928             self.device,
929         )
930
931         i = torch.logical_and(self.test_input % 2 == 1, self.test_input < 2 * nb_stacks)
932         counts = self.test_stack_counts.flatten()[i.flatten()]
933         counts = F.one_hot(counts).sum(0)
934         log_string(f"test_pop_stack_counts {counts}")
935
936         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
937
938     def batches(self, split="train", nb_to_use=-1, desc=None):
939         assert split in {"train", "test"}
940         input = self.train_input if split == "train" else self.test_input
941         if nb_to_use > 0:
942             input = input[:nb_to_use]
943         if desc is None:
944             desc = f"epoch-{split}"
945         for batch in tqdm.tqdm(
946             input.split(self.batch_size), dynamic_ncols=True, desc=desc
947         ):
948             yield batch
949
950     def vocabulary_size(self):
951         return self.nb_codes
952
953     def produce_results(self, n_epoch, model):
954         with torch.autograd.no_grad():
955             t = model.training
956             model.eval()
957
958             def compute_nb_correct(input):
959                 result = input.clone()
960                 stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
961                 ar_mask = (result != input).long()
962                 masked_inplace_autoregression(
963                     model, self.batch_size, result, ar_mask, device=self.device
964                 )
965
966                 errors = ((result != input).long() * ar_mask).reshape(
967                     -1, 1 + self.nb_digits
968                 )
969                 ar_mask = ar_mask.reshape(-1, 1 + self.nb_digits)
970
971                 nb_total = ar_mask.max(1).values.sum()
972                 nb_correct = nb_total - errors.max(1).values.sum()
973
974                 return nb_total, nb_correct
975
976             test_nb_total, test_nb_correct = compute_nb_correct(self.test_input[:1000])
977
978             log_string(
979                 f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
980             )
981
982             ##############################################################
983             # Log a few generated sequences
984             input = self.test_input[:10, : 12 * (1 + self.nb_digits)]
985             result = input.clone()
986             stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
987             ar_mask = (result != input).long()
988             for n in range(result.size(0)):
989                 log_string(
990                     f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
991                 )
992             masked_inplace_autoregression(
993                 model, self.batch_size, result, ar_mask, device=self.device
994             )
995             for n in range(result.size(0)):
996                 log_string(
997                     f"test_after  {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
998                 )
999             ##############################################################
1000
1001             model.train(t)
1002
1003
1004 ######################################################################
1005
1006
1007 import expr
1008
1009
1010 class TaskExpr(Task):
1011     def __init__(
1012         self,
1013         nb_train_samples,
1014         nb_test_samples,
1015         batch_size,
1016         device=torch.device("cpu"),
1017     ):
1018         self.batch_size = batch_size
1019         self.device = device
1020
1021         train_sequences = expr.generate_sequences(nb_train_samples)
1022         test_sequences = expr.generate_sequences(nb_test_samples)
1023         self.char2id = dict(
1024             [
1025                 (c, n)
1026                 for n, c in enumerate(set("#"+"".join(train_sequences + test_sequences)))
1027             ]
1028         )
1029         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
1030         len_max = max([len(x) for x in train_sequences + test_sequences])
1031         self.train_input = torch.cat(
1032             [
1033                 torch.tensor(
1034                     [
1035                         [self.char2id[c] for c in s + "#" * (len_max - len(s))]
1036                         for s in train_sequences
1037                     ]
1038                 )
1039             ],
1040             0,
1041         ).to(device)
1042         self.test_input = torch.cat(
1043             [
1044                 torch.tensor(
1045                     [
1046                         [self.char2id[c] for c in s + "#" * (len_max - len(s))]
1047                         for s in test_sequences
1048                     ]
1049                 )
1050             ],
1051             0,
1052         ).to(device)
1053         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1054
1055     def batches(self, split="train", nb_to_use=-1, desc=None):
1056         assert split in {"train", "test"}
1057         input = self.train_input if split == "train" else self.test_input
1058         if nb_to_use > 0:
1059             input = input[:nb_to_use]
1060         if desc is None:
1061             desc = f"epoch-{split}"
1062         for batch in tqdm.tqdm(
1063             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1064         ):
1065             yield batch
1066
1067     def vocabulary_size(self):
1068         return self.nb_codes
1069
1070     def produce_results(self, n_epoch, model):
1071         with torch.autograd.no_grad():
1072             t = model.training
1073             model.eval()
1074
1075             def compute_nb_correct(input):
1076                 result = input.clone()
1077                 space = self.char2id["#"]
1078                 ar_mask = (result == space).long().cumsum(dim=1).clamp(max=1)
1079                 result = (1 - ar_mask) * result + space * ar_mask
1080                 masked_inplace_autoregression(
1081                     model, self.batch_size, result, ar_mask, device=self.device
1082                 )
1083
1084                 nb_total = ar_mask.sum()
1085                 nb_correct = ((input == result).long() * ar_mask).sum()
1086
1087                 return nb_total, nb_correct
1088
1089             test_nb_total, test_nb_correct = compute_nb_correct(self.test_input[:1000])
1090
1091             log_string(
1092                 f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
1093             )
1094
1095             ##############################################################
1096             # Log a few generated sequences
1097             input = self.test_input[:10]
1098             result = input.clone()
1099             space = self.char2id["#"]
1100             ar_mask = (result == space).long().cumsum(dim=1).clamp(max=1)
1101             result = (1 - ar_mask) * result + space * ar_mask
1102             for n in range(result.size(0)):
1103                 s = "".join([self.id2char[k.item()] for k in result[n]])
1104                 log_string(f"test_before {s}")
1105             masked_inplace_autoregression(
1106                 model, self.batch_size, result, ar_mask, device=self.device
1107             )
1108             for n in range(result.size(0)):
1109                 s = "".join([self.id2char[k.item()] for k in result[n]])
1110                 log_string(f"test_after  {s}")
1111             ##############################################################
1112
1113             model.train(t)
1114
1115
1116 ######################################################################
1117
1118
1119 def picoclvr_pruner_horizontal_green(p):
1120     return not ("green" in p and ("left" in p or "right" in p))
1121
1122
1123 picoclvr_pruner_train = (
1124     picoclvr_pruner_horizontal_green
1125     if args.picocvlr_prune_properties in {"train+eval"}
1126     else None
1127 )
1128
1129 picoclvr_pruner_eval = (
1130     (lambda p: not picoclvr_pruner_horizontal_green(p))
1131     if args.picocvlr_prune_properties in {"train+eval", "eval"}
1132     else None
1133 )
1134
1135 ######################################################################
1136
1137 if args.task == "picoclvr":
1138     task = TaskPicoCLVR(
1139         nb_train_samples=args.nb_train_samples,
1140         nb_test_samples=args.nb_test_samples,
1141         batch_size=args.batch_size,
1142         height=args.picoclvr_height,
1143         width=args.picoclvr_width,
1144         nb_colors=args.picoclvr_nb_colors,
1145         device=device,
1146         pruner_train=picoclvr_pruner_train,
1147         pruner_eval=picoclvr_pruner_eval,
1148     )
1149
1150 elif args.task == "mnist":
1151     task = TaskMNIST(
1152         batch_size=args.batch_size,
1153         device=device,
1154     )
1155
1156 elif args.task == "maze":
1157     task = TaskMaze(
1158         nb_train_samples=args.nb_train_samples,
1159         nb_test_samples=args.nb_test_samples,
1160         batch_size=args.batch_size,
1161         height=args.maze_height,
1162         width=args.maze_width,
1163         nb_walls=args.maze_nb_walls,
1164         device=device,
1165     )
1166
1167 elif args.task == "snake":
1168     task = TaskSnake(
1169         nb_train_samples=args.nb_train_samples,
1170         nb_test_samples=args.nb_test_samples,
1171         batch_size=args.batch_size,
1172         height=args.snake_height,
1173         width=args.snake_width,
1174         nb_colors=args.snake_nb_colors,
1175         length=args.snake_length,
1176         prompt_length=args.snake_length // 2,
1177         device=device,
1178     )
1179
1180 elif args.task == "stack":
1181     task = TaskStack(
1182         nb_train_samples=args.nb_train_samples,
1183         nb_test_samples=args.nb_test_samples,
1184         batch_size=args.batch_size,
1185         nb_steps=args.stack_nb_steps,
1186         nb_stacks=args.stack_nb_stacks,
1187         nb_digits=args.stack_nb_digits,
1188         fraction_values_for_train=args.stack_fraction_values_for_train,
1189         device=device,
1190     )
1191
1192 elif args.task == "expr":
1193     task = TaskExpr(
1194         nb_train_samples=args.nb_train_samples,
1195         nb_test_samples=args.nb_test_samples,
1196         batch_size=args.batch_size,
1197         device=device,
1198     )
1199
1200 else:
1201     raise ValueError(f"Unknown task {args.task}")
1202
1203 ######################################################################
1204
1205 log_string(f"device {device}")
1206
1207 vocabulary_size = task.vocabulary_size()
1208
1209 log_string(f"vocabulary_size {vocabulary_size}")
1210
1211 ##############################
1212
1213 model = mygpt.MyGPT(
1214     vocabulary_size=vocabulary_size,
1215     dim_model=args.dim_model,
1216     dim_keys=args.dim_keys,
1217     dim_hidden=args.dim_hidden,
1218     nb_heads=args.nb_heads,
1219     nb_blocks=args.nb_blocks,
1220     causal=True,
1221     dropout=args.dropout,
1222 )
1223
1224 model.to(device)
1225
1226 nb_parameters = sum(p.numel() for p in model.parameters())
1227 log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
1228
1229 ######################################################################
1230
1231 nb_epochs_finished = 0
1232
1233 if args.no_checkpoint:
1234     log_string(f"not trying to load checkpoint.")
1235
1236 else:
1237     try:
1238         checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
1239         checkpoint = torch.load(checkpoint_name)
1240         nb_epochs_finished = checkpoint["nb_epochs_finished"]
1241         model.load_state_dict(checkpoint["model_state"])
1242         torch.set_rng_state(checkpoint["rng_state"])
1243         if torch.cuda.is_available():
1244             torch.cuda.set_rng_state(checkpoint["cuda_rng_state"])
1245
1246         log_string(f"checkpoint loaded with {nb_epochs_finished} epochs finished.")
1247
1248     except FileNotFoundError:
1249         log_string("starting from scratch.")
1250
1251     except:
1252         log_string("error when loading the checkpoint.")
1253         exit(1)
1254
1255 ######################################################################
1256
1257 nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
1258
1259 token_count = 0
1260 for input in task.batches(split="train"):
1261     token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1))
1262 token_probas = token_count / token_count.sum()
1263 entropy = -torch.xlogy(token_probas, token_probas).sum()
1264 train_set_perplexity = math.exp(entropy)
1265
1266 ##############################
1267
1268 if args.learning_rate_schedule == "cos":
1269     learning_rate_schedule = {}
1270     for n_epoch in range(args.nb_epochs):
1271         u = n_epoch / args.nb_epochs * math.pi
1272         learning_rate_schedule[n_epoch] = args.learning_rate * 0.5 * (1 + math.cos(u))
1273 else:
1274     u = {
1275         int(k): float(v)
1276         for k, v in [
1277             tuple(x.split(":")) for x in args.learning_rate_schedule.split(",")
1278         ]
1279     }
1280
1281     learning_rate_schedule = {}
1282     learning_rate = args.learning_rate
1283     for n_epoch in range(args.nb_epochs):
1284         if n_epoch in u:
1285             learning_rate = u[n_epoch]
1286         learning_rate_schedule[n_epoch] = learning_rate
1287
1288 log_string(f"learning_rate_schedule {learning_rate_schedule}")
1289
1290 ##############################
1291
1292 nb_samples_seen = 0
1293
1294 if nb_epochs_finished >= nb_epochs:
1295     task.produce_results(nb_epochs_finished, model)
1296
1297 for n_epoch in range(nb_epochs_finished, nb_epochs):
1298     learning_rate = learning_rate_schedule[n_epoch]
1299
1300     log_string(f"learning_rate {learning_rate}")
1301
1302     if args.optim == "sgd":
1303         optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
1304     elif args.optim == "adam":
1305         optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
1306     elif args.optim == "adamw":
1307         optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
1308     else:
1309         raise ValueError(f"Unknown optimizer {args.optim}.")
1310
1311     model.train()
1312
1313     nb_train_samples, acc_train_loss = 0, 0.0
1314
1315     for input in task.batches(split="train"):
1316         input = input.to(device)
1317         output = model(mygpt.BracketedSequence(input)).x
1318         loss = F.cross_entropy(output.transpose(1, 2), input)
1319         acc_train_loss += loss.item() * input.size(0)
1320         nb_train_samples += input.size(0)
1321         nb_samples_seen += input.size(0)
1322
1323         optimizer.zero_grad()
1324         loss.backward()
1325         optimizer.step()
1326
1327     with torch.autograd.no_grad():
1328         model.eval()
1329
1330         nb_test_samples, acc_test_loss = 0, 0.0
1331
1332         for input in task.batches(split="test"):
1333             input = input.to(device)
1334
1335             output = model(mygpt.BracketedSequence(input)).x
1336             loss = F.cross_entropy(output.transpose(1, 2), input)
1337             acc_test_loss += loss.item() * input.size(0)
1338             nb_test_samples += input.size(0)
1339
1340         train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
1341         test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
1342
1343         log_string(
1344             f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
1345         )
1346
1347         task.produce_results(n_epoch, model)
1348
1349     checkpoint = {
1350         "nb_epochs_finished": n_epoch + 1,
1351         "model_state": model.state_dict(),
1352         "rng_state": torch.get_rng_state(),
1353     }
1354
1355     if torch.cuda.is_available():
1356         checkpoint["cuda_rng_state"] = torch.cuda.get_rng_state()
1357
1358     checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
1359     torch.save(checkpoint, checkpoint_name)
1360     log_string(f"saved checkpoint {checkpoint_name}")
1361
1362 ######################################################################