Update.
[picoclvr.git] / main.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # torch.backends.cuda.matmul.allow_tf23
9 # torch.autocast(torch.bfloat16)
10
11 import math, sys, argparse, time, tqdm, os
12
13 import torch, torchvision
14 from torch import nn
15 from torch.nn import functional as F
16
17 import mygpt, tensorstack
18
19 ######################################################################
20
21 if torch.cuda.is_available():
22     device = torch.device("cuda")
23     torch.backends.cuda.matmul.allow_tf32 = True
24 else:
25     device = torch.device("cpu")
26
27 ######################################################################
28
29 parser = argparse.ArgumentParser(
30     description="An implementation of GPT with cache.",
31     formatter_class=argparse.ArgumentDefaultsHelpFormatter,
32 )
33
34 parser.add_argument("--task", type=str, default="picoclvr")
35
36 parser.add_argument("--log_filename", type=str, default="train.log")
37
38 parser.add_argument("--result_dir", type=str, default="results_default")
39
40 parser.add_argument("--seed", type=int, default=0)
41
42 parser.add_argument("--nb_epochs", type=int, default=25)
43
44 parser.add_argument("--batch_size", type=int, default=None)
45
46 parser.add_argument("--nb_train_samples", type=int, default=250000)
47
48 parser.add_argument("--nb_test_samples", type=int, default=10000)
49
50 parser.add_argument("--optim", type=str, default="adam")
51
52 parser.add_argument("--learning_rate", type=float, default=1e-4)
53
54 parser.add_argument("--learning_rate_schedule", type=str, default="10: 2e-5,30: 4e-6")
55
56 parser.add_argument("--dim_model", type=int, default=512)
57
58 parser.add_argument("--dim_keys", type=int, default=64)
59
60 parser.add_argument("--dim_hidden", type=int, default=2048)
61
62 parser.add_argument("--nb_heads", type=int, default=8)
63
64 parser.add_argument("--nb_blocks", type=int, default=12)
65
66 parser.add_argument("--dropout", type=float, default=0.1)
67
68 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
69
70 parser.add_argument("--no_checkpoint", action="store_true", default=False)
71
72 parser.add_argument("--overwrite_results", action="store_true", default=False)
73
74 parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
75
76 ##############################
77 # picoclvr options
78
79 parser.add_argument("--picoclvr_nb_colors", type=int, default=5)
80
81 parser.add_argument("--picoclvr_height", type=int, default=12)
82
83 parser.add_argument("--picoclvr_width", type=int, default=16)
84
85 parser.add_argument("--picocvlr_prune_properties", type=str, default="none")
86
87 ##############################
88 # Maze options
89
90 parser.add_argument("--maze_height", type=int, default=13)
91
92 parser.add_argument("--maze_width", type=int, default=21)
93
94 parser.add_argument("--maze_nb_walls", type=int, default=15)
95
96 ##############################
97 # Snake options
98
99 parser.add_argument("--snake_height", type=int, default=6)
100
101 parser.add_argument("--snake_width", type=int, default=8)
102
103 parser.add_argument("--snake_nb_colors", type=int, default=3)
104
105 parser.add_argument("--snake_length", type=int, default=400)
106
107 ######################################################################
108
109 args = parser.parse_args()
110
111 assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"}
112
113 try:
114     os.mkdir(args.result_dir)
115 except FileExistsError:
116     if not args.overwrite_results:
117         print(f"result directory {args.result_dir} already exists")
118         exit(1)
119
120 log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
121
122 if args.seed >= 0:
123     # torch.backends.cudnn.deterministic = True
124     # torch.backends.cudnn.benchmark = False
125     # torch.use_deterministic_algorithms(True)
126     torch.manual_seed(args.seed)
127     if torch.cuda.is_available():
128         torch.cuda.manual_seed_all(args.seed)
129
130 ######################################################################
131
132 default_args = {
133     "picoclvr": {
134         "batch_size": 25,
135     },
136     "mnist": {
137         "batch_size": 10,
138     },
139     "maze": {
140         "batch_size": 25,
141     },
142     "snake": {
143         "batch_size": 20,
144     },
145 }
146
147 if args.task in default_args:
148     for k, v in default_args[args.task].items():
149         if getattr(args, k) is None:
150             setattr(args, k, v)
151
152 ######################################################################
153
154
155 def log_string(s):
156     t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
157
158     if log_file is not None:
159         log_file.write(t + s + "\n")
160         log_file.flush()
161
162     print(t + s)
163     sys.stdout.flush()
164
165
166 for n in vars(args):
167     log_string(f"args.{n} {getattr(args, n)}")
168
169 ######################################################################
170
171
172 def masked_inplace_autoregression(
173     model, batch_size, input, ar_mask, forbidden_tokens=None, device=torch.device("cpu")
174 ):
175     for input, ar_mask in zip(input.split(batch_size), ar_mask.split(batch_size)):
176         i = (ar_mask.sum(0) > 0).nonzero()
177         if i.min() > 0:
178             model(
179                 mygpt.BracketedSequence(input, 0, i.min())
180             )  # Needed to initialize the model's cache
181         for s in range(i.min(), i.max() + 1):
182             output = model(mygpt.BracketedSequence(input, s, 1)).x
183             logits = output[:, s]
184             if forbidden_tokens is not None:
185                 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
186             if args.deterministic_synthesis:
187                 t_next = logits.argmax(1)
188             else:
189                 dist = torch.distributions.categorical.Categorical(logits=logits)
190                 t_next = dist.sample()
191             input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
192
193
194 ######################################################################
195
196
197 class Task:
198     def batches(self, split="train"):
199         pass
200
201     def vocabulary_size(self):
202         pass
203
204     def produce_results(self, n_epoch, model):
205         pass
206
207
208 ######################################################################
209
210 import picoclvr
211
212
213 class TaskPicoCLVR(Task):
214     # Make a tensor from a list of strings
215     def tensorize(self, descr):
216         token_descr = [s.strip().split(" ") for s in descr]
217         l = max([len(s) for s in token_descr])
218         token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
219         id_descr = [[self.token2id[u] for u in s] for s in token_descr]
220         return torch.tensor(id_descr, device=self.device)
221
222     # Make a list of strings from a tensor
223     def detensorize(self, x):
224         return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
225
226     # trim all the tensors in the tuple z to remove as much token from
227     # left and right in the first tensor. If z is a tuple, all its
228     # elements are trimed according to the triming for the first
229     def trim(self, z, token="<nul>"):
230         n = self.token2id[token]
231         if type(z) == tuple:
232             x = z[0]
233             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
234             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
235             return tuple([t[:, a:b] for t in z])
236         else:
237             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
238             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
239             return z[:, a:b]
240
241     ######################
242     # Not the cleanest part of the code
243
244     # Extract the last image of each sequence, from the last <img>
245     # included, and set to <nul> all the tokens from the beginning of
246     # that image to the end
247     def excise_last_image(self, input):
248         t_img, t_nul = self.token2id["<img>"], self.token2id["<nul>"]
249         nb_img_tokens = self.height * self.width + 1
250
251         input = input.clone()
252         t = (input == t_img).long()
253         tail_masks = (t.cumsum(dim=1) == t.sum(dim=1, keepdim=True)).long()
254         i = (t * tail_masks).nonzero(as_tuple=True)
255         j = (
256             i[0][:, None],
257             i[1][:, None] + torch.arange(nb_img_tokens, device=input.device)[None, :],
258         )
259         images = self.trim(input[j])
260         input[j] = t_nul
261         loss_masks = 1 - tail_masks
262         input, loss_masks = self.trim((input, loss_masks))
263         return input, loss_masks, images
264
265     def add_true_image(self, input, images, loss_masks):
266         t_nul = self.token2id["<nul>"]
267         nb_img_tokens = self.height * self.width + 1
268         input = F.pad(input, (0, nb_img_tokens), value=t_nul)
269         loss_masks = F.pad(loss_masks, (0, nb_img_tokens), value=0)
270         t = (input == t_nul).long()
271         i = (t.cumsum(dim=1) == 1).nonzero(as_tuple=True)
272         j = (
273             i[0][:, None],
274             i[1][:, None] + torch.arange(nb_img_tokens, device=input.device)[None, :],
275         )
276         input[j] = images
277         loss_masks[j] = 1
278         input, loss_masks = self.trim((input, loss_masks))
279         return input, loss_masks
280
281     def add_generated_image(self, input, loss_masks, model):
282         t_img, t_nul = self.token2id["<img>"], self.token2id["<nul>"]
283         nb_img_tokens = self.height * self.width + 1
284
285         input = F.pad(input, (0, nb_img_tokens), value=t_nul)
286         loss_masks = F.pad(loss_masks, (0, nb_img_tokens), value=0)
287         t = (input == t_nul).long()
288         i = (t.cumsum(dim=1) == 1).nonzero(as_tuple=True)
289         input[i] = t_img
290
291         j = (
292             i[0][:, None],
293             i[1][:, None]
294             + 1
295             + torch.arange(nb_img_tokens - 1, device=input.device)[None, :],
296         )
297         ar_masks = input.new_zeros(input.size(), dtype=torch.int64)
298         ar_masks[j] = 1
299         forbidden_tokens = (
300             torch.arange(self.vocabulary_size(), device=input.device) == t_nul
301         )
302         with torch.autograd.no_grad():
303             t = model.training
304             model.eval()
305             masked_inplace_autoregression(
306                 model,
307                 self.batch_size,
308                 input,
309                 ar_masks,
310                 forbidden_tokens,
311                 device=self.device,
312             )
313             model.train(t)
314
315         input, loss_masks = self.trim((input, loss_masks))
316
317         return input, loss_masks
318
319     ######################
320
321     def __init__(
322         self,
323         nb_train_samples,
324         nb_test_samples,
325         batch_size,
326         height,
327         width,
328         nb_colors=5,
329         device=torch.device("cpu"),
330         pruner_train=None,
331         pruner_eval=None,
332     ):
333         def generate_descr(nb, cache_suffix, pruner):
334             return picoclvr.generate(
335                 nb,
336                 height=self.height,
337                 width=self.width,
338                 nb_colors=nb_colors,
339                 pruner=pruner,
340             )
341
342         self.height = height
343         self.width = width
344         self.batch_size = batch_size
345         self.device = device
346         self.pruner_train = pruner_train
347         self.pruner_eval = pruner_eval
348
349         param = {
350             "nb_train_samples": nb_train_samples,
351             "nb_test_samples": nb_test_samples,
352             "height": height,
353             "width": width,
354             "nb_colors": nb_colors,
355             "batch_size": batch_size,
356             "rng_state": list(torch.get_rng_state()),
357         }
358
359         log_string(
360             f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
361         )
362         self.train_descr = generate_descr(
363             nb_train_samples, "train", pruner=self.pruner_train
364         )
365         self.test_descr = generate_descr(nb_test_samples, "test", pruner=None)
366
367         # Build the tokenizer
368         tokens = {"<nul>", "<img>"}
369         for d in [self.train_descr, self.test_descr]:
370             for s in d:
371                 for t in s.strip().split(" "):
372                     tokens.add(t)
373         # make this set a sorted list to get the same tensors given
374         # the same descr
375         tokens = list(tokens)
376         tokens.sort()
377         self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
378         self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
379
380         # Tokenize the train and test sets
381         self.train_input = self.tensorize(self.train_descr)
382         self.test_input = self.tensorize(self.test_descr)
383
384     def batches(self, split="train"):
385         assert split in {"train", "test"}
386         input = self.train_input if split == "train" else self.test_input
387         for batch in tqdm.tqdm(
388             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
389         ):
390             yield self.trim(batch)
391
392     def vocabulary_size(self):
393         return len(self.token2id)
394
395     def compute_missing_properties(self, n_epoch, model, pruner=None):
396         acc_nb_requested_properties = []
397         acc_nb_missing_properties = []
398         acc_nb_results = 0
399
400         for input in tqdm.tqdm(
401             self.test_input.split(self.batch_size),
402             dynamic_ncols=True,
403             desc=f"test-properties",
404         ):
405             tape, loss_masks, _ = self.excise_last_image(input)
406             tape, loss_masks = self.add_generated_image(tape, loss_masks, model)
407             result_descr = self.detensorize(tape)
408             np = picoclvr.nb_properties(
409                 result_descr,
410                 height=self.height,
411                 width=self.width,
412                 pruner=pruner,
413             )
414             nb_requested_properties, _, nb_missing_properties = zip(*np)
415             acc_nb_requested_properties += nb_requested_properties
416             acc_nb_missing_properties += nb_missing_properties
417             acc_nb_results += len(result_descr)
418
419         nb_requested_properties = sum(acc_nb_requested_properties)
420         nb_missing_properties = sum(acc_nb_missing_properties)
421
422         prefix = "" if pruner is None else "pruned_"
423         log_string(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
424         log_string(
425             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
426         )
427         log_string(
428             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
429         )
430
431     ######################################################################
432
433     def produce_results(self, n_epoch, model):
434         self.compute_missing_properties(n_epoch, model)
435
436         if self.pruner_eval is not None:
437             self.compute_missing_properties(n_epoch, model, self.pruner_eval)
438
439         nb_tokens_to_generate = self.height * self.width + 3
440         result_descr = []
441         nb_per_primer = 8
442         primer = []
443
444         for primer_descr in [
445             "red above green <sep> green top <sep> blue right of red",
446             "there is red <sep> there is yellow <sep> there is blue",
447             "red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left",
448             "green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top",
449         ]:
450             primer += [primer_descr] * nb_per_primer
451
452         tape = self.tensorize(primer)
453         loss_masks = 1 - (tape == self.token2id["<nul>"]).long()
454         tape, loss_masks = self.add_generated_image(tape, loss_masks, model)
455         result_descr = self.detensorize(tape)
456
457         np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
458
459         acc_nb_requested_properties, _, acc_nb_missing_properties = zip(*np)
460         acc_nb_results = len(result_descr)
461
462         nb_requested_properties = sum(acc_nb_requested_properties)
463         nb_missing_properties = sum(acc_nb_missing_properties)
464
465         prefix = "demo_"
466         log_string(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
467         log_string(
468             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
469         )
470         log_string(
471             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
472         )
473
474         img = picoclvr.descr2img(result_descr, height=self.height, width=self.width)
475
476         if img.dim() == 5:
477             if img.size(1) == 1:
478                 img = F.pad(img.squeeze(1), pad=(1, 1, 1, 1), value=64)
479             else:
480                 img = torch.cat(
481                     [
482                         torchvision.utils.make_grid(x, padding=1, pad_value=64)[None]
483                         for x in img
484                     ],
485                     0,
486                 )
487
488         image_name = os.path.join(args.result_dir, f"picoclvr_result_{n_epoch:04d}.png")
489         torchvision.utils.save_image(
490             img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=1.0
491         )
492         log_string(f"wrote {image_name}")
493
494
495 ######################################################################
496
497
498 class TaskMNIST(Task):
499     def __init__(self, batch_size, device=torch.device("cpu")):
500         self.device = device
501         self.batch_size = batch_size
502
503     def batches(self, split="train"):
504         assert split in {"train", "test"}
505         data_set = torchvision.datasets.MNIST(
506             root="./data", train=(split == "train"), download=True
507         )
508         data_input = data_set.data.view(-1, 28 * 28).long()
509         if args.nb_train_samples is not None:
510             data_input = data_input[: args.nb_train_samples]
511         for batch in tqdm.tqdm(
512             data_input.split(self.batch_size), desc=f"epoch-{split}"
513         ):
514             yield batch
515
516     def vocabulary_size(self):
517         return 256
518
519     def produce_results(self, n_epoch, model):
520         results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64)
521         ar_mask = torch.full_like(results, 1)
522         masked_inplace_autoregression(
523             model, self.batch_size, results, ar_mask, device=self.device
524         )
525         image_name = os.path.join(args.result_dir, f"mnist_result_{n_epoch:04d}.png")
526         torchvision.utils.save_image(
527             1 - results.reshape(-1, 1, 28, 28) / 255.0,
528             image_name,
529             nrow=16,
530             pad_value=0.8,
531         )
532         log_string(f"wrote {image_name}")
533
534
535 ######################################################################
536
537 import maze
538
539
540 class TaskMaze(Task):
541     def map2seq(self, *m):
542         return torch.cat([x.flatten(1) for x in m], 1)
543
544     def seq2map(self, s):
545         s = s.reshape(s.size(0), -1, self.height, self.width)
546         return (s[:, k] for k in range(s.size(1)))
547
548     def __init__(
549         self,
550         nb_train_samples,
551         nb_test_samples,
552         batch_size,
553         height,
554         width,
555         nb_walls,
556         device=torch.device("cpu"),
557     ):
558         self.batch_size = batch_size
559         self.height = height
560         self.width = width
561         self.device = device
562
563         train_mazes, train_paths, _ = maze.create_maze_data(
564             nb_train_samples,
565             height=height,
566             width=width,
567             nb_walls=nb_walls,
568             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"),
569         )
570         self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device))
571
572         test_mazes, test_paths, _ = maze.create_maze_data(
573             nb_test_samples,
574             height=height,
575             width=width,
576             nb_walls=nb_walls,
577             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"),
578         )
579         self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
580
581         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
582
583     def batches(self, split="train", nb_to_use=-1, desc=None):
584         assert split in {"train", "test"}
585         input = self.train_input if split == "train" else self.test_input
586         if nb_to_use > 0:
587             input = input[:nb_to_use]
588         if desc is None:
589             desc = f"epoch-{split}"
590         for batch in tqdm.tqdm(
591             input.split(self.batch_size), dynamic_ncols=True, desc=desc
592         ):
593             yield batch
594
595     def vocabulary_size(self):
596         return self.nb_codes
597
598     def compute_error(self, model, split="train", nb_to_use=-1):
599         nb_total, nb_correct = 0, 0
600         for input in task.batches(split, nb_to_use):
601             result = input.clone()
602             ar_mask = result.new_zeros(result.size())
603             ar_mask[:, self.height * self.width :] = 1
604             result *= 1 - ar_mask
605             masked_inplace_autoregression(
606                 model, self.batch_size, result, ar_mask, device=self.device
607             )
608             mazes, paths = self.seq2map(result)
609             nb_correct += maze.path_correctness(mazes, paths).long().sum()
610             nb_total += mazes.size(0)
611
612         return nb_total, nb_correct
613
614     def produce_results(self, n_epoch, model):
615         with torch.autograd.no_grad():
616             t = model.training
617             model.eval()
618
619             train_nb_total, train_nb_correct = self.compute_error(
620                 model, "train", nb_to_use=1000
621             )
622             log_string(
623                 f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
624             )
625
626             test_nb_total, test_nb_correct = self.compute_error(
627                 model, "test", nb_to_use=1000
628             )
629             log_string(
630                 f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
631             )
632
633             input = self.test_input[:48]
634             result = input.clone()
635             ar_mask = result.new_zeros(result.size())
636             ar_mask[:, self.height * self.width :] = 1
637             result *= 1 - ar_mask
638             masked_inplace_autoregression(
639                 model, self.batch_size, result, ar_mask, device=self.device
640             )
641
642             mazes, paths = self.seq2map(input)
643             _, predicted_paths = self.seq2map(result)
644
645             filename = os.path.join(args.result_dir, f"maze_result_{n_epoch:04d}.png")
646             maze.save_image(
647                 filename,
648                 mazes=mazes,
649                 target_paths=paths,
650                 predicted_paths=predicted_paths,
651                 path_correct=maze.path_correctness(mazes, predicted_paths),
652             )
653             log_string(f"wrote {filename}")
654
655             model.train(t)
656
657
658 ######################################################################
659
660
661 def generate_snake_sequences(
662     nb, height, width, nb_colors, length, device=torch.device("cpu")
663 ):
664     worlds = torch.randint(nb_colors, (nb, height, width), device=device)
665     nb_prior_visits = torch.zeros(nb, height, width, device=device)
666
667     # nb x 2
668     snake_position = torch.cat(
669         (
670             torch.randint(height, (nb, 1), device=device),
671             torch.randint(width, (nb, 1), device=device),
672         ),
673         1,
674     )
675     snake_direction = torch.randint(4, (nb,), device=device)
676     sequences = torch.empty(nb, 2 * length, device=device, dtype=torch.int64)
677     sequences_prior_visits = torch.zeros(
678         nb, 2 * length, device=device, dtype=torch.int64
679     )
680     i = torch.arange(nb, device=device)  # [:,None]
681
682     for l in range(length):
683         # nb x 3
684         snake_next_direction = torch.cat(
685             (
686                 (snake_direction[:, None] - 1) % 4,
687                 snake_direction[:, None],
688                 (snake_direction[:, None] + 1) % 4,
689             ),
690             1,
691         )
692
693         # nb x 3
694         vh = (snake_next_direction + 1) % 2 * (snake_next_direction - 1)
695         vw = snake_next_direction % 2 * (snake_next_direction - 2)
696
697         # nb x 3 x 2
698         snake_next_speed = torch.cat((vh[:, :, None], vw[:, :, None]), 2)
699         snake_next_position = snake_position[:, None, :] + snake_next_speed
700
701         # nb x 3
702         val = torch.logical_and(
703             torch.logical_and(
704                 snake_next_position[:, :, 0] >= 0, snake_next_position[:, :, 0] < height
705             ),
706             torch.logical_and(
707                 snake_next_position[:, :, 1] >= 0, snake_next_position[:, :, 1] < width
708             ),
709         ).float()
710         val = (
711             # The multiplicative factors bias toward moving forward
712             torch.rand_like(val)
713             * val
714             * torch.tensor([[1.0, 2.0, 1.0]], device=device)
715         )
716
717         # nb
718         j = val.argmax(1)
719         snake_direction = snake_next_direction[i, j]
720
721         sequences[:, 2 * l] = worlds[i, snake_position[:, 0], snake_position[:, 1]] + 4
722         sequences_prior_visits[:, 2 * l] = nb_prior_visits[
723             i, snake_position[:, 0], snake_position[:, 1]
724         ]
725         nb_prior_visits[i, snake_position[:, 0], snake_position[:, 1]] += 1
726         sequences[:, 2 * l + 1] = snake_direction
727
728         # nb x 2
729         snake_position = snake_next_position[i, j]
730
731     return sequences, sequences_prior_visits
732
733
734 # generate_snake_sequences(nb=1, height=4, width=6, nb_colors=3, length=20)
735 # exit(0)
736
737
738 class TaskSnake(Task):
739     def __init__(
740         self,
741         nb_train_samples,
742         nb_test_samples,
743         batch_size,
744         height,
745         width,
746         nb_colors,
747         length,
748         device=torch.device("cpu"),
749     ):
750         self.batch_size = batch_size
751         self.height = height
752         self.width = width
753         self.device = device
754
755         self.train_input, self.train_prior_visits = generate_snake_sequences(
756             nb_train_samples, height, width, nb_colors, length, self.device
757         )
758         self.test_input, self.test_prior_visits = generate_snake_sequences(
759             nb_test_samples, height, width, nb_colors, length, self.device
760         )
761
762         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
763
764     def batches(self, split="train", nb_to_use=-1, desc=None):
765         assert split in {"train", "test"}
766         input = self.train_input if split == "train" else self.test_input
767         if nb_to_use > 0:
768             input = input[:nb_to_use]
769         if desc is None:
770             desc = f"epoch-{split}"
771         for batch in tqdm.tqdm(
772             input.split(self.batch_size), dynamic_ncols=True, desc=desc
773         ):
774             yield batch
775
776     def vocabulary_size(self):
777         return self.nb_codes
778
779     def produce_results(self, n_epoch, model):
780         with torch.autograd.no_grad():
781             t = model.training
782             model.eval()
783
784             def compute_nb_correct(input, prior_visits):
785                 result = input.clone()
786                 i = torch.arange(result.size(1), device=result.device)[None, :]
787                 ar_mask = torch.logical_and(i >= i.size(0) // 2, i % 2 == 0).long()
788                 result *= 1 - ar_mask
789                 masked_inplace_autoregression(
790                     model, self.batch_size, result, ar_mask, device=self.device
791                 )
792
793                 nb_total = (
794                     (prior_visits > 0) * ar_mask
795                 ).sum()
796
797                 nb_correct = (
798                     (result == input).long() * (prior_visits > 0) * ar_mask
799                 ).sum()
800
801                 # nb_total = result.size(0)
802                 # nb_correct = ((result - input).abs().sum(1) == 0).sum()
803
804                 return nb_total, nb_correct
805
806             train_nb_total, train_nb_correct = compute_nb_correct(
807                 self.train_input, self.train_prior_visits
808             )
809
810             log_string(
811                 f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
812             )
813
814             test_nb_total, test_nb_correct = compute_nb_correct(
815                 self.test_input, self.test_prior_visits
816             )
817
818             log_string(
819                 f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
820             )
821
822             model.train(t)
823
824
825 ######################################################################
826
827
828 def picoclvr_pruner_horizontal_green(p):
829     return not ("green" in p and ("left" in p or "right" in p))
830
831
832 picoclvr_pruner_train = (
833     picoclvr_pruner_horizontal_green
834     if args.picocvlr_prune_properties in {"train+eval"}
835     else None
836 )
837
838 picoclvr_pruner_eval = (
839     (lambda p: not picoclvr_pruner_horizontal_green(p))
840     if args.picocvlr_prune_properties in {"train+eval", "eval"}
841     else None
842 )
843
844 ######################################################################
845
846 if args.task == "picoclvr":
847     task = TaskPicoCLVR(
848         nb_train_samples=args.nb_train_samples,
849         nb_test_samples=args.nb_test_samples,
850         batch_size=args.batch_size,
851         height=args.picoclvr_height,
852         width=args.picoclvr_width,
853         nb_colors=args.picoclvr_nb_colors,
854         device=device,
855         pruner_train=picoclvr_pruner_train,
856         pruner_eval=picoclvr_pruner_eval,
857     )
858
859 elif args.task == "mnist":
860     task = TaskMNIST(
861         batch_size=args.batch_size,
862         device=device,
863     )
864
865 elif args.task == "maze":
866     task = TaskMaze(
867         nb_train_samples=args.nb_train_samples,
868         nb_test_samples=args.nb_test_samples,
869         batch_size=args.batch_size,
870         height=args.maze_height,
871         width=args.maze_width,
872         nb_walls=args.maze_nb_walls,
873         device=device,
874     )
875
876 elif args.task == "snake":
877     task = TaskSnake(
878         nb_train_samples=args.nb_train_samples,
879         nb_test_samples=args.nb_test_samples,
880         batch_size=args.batch_size,
881         height=args.snake_height,
882         width=args.snake_width,
883         nb_colors=args.snake_nb_colors,
884         length=args.snake_length,
885         device=device,
886     )
887
888 else:
889     raise ValueError(f"Unknown task {args.task}")
890
891 ######################################################################
892
893 log_string(f"device {device}")
894
895 vocabulary_size = task.vocabulary_size()
896
897 log_string(f"vocabulary_size {vocabulary_size}")
898
899 ##############################
900
901 model = mygpt.MyGPT(
902     vocabulary_size=vocabulary_size,
903     dim_model=args.dim_model,
904     dim_keys=args.dim_keys,
905     dim_hidden=args.dim_hidden,
906     nb_heads=args.nb_heads,
907     nb_blocks=args.nb_blocks,
908     causal=True,
909     dropout=args.dropout,
910 )
911
912 model.to(device)
913
914 nb_parameters = sum(p.numel() for p in model.parameters())
915 log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
916
917 ######################################################################
918
919 nb_epochs_finished = 0
920
921 if args.no_checkpoint:
922     log_string(f"not trying to load checkpoint.")
923
924 else:
925     try:
926         checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
927         checkpoint = torch.load(checkpoint_name)
928         nb_epochs_finished = checkpoint["nb_epochs_finished"]
929         model.load_state_dict(checkpoint["model_state"])
930         torch.set_rng_state(checkpoint["rng_state"])
931         if torch.cuda.is_available():
932             torch.cuda.set_rng_state(checkpoint["cuda_rng_state"])
933
934         log_string(f"checkpoint loaded with {nb_epochs_finished} epochs finished.")
935
936     except FileNotFoundError:
937         log_string("starting from scratch.")
938
939     except:
940         log_string("error when loading the checkpoint.")
941         exit(1)
942
943 ######################################################################
944
945 nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
946
947 token_count = 0
948 for input in task.batches(split="train"):
949     token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1))
950 token_probas = token_count / token_count.sum()
951 entropy = -torch.xlogy(token_probas, token_probas).sum()
952 train_set_perplexity = math.exp(entropy)
953
954 ##############################
955
956 if args.learning_rate_schedule == "cos":
957     learning_rate_schedule = {}
958     for n_epoch in range(args.nb_epochs):
959         u = n_epoch / args.nb_epochs * math.pi
960         learning_rate_schedule[n_epoch] = args.learning_rate * 0.5 * (1 + math.cos(u))
961 else:
962     u = {
963         int(k): float(v)
964         for k, v in [
965             tuple(x.split(":")) for x in args.learning_rate_schedule.split(",")
966         ]
967     }
968
969     learning_rate_schedule = {}
970     learning_rate = args.learning_rate
971     for n_epoch in range(args.nb_epochs):
972         if n_epoch in u:
973             learning_rate = u[n_epoch]
974         learning_rate_schedule[n_epoch] = learning_rate
975
976 log_string(f"learning_rate_schedule {learning_rate_schedule}")
977
978 ##############################
979
980 nb_samples_seen = 0
981
982 if nb_epochs_finished >= nb_epochs:
983     task.produce_results(nb_epochs_finished, model)
984
985 for n_epoch in range(nb_epochs_finished, nb_epochs):
986     learning_rate = learning_rate_schedule[n_epoch]
987
988     log_string(f"learning_rate {learning_rate}")
989
990     if args.optim == "sgd":
991         optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
992     elif args.optim == "adam":
993         optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
994     elif args.optim == "adamw":
995         optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
996     else:
997         raise ValueError(f"Unknown optimizer {args.optim}.")
998
999     model.train()
1000
1001     nb_train_samples, acc_train_loss = 0, 0.0
1002
1003     for input in task.batches(split="train"):
1004         input = input.to(device)
1005         output = model(mygpt.BracketedSequence(input)).x
1006         loss = F.cross_entropy(output.transpose(1, 2), input)
1007         acc_train_loss += loss.item() * input.size(0)
1008         nb_train_samples += input.size(0)
1009         nb_samples_seen += input.size(0)
1010
1011         optimizer.zero_grad()
1012         loss.backward()
1013         optimizer.step()
1014
1015     with torch.autograd.no_grad():
1016         model.eval()
1017
1018         nb_test_samples, acc_test_loss = 0, 0.0
1019
1020         for input in task.batches(split="test"):
1021             input = input.to(device)
1022
1023             # input, loss_masks, true_images = task.excise_last_image(input)
1024             # input, loss_masks = task.add_true_image(input, true_images, loss_masks)
1025
1026             output = model(mygpt.BracketedSequence(input)).x
1027             loss = F.cross_entropy(output.transpose(1, 2), input)
1028             acc_test_loss += loss.item() * input.size(0)
1029             nb_test_samples += input.size(0)
1030
1031         train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
1032         test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
1033
1034         log_string(
1035             f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
1036         )
1037
1038         task.produce_results(n_epoch, model)
1039
1040     checkpoint = {
1041         "nb_epochs_finished": n_epoch + 1,
1042         "model_state": model.state_dict(),
1043         "rng_state": torch.get_rng_state(),
1044     }
1045
1046     if torch.cuda.is_available():
1047         checkpoint["cuda_rng_state"] = torch.cuda.get_rng_state()
1048
1049     checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
1050     torch.save(checkpoint, checkpoint_name)
1051     log_string(f"saved checkpoint {checkpoint_name}")
1052
1053 ######################################################################