Update.
[picoclvr.git] / main.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # torch.backends.cuda.matmul.allow_tf23
9 # torch.autocast(torch.bfloat16)
10
11 import math, sys, argparse, time, tqdm, os
12
13 import torch, torchvision
14 from torch import nn
15 from torch.nn import functional as F
16
17 import mygpt, tensorstack
18
19 ######################################################################
20
21 if torch.cuda.is_available():
22     device = torch.device("cuda")
23     torch.backends.cuda.matmul.allow_tf32 = True
24 else:
25     device = torch.device("cpu")
26
27 ######################################################################
28
29 parser = argparse.ArgumentParser(
30     description="An implementation of GPT with cache.",
31     formatter_class=argparse.ArgumentDefaultsHelpFormatter,
32 )
33
34 parser.add_argument("--task", type=str, default="picoclvr")
35
36 parser.add_argument("--log_filename", type=str, default="train.log")
37
38 parser.add_argument("--result_dir", type=str, default="results_default")
39
40 parser.add_argument("--seed", type=int, default=0)
41
42 parser.add_argument("--nb_epochs", type=int, default=None)
43
44 parser.add_argument("--batch_size", type=int, default=None)
45
46 parser.add_argument("--nb_train_samples", type=int, default=250000)
47
48 parser.add_argument("--nb_test_samples", type=int, default=10000)
49
50 parser.add_argument("--optim", type=str, default="adam")
51
52 parser.add_argument("--learning_rate", type=float, default=1e-4)
53
54 parser.add_argument("--learning_rate_schedule", type=str, default="10: 2e-5,30: 4e-6")
55
56 parser.add_argument("--dim_model", type=int, default=512)
57
58 parser.add_argument("--dim_keys", type=int, default=64)
59
60 parser.add_argument("--dim_hidden", type=int, default=2048)
61
62 parser.add_argument("--nb_heads", type=int, default=8)
63
64 parser.add_argument("--nb_blocks", type=int, default=12)
65
66 parser.add_argument("--dropout", type=float, default=0.1)
67
68 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
69
70 parser.add_argument("--no_checkpoint", action="store_true", default=False)
71
72 parser.add_argument("--overwrite_results", action="store_true", default=False)
73
74 parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
75
76 ##############################
77 # picoclvr options
78
79 parser.add_argument("--picoclvr_nb_colors", type=int, default=5)
80
81 parser.add_argument("--picoclvr_height", type=int, default=12)
82
83 parser.add_argument("--picoclvr_width", type=int, default=16)
84
85 parser.add_argument("--picocvlr_prune_properties", type=str, default="none")
86
87 ##############################
88 # Maze options
89
90 parser.add_argument("--maze_height", type=int, default=13)
91
92 parser.add_argument("--maze_width", type=int, default=21)
93
94 parser.add_argument("--maze_nb_walls", type=int, default=15)
95
96 ##############################
97 # Snake options
98
99 parser.add_argument("--snake_height", type=int, default=6)
100
101 parser.add_argument("--snake_width", type=int, default=8)
102
103 parser.add_argument("--snake_nb_colors", type=int, default=5)
104
105 parser.add_argument("--snake_length", type=int, default=200)
106
107 ######################################################################
108
109 args = parser.parse_args()
110
111 assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"}
112
113 try:
114     os.mkdir(args.result_dir)
115 except FileExistsError:
116     if not args.overwrite_results:
117         print(f"result directory {args.result_dir} already exists")
118         exit(1)
119
120 log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
121
122 if args.seed >= 0:
123     # torch.backends.cudnn.deterministic = True
124     # torch.backends.cudnn.benchmark = False
125     # torch.use_deterministic_algorithms(True)
126     torch.manual_seed(args.seed)
127     if torch.cuda.is_available():
128         torch.cuda.manual_seed_all(args.seed)
129
130 ######################################################################
131
132 default_args = {
133     "picoclvr": {
134         "nb_epochs": 25,
135         "batch_size": 25,
136     },
137     "mnist": {
138         "nb_epochs": 25,
139         "batch_size": 10,
140     },
141     "maze": {
142         "nb_epochs": 25,
143         "batch_size": 25,
144     },
145     "snake": {
146         "nb_epochs": 5,
147         "batch_size": 25,
148     },
149 }
150
151 if args.task in default_args:
152     for k, v in default_args[args.task].items():
153         if getattr(args, k) is None:
154             setattr(args, k, v)
155
156 ######################################################################
157
158
159 def log_string(s):
160     t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
161
162     if log_file is not None:
163         log_file.write(t + s + "\n")
164         log_file.flush()
165
166     print(t + s)
167     sys.stdout.flush()
168
169
170 for n in vars(args):
171     log_string(f"args.{n} {getattr(args, n)}")
172
173 ######################################################################
174
175
176 def masked_inplace_autoregression(
177     model, batch_size, input, ar_mask, forbidden_tokens=None, device=torch.device("cpu")
178 ):
179     for input, ar_mask in tqdm.tqdm(
180         zip(input.split(batch_size), ar_mask.split(batch_size)),
181         dynamic_ncols=True,
182         desc="autoregression",
183         total=input.size(0) // batch_size,
184     ):
185         i = (ar_mask.sum(0) > 0).nonzero()
186         if i.min() > 0:
187             model(
188                 mygpt.BracketedSequence(input, 0, i.min())
189             )  # Needed to initialize the model's cache
190         for s in range(i.min(), i.max() + 1):
191             output = model(mygpt.BracketedSequence(input, s, 1)).x
192             logits = output[:, s]
193             if forbidden_tokens is not None:
194                 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
195             if args.deterministic_synthesis:
196                 t_next = logits.argmax(1)
197             else:
198                 dist = torch.distributions.categorical.Categorical(logits=logits)
199                 t_next = dist.sample()
200             input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
201
202
203 ######################################################################
204
205
206 class Task:
207     def batches(self, split="train"):
208         pass
209
210     def vocabulary_size(self):
211         pass
212
213     def produce_results(self, n_epoch, model):
214         pass
215
216
217 ######################################################################
218
219 import picoclvr
220
221
222 class TaskPicoCLVR(Task):
223     # Make a tensor from a list of strings
224     def tensorize(self, descr):
225         token_descr = [s.strip().split(" ") for s in descr]
226         l = max([len(s) for s in token_descr])
227         token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
228         id_descr = [[self.token2id[u] for u in s] for s in token_descr]
229         return torch.tensor(id_descr, device=self.device)
230
231     # Make a list of strings from a tensor
232     def detensorize(self, x):
233         return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
234
235     # trim all the tensors in the tuple z to remove as much token from
236     # left and right in the first tensor. If z is a tuple, all its
237     # elements are trimed according to the triming for the first
238     def trim(self, z, token="<nul>"):
239         n = self.token2id[token]
240         if type(z) == tuple:
241             x = z[0]
242             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
243             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
244             return tuple([t[:, a:b] for t in z])
245         else:
246             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
247             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
248             return z[:, a:b]
249
250     ######################
251     # Not the cleanest part of the code
252
253     # Extract the last image of each sequence, from the last <img>
254     # included, and set to <nul> all the tokens from the beginning of
255     # that image to the end
256     def excise_last_image(self, input):
257         t_img, t_nul = self.token2id["<img>"], self.token2id["<nul>"]
258         nb_img_tokens = self.height * self.width + 1
259
260         input = input.clone()
261         t = (input == t_img).long()
262         tail_masks = (t.cumsum(dim=1) == t.sum(dim=1, keepdim=True)).long()
263         i = (t * tail_masks).nonzero(as_tuple=True)
264         j = (
265             i[0][:, None],
266             i[1][:, None] + torch.arange(nb_img_tokens, device=input.device)[None, :],
267         )
268         images = self.trim(input[j])
269         input[j] = t_nul
270         loss_masks = 1 - tail_masks
271         input, loss_masks = self.trim((input, loss_masks))
272         return input, loss_masks, images
273
274     def add_true_image(self, input, images, loss_masks):
275         t_nul = self.token2id["<nul>"]
276         nb_img_tokens = self.height * self.width + 1
277         input = F.pad(input, (0, nb_img_tokens), value=t_nul)
278         loss_masks = F.pad(loss_masks, (0, nb_img_tokens), value=0)
279         t = (input == t_nul).long()
280         i = (t.cumsum(dim=1) == 1).nonzero(as_tuple=True)
281         j = (
282             i[0][:, None],
283             i[1][:, None] + torch.arange(nb_img_tokens, device=input.device)[None, :],
284         )
285         input[j] = images
286         loss_masks[j] = 1
287         input, loss_masks = self.trim((input, loss_masks))
288         return input, loss_masks
289
290     def add_generated_image(self, input, loss_masks, model):
291         t_img, t_nul = self.token2id["<img>"], self.token2id["<nul>"]
292         nb_img_tokens = self.height * self.width + 1
293
294         input = F.pad(input, (0, nb_img_tokens), value=t_nul)
295         loss_masks = F.pad(loss_masks, (0, nb_img_tokens), value=0)
296         t = (input == t_nul).long()
297         i = (t.cumsum(dim=1) == 1).nonzero(as_tuple=True)
298         input[i] = t_img
299
300         j = (
301             i[0][:, None],
302             i[1][:, None]
303             + 1
304             + torch.arange(nb_img_tokens - 1, device=input.device)[None, :],
305         )
306         ar_masks = input.new_zeros(input.size(), dtype=torch.int64)
307         ar_masks[j] = 1
308         forbidden_tokens = (
309             torch.arange(self.vocabulary_size(), device=input.device) == t_nul
310         )
311         with torch.autograd.no_grad():
312             t = model.training
313             model.eval()
314             masked_inplace_autoregression(
315                 model,
316                 self.batch_size,
317                 input,
318                 ar_masks,
319                 forbidden_tokens,
320                 device=self.device,
321             )
322             model.train(t)
323
324         input, loss_masks = self.trim((input, loss_masks))
325
326         return input, loss_masks
327
328     ######################
329
330     def __init__(
331         self,
332         nb_train_samples,
333         nb_test_samples,
334         batch_size,
335         height,
336         width,
337         nb_colors=5,
338         device=torch.device("cpu"),
339         pruner_train=None,
340         pruner_eval=None,
341     ):
342         def generate_descr(nb, cache_suffix, pruner):
343             return picoclvr.generate(
344                 nb,
345                 height=self.height,
346                 width=self.width,
347                 nb_colors=nb_colors,
348                 pruner=pruner,
349             )
350
351         self.height = height
352         self.width = width
353         self.batch_size = batch_size
354         self.device = device
355         self.pruner_train = pruner_train
356         self.pruner_eval = pruner_eval
357
358         param = {
359             "nb_train_samples": nb_train_samples,
360             "nb_test_samples": nb_test_samples,
361             "height": height,
362             "width": width,
363             "nb_colors": nb_colors,
364             "batch_size": batch_size,
365             "rng_state": list(torch.get_rng_state()),
366         }
367
368         log_string(
369             f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
370         )
371         self.train_descr = generate_descr(
372             nb_train_samples, "train", pruner=self.pruner_train
373         )
374         self.test_descr = generate_descr(nb_test_samples, "test", pruner=None)
375
376         # Build the tokenizer
377         tokens = {"<nul>", "<img>"}
378         for d in [self.train_descr, self.test_descr]:
379             for s in d:
380                 for t in s.strip().split(" "):
381                     tokens.add(t)
382         # make this set a sorted list to get the same tensors given
383         # the same descr
384         tokens = list(tokens)
385         tokens.sort()
386         self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
387         self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
388
389         # Tokenize the train and test sets
390         self.train_input = self.tensorize(self.train_descr)
391         self.test_input = self.tensorize(self.test_descr)
392
393     def batches(self, split="train"):
394         assert split in {"train", "test"}
395         input = self.train_input if split == "train" else self.test_input
396         for batch in tqdm.tqdm(
397             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
398         ):
399             yield self.trim(batch)
400
401     def vocabulary_size(self):
402         return len(self.token2id)
403
404     def compute_missing_properties(self, n_epoch, model, pruner=None):
405         acc_nb_requested_properties = []
406         acc_nb_missing_properties = []
407         acc_nb_results = 0
408
409         for input in tqdm.tqdm(
410             self.test_input.split(self.batch_size),
411             dynamic_ncols=True,
412             desc=f"test-properties",
413         ):
414             tape, loss_masks, _ = self.excise_last_image(input)
415             tape, loss_masks = self.add_generated_image(tape, loss_masks, model)
416             result_descr = self.detensorize(tape)
417             np = picoclvr.nb_properties(
418                 result_descr,
419                 height=self.height,
420                 width=self.width,
421                 pruner=pruner,
422             )
423             nb_requested_properties, _, nb_missing_properties = zip(*np)
424             acc_nb_requested_properties += nb_requested_properties
425             acc_nb_missing_properties += nb_missing_properties
426             acc_nb_results += len(result_descr)
427
428         nb_requested_properties = sum(acc_nb_requested_properties)
429         nb_missing_properties = sum(acc_nb_missing_properties)
430
431         prefix = "" if pruner is None else "pruned_"
432         log_string(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
433         log_string(
434             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
435         )
436         log_string(
437             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
438         )
439
440     ######################################################################
441
442     def produce_results(self, n_epoch, model):
443         self.compute_missing_properties(n_epoch, model)
444
445         if self.pruner_eval is not None:
446             self.compute_missing_properties(n_epoch, model, self.pruner_eval)
447
448         nb_tokens_to_generate = self.height * self.width + 3
449         result_descr = []
450         nb_per_primer = 8
451         primer = []
452
453         for primer_descr in [
454             "red above green <sep> green top <sep> blue right of red",
455             "there is red <sep> there is yellow <sep> there is blue",
456             "red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left",
457             "green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top",
458         ]:
459             primer += [primer_descr] * nb_per_primer
460
461         tape = self.tensorize(primer)
462         loss_masks = 1 - (tape == self.token2id["<nul>"]).long()
463         tape, loss_masks = self.add_generated_image(tape, loss_masks, model)
464         result_descr = self.detensorize(tape)
465
466         np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
467
468         acc_nb_requested_properties, _, acc_nb_missing_properties = zip(*np)
469         acc_nb_results = len(result_descr)
470
471         nb_requested_properties = sum(acc_nb_requested_properties)
472         nb_missing_properties = sum(acc_nb_missing_properties)
473
474         prefix = "demo_"
475         log_string(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
476         log_string(
477             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
478         )
479         log_string(
480             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
481         )
482
483         img = picoclvr.descr2img(result_descr, height=self.height, width=self.width)
484
485         if img.dim() == 5:
486             if img.size(1) == 1:
487                 img = F.pad(img.squeeze(1), pad=(1, 1, 1, 1), value=64)
488             else:
489                 img = torch.cat(
490                     [
491                         torchvision.utils.make_grid(x, padding=1, pad_value=64)[None]
492                         for x in img
493                     ],
494                     0,
495                 )
496
497         image_name = os.path.join(args.result_dir, f"picoclvr_result_{n_epoch:04d}.png")
498         torchvision.utils.save_image(
499             img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=1.0
500         )
501         log_string(f"wrote {image_name}")
502
503
504 ######################################################################
505
506
507 class TaskMNIST(Task):
508     def __init__(self, batch_size, device=torch.device("cpu")):
509         self.device = device
510         self.batch_size = batch_size
511
512     def batches(self, split="train"):
513         assert split in {"train", "test"}
514         data_set = torchvision.datasets.MNIST(
515             root="./data", train=(split == "train"), download=True
516         )
517         data_input = data_set.data.view(-1, 28 * 28).long()
518         if args.nb_train_samples is not None:
519             data_input = data_input[: args.nb_train_samples]
520         for batch in tqdm.tqdm(
521             data_input.split(self.batch_size), desc=f"epoch-{split}"
522         ):
523             yield batch
524
525     def vocabulary_size(self):
526         return 256
527
528     def produce_results(self, n_epoch, model):
529         results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64)
530         ar_mask = torch.full_like(results, 1)
531         masked_inplace_autoregression(
532             model, self.batch_size, results, ar_mask, device=self.device
533         )
534         image_name = os.path.join(args.result_dir, f"mnist_result_{n_epoch:04d}.png")
535         torchvision.utils.save_image(
536             1 - results.reshape(-1, 1, 28, 28) / 255.0,
537             image_name,
538             nrow=16,
539             pad_value=0.8,
540         )
541         log_string(f"wrote {image_name}")
542
543
544 ######################################################################
545
546 import maze
547
548
549 class TaskMaze(Task):
550     def map2seq(self, *m):
551         return torch.cat([x.flatten(1) for x in m], 1)
552
553     def seq2map(self, s):
554         s = s.reshape(s.size(0), -1, self.height, self.width)
555         return (s[:, k] for k in range(s.size(1)))
556
557     def __init__(
558         self,
559         nb_train_samples,
560         nb_test_samples,
561         batch_size,
562         height,
563         width,
564         nb_walls,
565         device=torch.device("cpu"),
566     ):
567         self.batch_size = batch_size
568         self.height = height
569         self.width = width
570         self.device = device
571
572         train_mazes, train_paths, _ = maze.create_maze_data(
573             nb_train_samples,
574             height=height,
575             width=width,
576             nb_walls=nb_walls,
577             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"),
578         )
579         self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device))
580
581         test_mazes, test_paths, _ = maze.create_maze_data(
582             nb_test_samples,
583             height=height,
584             width=width,
585             nb_walls=nb_walls,
586             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"),
587         )
588         self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
589
590         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
591
592     def batches(self, split="train", nb_to_use=-1, desc=None):
593         assert split in {"train", "test"}
594         input = self.train_input if split == "train" else self.test_input
595         if nb_to_use > 0:
596             input = input[:nb_to_use]
597         if desc is None:
598             desc = f"epoch-{split}"
599         for batch in tqdm.tqdm(
600             input.split(self.batch_size), dynamic_ncols=True, desc=desc
601         ):
602             yield batch
603
604     def vocabulary_size(self):
605         return self.nb_codes
606
607     def compute_error(self, model, split="train", nb_to_use=-1):
608         nb_total, nb_correct = 0, 0
609         for input in task.batches(split, nb_to_use):
610             result = input.clone()
611             ar_mask = result.new_zeros(result.size())
612             ar_mask[:, self.height * self.width :] = 1
613             result *= 1 - ar_mask
614             masked_inplace_autoregression(
615                 model, self.batch_size, result, ar_mask, device=self.device
616             )
617             mazes, paths = self.seq2map(result)
618             nb_correct += maze.path_correctness(mazes, paths).long().sum()
619             nb_total += mazes.size(0)
620
621         return nb_total, nb_correct
622
623     def produce_results(self, n_epoch, model):
624         with torch.autograd.no_grad():
625             t = model.training
626             model.eval()
627
628             train_nb_total, train_nb_correct = self.compute_error(
629                 model, "train", nb_to_use=1000
630             )
631             log_string(
632                 f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
633             )
634
635             test_nb_total, test_nb_correct = self.compute_error(
636                 model, "test", nb_to_use=1000
637             )
638             log_string(
639                 f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
640             )
641
642             input = self.test_input[:48]
643             result = input.clone()
644             ar_mask = result.new_zeros(result.size())
645             ar_mask[:, self.height * self.width :] = 1
646             result *= 1 - ar_mask
647             masked_inplace_autoregression(
648                 model, self.batch_size, result, ar_mask, device=self.device
649             )
650
651             mazes, paths = self.seq2map(input)
652             _, predicted_paths = self.seq2map(result)
653
654             filename = os.path.join(args.result_dir, f"maze_result_{n_epoch:04d}.png")
655             maze.save_image(
656                 filename,
657                 mazes=mazes,
658                 target_paths=paths,
659                 predicted_paths=predicted_paths,
660                 path_correct=maze.path_correctness(mazes, predicted_paths),
661             )
662             log_string(f"wrote {filename}")
663
664             model.train(t)
665
666
667 ######################################################################
668
669
670 import snake
671
672
673 class TaskSnake(Task):
674     def __init__(
675         self,
676         nb_train_samples,
677         nb_test_samples,
678         batch_size,
679         height,
680         width,
681         nb_colors,
682         length,
683         prompt_length,
684         device=torch.device("cpu"),
685     ):
686         self.batch_size = batch_size
687         self.height = height
688         self.width = width
689         self.device = device
690         self.prompt_length = prompt_length
691
692         self.train_input, self.train_prior_visits, _, _ = snake.generate_sequences(
693             nb_train_samples,
694             height,
695             width,
696             nb_colors,
697             length,
698             prompt_length,
699             self.device,
700         )
701         self.test_input, self.test_prior_visits, _, _ = snake.generate_sequences(
702             nb_test_samples,
703             height,
704             width,
705             nb_colors,
706             length,
707             prompt_length,
708             self.device,
709         )
710
711         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
712
713     def batches(self, split="train", nb_to_use=-1, desc=None):
714         assert split in {"train", "test"}
715         input = self.train_input if split == "train" else self.test_input
716         if nb_to_use > 0:
717             input = input[:nb_to_use]
718         if desc is None:
719             desc = f"epoch-{split}"
720         for batch in tqdm.tqdm(
721             input.split(self.batch_size), dynamic_ncols=True, desc=desc
722         ):
723             yield batch
724
725     def vocabulary_size(self):
726         return self.nb_codes
727
728     def produce_results(self, n_epoch, model):
729         with torch.autograd.no_grad():
730             t = model.training
731             model.eval()
732
733             def compute_nb_correct(input, prior_visits):
734                 result = input.clone()
735                 i = torch.arange(result.size(1), device=result.device)[None, :]
736                 ar_mask = (
737                     torch.logical_and(i >= self.prompt_length * 2, i % 2 == 0)
738                     .long()
739                     .expand_as(result)
740                 )
741                 result *= 1 - ar_mask
742
743                 # snake.solver(result,ar_mask)
744
745                 masked_inplace_autoregression(
746                     model, self.batch_size, result, ar_mask, device=self.device
747                 )
748
749                 nb_total = ((prior_visits > 0) * ar_mask).sum()
750
751                 nb_correct = (
752                     (result == input).long() * (prior_visits > 0) * ar_mask
753                 ).sum()
754
755                 # nb_total = result.size(0)
756                 # nb_correct = ((result - input).abs().sum(1) == 0).sum()
757
758                 return nb_total, nb_correct
759
760             # train_nb_total, train_nb_correct = compute_nb_correct(
761             # self.train_input, self.train_prior_visits
762             # )
763
764             # log_string(
765             # f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
766             # )
767
768             test_nb_total, test_nb_correct = compute_nb_correct(
769                 self.test_input[:1000], self.test_prior_visits[:1000]
770             )
771
772             log_string(
773                 f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
774             )
775
776             model.train(t)
777
778
779 ######################################################################
780
781
782 def picoclvr_pruner_horizontal_green(p):
783     return not ("green" in p and ("left" in p or "right" in p))
784
785
786 picoclvr_pruner_train = (
787     picoclvr_pruner_horizontal_green
788     if args.picocvlr_prune_properties in {"train+eval"}
789     else None
790 )
791
792 picoclvr_pruner_eval = (
793     (lambda p: not picoclvr_pruner_horizontal_green(p))
794     if args.picocvlr_prune_properties in {"train+eval", "eval"}
795     else None
796 )
797
798 ######################################################################
799
800 if args.task == "picoclvr":
801     task = TaskPicoCLVR(
802         nb_train_samples=args.nb_train_samples,
803         nb_test_samples=args.nb_test_samples,
804         batch_size=args.batch_size,
805         height=args.picoclvr_height,
806         width=args.picoclvr_width,
807         nb_colors=args.picoclvr_nb_colors,
808         device=device,
809         pruner_train=picoclvr_pruner_train,
810         pruner_eval=picoclvr_pruner_eval,
811     )
812
813 elif args.task == "mnist":
814     task = TaskMNIST(
815         batch_size=args.batch_size,
816         device=device,
817     )
818
819 elif args.task == "maze":
820     task = TaskMaze(
821         nb_train_samples=args.nb_train_samples,
822         nb_test_samples=args.nb_test_samples,
823         batch_size=args.batch_size,
824         height=args.maze_height,
825         width=args.maze_width,
826         nb_walls=args.maze_nb_walls,
827         device=device,
828     )
829
830 elif args.task == "snake":
831     task = TaskSnake(
832         nb_train_samples=args.nb_train_samples,
833         nb_test_samples=args.nb_test_samples,
834         batch_size=args.batch_size,
835         height=args.snake_height,
836         width=args.snake_width,
837         nb_colors=args.snake_nb_colors,
838         length=args.snake_length,
839         prompt_length=args.snake_length // 2,
840         device=device,
841     )
842
843 else:
844     raise ValueError(f"Unknown task {args.task}")
845
846 ######################################################################
847
848 log_string(f"device {device}")
849
850 vocabulary_size = task.vocabulary_size()
851
852 log_string(f"vocabulary_size {vocabulary_size}")
853
854 ##############################
855
856 model = mygpt.MyGPT(
857     vocabulary_size=vocabulary_size,
858     dim_model=args.dim_model,
859     dim_keys=args.dim_keys,
860     dim_hidden=args.dim_hidden,
861     nb_heads=args.nb_heads,
862     nb_blocks=args.nb_blocks,
863     causal=True,
864     dropout=args.dropout,
865 )
866
867 model.to(device)
868
869 nb_parameters = sum(p.numel() for p in model.parameters())
870 log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
871
872 ######################################################################
873
874 nb_epochs_finished = 0
875
876 if args.no_checkpoint:
877     log_string(f"not trying to load checkpoint.")
878
879 else:
880     try:
881         checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
882         checkpoint = torch.load(checkpoint_name)
883         nb_epochs_finished = checkpoint["nb_epochs_finished"]
884         model.load_state_dict(checkpoint["model_state"])
885         torch.set_rng_state(checkpoint["rng_state"])
886         if torch.cuda.is_available():
887             torch.cuda.set_rng_state(checkpoint["cuda_rng_state"])
888
889         log_string(f"checkpoint loaded with {nb_epochs_finished} epochs finished.")
890
891     except FileNotFoundError:
892         log_string("starting from scratch.")
893
894     except:
895         log_string("error when loading the checkpoint.")
896         exit(1)
897
898 ######################################################################
899
900 nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
901
902 token_count = 0
903 for input in task.batches(split="train"):
904     token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1))
905 token_probas = token_count / token_count.sum()
906 entropy = -torch.xlogy(token_probas, token_probas).sum()
907 train_set_perplexity = math.exp(entropy)
908
909 ##############################
910
911 if args.learning_rate_schedule == "cos":
912     learning_rate_schedule = {}
913     for n_epoch in range(args.nb_epochs):
914         u = n_epoch / args.nb_epochs * math.pi
915         learning_rate_schedule[n_epoch] = args.learning_rate * 0.5 * (1 + math.cos(u))
916 else:
917     u = {
918         int(k): float(v)
919         for k, v in [
920             tuple(x.split(":")) for x in args.learning_rate_schedule.split(",")
921         ]
922     }
923
924     learning_rate_schedule = {}
925     learning_rate = args.learning_rate
926     for n_epoch in range(args.nb_epochs):
927         if n_epoch in u:
928             learning_rate = u[n_epoch]
929         learning_rate_schedule[n_epoch] = learning_rate
930
931 log_string(f"learning_rate_schedule {learning_rate_schedule}")
932
933 ##############################
934
935 nb_samples_seen = 0
936
937 if nb_epochs_finished >= nb_epochs:
938     task.produce_results(nb_epochs_finished, model)
939
940 for n_epoch in range(nb_epochs_finished, nb_epochs):
941     learning_rate = learning_rate_schedule[n_epoch]
942
943     log_string(f"learning_rate {learning_rate}")
944
945     if args.optim == "sgd":
946         optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
947     elif args.optim == "adam":
948         optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
949     elif args.optim == "adamw":
950         optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
951     else:
952         raise ValueError(f"Unknown optimizer {args.optim}.")
953
954     model.train()
955
956     nb_train_samples, acc_train_loss = 0, 0.0
957
958     for input in task.batches(split="train"):
959         input = input.to(device)
960         output = model(mygpt.BracketedSequence(input)).x
961         loss = F.cross_entropy(output.transpose(1, 2), input)
962         acc_train_loss += loss.item() * input.size(0)
963         nb_train_samples += input.size(0)
964         nb_samples_seen += input.size(0)
965
966         optimizer.zero_grad()
967         loss.backward()
968         optimizer.step()
969
970     with torch.autograd.no_grad():
971         model.eval()
972
973         nb_test_samples, acc_test_loss = 0, 0.0
974
975         for input in task.batches(split="test"):
976             input = input.to(device)
977
978             # input, loss_masks, true_images = task.excise_last_image(input)
979             # input, loss_masks = task.add_true_image(input, true_images, loss_masks)
980
981             output = model(mygpt.BracketedSequence(input)).x
982             loss = F.cross_entropy(output.transpose(1, 2), input)
983             acc_test_loss += loss.item() * input.size(0)
984             nb_test_samples += input.size(0)
985
986         train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
987         test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
988
989         log_string(
990             f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
991         )
992
993         task.produce_results(n_epoch, model)
994
995     checkpoint = {
996         "nb_epochs_finished": n_epoch + 1,
997         "model_state": model.state_dict(),
998         "rng_state": torch.get_rng_state(),
999     }
1000
1001     if torch.cuda.is_available():
1002         checkpoint["cuda_rng_state"] = torch.cuda.get_rng_state()
1003
1004     checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
1005     torch.save(checkpoint, checkpoint_name)
1006     log_string(f"saved checkpoint {checkpoint_name}")
1007
1008 ######################################################################