Update.
[picoclvr.git] / main.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # torch.backends.cuda.matmul.allow_tf23
9 # torch.autocast(torch.bfloat16)
10
11 import math, sys, argparse, time, tqdm, os
12
13 import torch, torchvision
14 from torch import nn
15 from torch.nn import functional as F
16
17 import mygpt, tensorstack
18
19 ######################################################################
20
21 if torch.cuda.is_available():
22     device = torch.device("cuda")
23     torch.backends.cuda.matmul.allow_tf32 = True
24 else:
25     device = torch.device("cpu")
26
27 ######################################################################
28
29 parser = argparse.ArgumentParser(
30     description="An implementation of GPT with cache.",
31     formatter_class=argparse.ArgumentDefaultsHelpFormatter,
32 )
33
34 parser.add_argument("--task", type=str, default="picoclvr")
35
36 parser.add_argument("--log_filename", type=str, default="train.log")
37
38 parser.add_argument("--result_dir", type=str, default="results_default")
39
40 parser.add_argument("--seed", type=int, default=0)
41
42 parser.add_argument("--nb_epochs", type=int, default=None)
43
44 parser.add_argument("--batch_size", type=int, default=None)
45
46 parser.add_argument("--nb_train_samples", type=int, default=250000)
47
48 parser.add_argument("--nb_test_samples", type=int, default=10000)
49
50 parser.add_argument("--optim", type=str, default="adam")
51
52 parser.add_argument("--learning_rate", type=float, default=1e-4)
53
54 parser.add_argument("--learning_rate_schedule", type=str, default="10: 2e-5,30: 4e-6")
55
56 parser.add_argument("--dim_model", type=int, default=512)
57
58 parser.add_argument("--dim_keys", type=int, default=64)
59
60 parser.add_argument("--dim_hidden", type=int, default=2048)
61
62 parser.add_argument("--nb_heads", type=int, default=8)
63
64 parser.add_argument("--nb_blocks", type=int, default=12)
65
66 parser.add_argument("--dropout", type=float, default=0.1)
67
68 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
69
70 parser.add_argument("--no_checkpoint", action="store_true", default=False)
71
72 parser.add_argument("--overwrite_results", action="store_true", default=False)
73
74 parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
75
76 ##############################
77 # picoclvr options
78
79 parser.add_argument("--picoclvr_nb_colors", type=int, default=5)
80
81 parser.add_argument("--picoclvr_height", type=int, default=12)
82
83 parser.add_argument("--picoclvr_width", type=int, default=16)
84
85 parser.add_argument("--picocvlr_prune_properties", type=str, default="none")
86
87 ##############################
88 # Maze options
89
90 parser.add_argument("--maze_height", type=int, default=13)
91
92 parser.add_argument("--maze_width", type=int, default=21)
93
94 parser.add_argument("--maze_nb_walls", type=int, default=15)
95
96 ##############################
97 # Snake options
98
99 parser.add_argument("--snake_height", type=int, default=6)
100
101 parser.add_argument("--snake_width", type=int, default=8)
102
103 parser.add_argument("--snake_nb_colors", type=int, default=5)
104
105 parser.add_argument("--snake_length", type=int, default=200)
106
107 ######################################################################
108
109 args = parser.parse_args()
110
111 assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"}
112
113 try:
114     os.mkdir(args.result_dir)
115 except FileExistsError:
116     if not args.overwrite_results:
117         print(f"result directory {args.result_dir} already exists")
118         exit(1)
119
120 log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
121
122 if args.seed >= 0:
123     # torch.backends.cudnn.deterministic = True
124     # torch.backends.cudnn.benchmark = False
125     # torch.use_deterministic_algorithms(True)
126     torch.manual_seed(args.seed)
127     if torch.cuda.is_available():
128         torch.cuda.manual_seed_all(args.seed)
129
130 ######################################################################
131
132 default_args = {
133     "picoclvr": {
134         "nb_epochs": 25,
135         "batch_size": 25,
136     },
137     "mnist": {
138         "nb_epochs": 25,
139         "batch_size": 10,
140     },
141     "maze": {
142         "nb_epochs": 25,
143         "batch_size": 25,
144     },
145     "snake": {
146         "nb_epochs": 5,
147         "batch_size": 25,
148     },
149 }
150
151 if args.task in default_args:
152     for k, v in default_args[args.task].items():
153         if getattr(args, k) is None:
154             setattr(args, k, v)
155
156 ######################################################################
157
158
159 def log_string(s):
160     t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
161
162     if log_file is not None:
163         log_file.write(t + s + "\n")
164         log_file.flush()
165
166     print(t + s)
167     sys.stdout.flush()
168
169
170 for n in vars(args):
171     log_string(f"args.{n} {getattr(args, n)}")
172
173 ######################################################################
174
175
176 # ra_mask is boolean, with 1s on the values to generate
177
178
179 def masked_inplace_autoregression(
180     model,
181     batch_size,
182     input,
183     ar_mask,
184     forbidden_tokens=None,
185     progress_bar_desc="autoregression",
186     device=torch.device("cpu"),
187 ):
188     batches = zip(input.split(batch_size), ar_mask.split(batch_size))
189     if progress_bar_desc is not None:
190         tqdm.tqdm(
191             batches,
192             dynamic_ncols=True,
193             desc=progress_bar_desc,
194             total=input.size(0) // batch_size,
195         )
196     for input, ar_mask in batches:
197         i = (ar_mask.sum(0) > 0).nonzero()
198         if i.min() > 0:
199             model(
200                 mygpt.BracketedSequence(input, 0, i.min())
201             )  # Needed to initialize the model's cache
202         for s in range(i.min(), i.max() + 1):
203             output = model(mygpt.BracketedSequence(input, s, 1)).x
204             logits = output[:, s]
205             if forbidden_tokens is not None:
206                 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
207             if args.deterministic_synthesis:
208                 t_next = logits.argmax(1)
209             else:
210                 dist = torch.distributions.categorical.Categorical(logits=logits)
211                 t_next = dist.sample()
212             input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
213
214
215 ######################################################################
216
217
218 class Task:
219     def batches(self, split="train"):
220         pass
221
222     def vocabulary_size(self):
223         pass
224
225     def produce_results(self, n_epoch, model):
226         pass
227
228
229 ######################################################################
230
231 import picoclvr
232
233
234 class TaskPicoCLVR(Task):
235     # Make a tensor from a list of strings
236     def tensorize(self, descr):
237         token_descr = [s.strip().split(" ") for s in descr]
238         l = max([len(s) for s in token_descr])
239         token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
240         id_descr = [[self.token2id[u] for u in s] for s in token_descr]
241         return torch.tensor(id_descr, device=self.device)
242
243     # Make a list of strings from a tensor
244     def detensorize(self, x):
245         return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
246
247     # trim all the tensors in the tuple z to remove as much token from
248     # left and right in the first tensor. If z is a tuple, all its
249     # elements are trimed according to the triming for the first
250     def trim(self, z, token="<nul>"):
251         n = self.token2id[token]
252         if type(z) == tuple:
253             x = z[0]
254             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
255             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
256             return tuple([t[:, a:b] for t in z])
257         else:
258             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
259             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
260             return z[:, a:b]
261
262     ######################
263     # Not the cleanest part of the code
264
265     # Extract the last image of each sequence, from the last <img>
266     # included, and set to <nul> all the tokens from the beginning of
267     # that image to the end
268     def excise_last_image(self, input):
269         t_img, t_nul = self.token2id["<img>"], self.token2id["<nul>"]
270         nb_img_tokens = self.height * self.width + 1
271
272         input = input.clone()
273         t = (input == t_img).long()
274         tail_masks = (t.cumsum(dim=1) == t.sum(dim=1, keepdim=True)).long()
275         i = (t * tail_masks).nonzero(as_tuple=True)
276         j = (
277             i[0][:, None],
278             i[1][:, None] + torch.arange(nb_img_tokens, device=input.device)[None, :],
279         )
280         images = self.trim(input[j])
281         input[j] = t_nul
282         loss_masks = 1 - tail_masks
283         input, loss_masks = self.trim((input, loss_masks))
284         return input, loss_masks, images
285
286     def add_true_image(self, input, images, loss_masks):
287         t_nul = self.token2id["<nul>"]
288         nb_img_tokens = self.height * self.width + 1
289         input = F.pad(input, (0, nb_img_tokens), value=t_nul)
290         loss_masks = F.pad(loss_masks, (0, nb_img_tokens), value=0)
291         t = (input == t_nul).long()
292         i = (t.cumsum(dim=1) == 1).nonzero(as_tuple=True)
293         j = (
294             i[0][:, None],
295             i[1][:, None] + torch.arange(nb_img_tokens, device=input.device)[None, :],
296         )
297         input[j] = images
298         loss_masks[j] = 1
299         input, loss_masks = self.trim((input, loss_masks))
300         return input, loss_masks
301
302     def add_generated_image(self, input, loss_masks, model):
303         t_img, t_nul = self.token2id["<img>"], self.token2id["<nul>"]
304         nb_img_tokens = self.height * self.width + 1
305
306         input = F.pad(input, (0, nb_img_tokens), value=t_nul)
307         loss_masks = F.pad(loss_masks, (0, nb_img_tokens), value=0)
308         t = (input == t_nul).long()
309         i = (t.cumsum(dim=1) == 1).nonzero(as_tuple=True)
310         input[i] = t_img
311
312         j = (
313             i[0][:, None],
314             i[1][:, None]
315             + 1
316             + torch.arange(nb_img_tokens - 1, device=input.device)[None, :],
317         )
318         ar_masks = input.new_zeros(input.size(), dtype=torch.int64)
319         ar_masks[j] = 1
320         forbidden_tokens = (
321             torch.arange(self.vocabulary_size(), device=input.device) == t_nul
322         )
323         with torch.autograd.no_grad():
324             t = model.training
325             model.eval()
326             masked_inplace_autoregression(
327                 model,
328                 self.batch_size,
329                 input,
330                 ar_masks,
331                 forbidden_tokens,
332                 progress_bar_desc=None,
333                 device=self.device,
334             )
335             model.train(t)
336
337         input, loss_masks = self.trim((input, loss_masks))
338
339         return input, loss_masks
340
341     ######################
342
343     def __init__(
344         self,
345         nb_train_samples,
346         nb_test_samples,
347         batch_size,
348         height,
349         width,
350         nb_colors=5,
351         device=torch.device("cpu"),
352         pruner_train=None,
353         pruner_eval=None,
354     ):
355         def generate_descr(nb, cache_suffix, pruner):
356             return picoclvr.generate(
357                 nb,
358                 height=self.height,
359                 width=self.width,
360                 nb_colors=nb_colors,
361                 pruner=pruner,
362             )
363
364         self.height = height
365         self.width = width
366         self.batch_size = batch_size
367         self.device = device
368         self.pruner_train = pruner_train
369         self.pruner_eval = pruner_eval
370
371         param = {
372             "nb_train_samples": nb_train_samples,
373             "nb_test_samples": nb_test_samples,
374             "height": height,
375             "width": width,
376             "nb_colors": nb_colors,
377             "batch_size": batch_size,
378             "rng_state": list(torch.get_rng_state()),
379         }
380
381         log_string(
382             f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
383         )
384         self.train_descr = generate_descr(
385             nb_train_samples, "train", pruner=self.pruner_train
386         )
387         self.test_descr = generate_descr(nb_test_samples, "test", pruner=None)
388
389         # Build the tokenizer
390         tokens = {"<nul>", "<img>"}
391         for d in [self.train_descr, self.test_descr]:
392             for s in d:
393                 for t in s.strip().split(" "):
394                     tokens.add(t)
395         # make this set a sorted list to get the same tensors given
396         # the same descr
397         tokens = list(tokens)
398         tokens.sort()
399         self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
400         self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
401
402         # Tokenize the train and test sets
403         self.train_input = self.tensorize(self.train_descr)
404         self.test_input = self.tensorize(self.test_descr)
405
406     def batches(self, split="train"):
407         assert split in {"train", "test"}
408         input = self.train_input if split == "train" else self.test_input
409         for batch in tqdm.tqdm(
410             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
411         ):
412             yield self.trim(batch)
413
414     def vocabulary_size(self):
415         return len(self.token2id)
416
417     def compute_missing_properties(self, n_epoch, model, pruner=None):
418         acc_nb_requested_properties = []
419         acc_nb_missing_properties = []
420         acc_nb_results = 0
421
422         for input in tqdm.tqdm(
423             self.test_input.split(self.batch_size),
424             dynamic_ncols=True,
425             desc=f"test-properties",
426         ):
427             tape, loss_masks, _ = self.excise_last_image(input)
428             tape, loss_masks = self.add_generated_image(tape, loss_masks, model)
429             result_descr = self.detensorize(tape)
430             np = picoclvr.nb_properties(
431                 result_descr,
432                 height=self.height,
433                 width=self.width,
434                 pruner=pruner,
435             )
436             nb_requested_properties, _, nb_missing_properties = zip(*np)
437             acc_nb_requested_properties += nb_requested_properties
438             acc_nb_missing_properties += nb_missing_properties
439             acc_nb_results += len(result_descr)
440
441         nb_requested_properties = sum(acc_nb_requested_properties)
442         nb_missing_properties = sum(acc_nb_missing_properties)
443
444         prefix = "" if pruner is None else "pruned_"
445         log_string(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
446         log_string(
447             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
448         )
449         log_string(
450             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
451         )
452
453     ######################################################################
454
455     def produce_results(self, n_epoch, model):
456         self.compute_missing_properties(n_epoch, model)
457
458         if self.pruner_eval is not None:
459             self.compute_missing_properties(n_epoch, model, self.pruner_eval)
460
461         nb_tokens_to_generate = self.height * self.width + 3
462         result_descr = []
463         nb_per_primer = 8
464         primer = []
465
466         for primer_descr in [
467             "red above green <sep> green top <sep> blue right of red",
468             "there is red <sep> there is yellow <sep> there is blue",
469             "red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left",
470             "green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top",
471         ]:
472             primer += [primer_descr] * nb_per_primer
473
474         tape = self.tensorize(primer)
475         loss_masks = 1 - (tape == self.token2id["<nul>"]).long()
476         tape, loss_masks = self.add_generated_image(tape, loss_masks, model)
477         result_descr = self.detensorize(tape)
478
479         np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
480
481         acc_nb_requested_properties, _, acc_nb_missing_properties = zip(*np)
482         acc_nb_results = len(result_descr)
483
484         nb_requested_properties = sum(acc_nb_requested_properties)
485         nb_missing_properties = sum(acc_nb_missing_properties)
486
487         prefix = "demo_"
488         log_string(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
489         log_string(
490             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
491         )
492         log_string(
493             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
494         )
495
496         img = picoclvr.descr2img(result_descr, height=self.height, width=self.width)
497
498         if img.dim() == 5:
499             if img.size(1) == 1:
500                 img = F.pad(img.squeeze(1), pad=(1, 1, 1, 1), value=64)
501             else:
502                 img = torch.cat(
503                     [
504                         torchvision.utils.make_grid(x, padding=1, pad_value=64)[None]
505                         for x in img
506                     ],
507                     0,
508                 )
509
510         image_name = os.path.join(args.result_dir, f"picoclvr_result_{n_epoch:04d}.png")
511         torchvision.utils.save_image(
512             img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=1.0
513         )
514         log_string(f"wrote {image_name}")
515
516
517 ######################################################################
518
519
520 class TaskMNIST(Task):
521     def __init__(self, batch_size, device=torch.device("cpu")):
522         self.device = device
523         self.batch_size = batch_size
524
525     def batches(self, split="train"):
526         assert split in {"train", "test"}
527         data_set = torchvision.datasets.MNIST(
528             root="./data", train=(split == "train"), download=True
529         )
530         data_input = data_set.data.view(-1, 28 * 28).long()
531         if args.nb_train_samples is not None:
532             data_input = data_input[: args.nb_train_samples]
533         for batch in tqdm.tqdm(
534             data_input.split(self.batch_size), desc=f"epoch-{split}"
535         ):
536             yield batch
537
538     def vocabulary_size(self):
539         return 256
540
541     def produce_results(self, n_epoch, model):
542         results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64)
543         ar_mask = torch.full_like(results, 1)
544         masked_inplace_autoregression(
545             model, self.batch_size, results, ar_mask, device=self.device
546         )
547         image_name = os.path.join(args.result_dir, f"mnist_result_{n_epoch:04d}.png")
548         torchvision.utils.save_image(
549             1 - results.reshape(-1, 1, 28, 28) / 255.0,
550             image_name,
551             nrow=16,
552             pad_value=0.8,
553         )
554         log_string(f"wrote {image_name}")
555
556
557 ######################################################################
558
559 import maze
560
561
562 class TaskMaze(Task):
563     def map2seq(self, *m):
564         return torch.cat([x.flatten(1) for x in m], 1)
565
566     def seq2map(self, s):
567         s = s.reshape(s.size(0), -1, self.height, self.width)
568         return (s[:, k] for k in range(s.size(1)))
569
570     def __init__(
571         self,
572         nb_train_samples,
573         nb_test_samples,
574         batch_size,
575         height,
576         width,
577         nb_walls,
578         device=torch.device("cpu"),
579     ):
580         self.batch_size = batch_size
581         self.height = height
582         self.width = width
583         self.device = device
584
585         train_mazes, train_paths, _ = maze.create_maze_data(
586             nb_train_samples,
587             height=height,
588             width=width,
589             nb_walls=nb_walls,
590             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"),
591         )
592         self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device))
593
594         test_mazes, test_paths, _ = maze.create_maze_data(
595             nb_test_samples,
596             height=height,
597             width=width,
598             nb_walls=nb_walls,
599             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"),
600         )
601         self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
602
603         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
604
605     def batches(self, split="train", nb_to_use=-1, desc=None):
606         assert split in {"train", "test"}
607         input = self.train_input if split == "train" else self.test_input
608         if nb_to_use > 0:
609             input = input[:nb_to_use]
610         if desc is None:
611             desc = f"epoch-{split}"
612         for batch in tqdm.tqdm(
613             input.split(self.batch_size), dynamic_ncols=True, desc=desc
614         ):
615             yield batch
616
617     def vocabulary_size(self):
618         return self.nb_codes
619
620     def compute_error(self, model, split="train", nb_to_use=-1):
621         nb_total, nb_correct = 0, 0
622         for input in task.batches(split, nb_to_use):
623             result = input.clone()
624             ar_mask = result.new_zeros(result.size())
625             ar_mask[:, self.height * self.width :] = 1
626             result *= 1 - ar_mask
627             masked_inplace_autoregression(
628                 model, self.batch_size, result, ar_mask, device=self.device
629             )
630             mazes, paths = self.seq2map(result)
631             nb_correct += maze.path_correctness(mazes, paths).long().sum()
632             nb_total += mazes.size(0)
633
634         return nb_total, nb_correct
635
636     def produce_results(self, n_epoch, model):
637         with torch.autograd.no_grad():
638             t = model.training
639             model.eval()
640
641             train_nb_total, train_nb_correct = self.compute_error(
642                 model, "train", nb_to_use=1000
643             )
644             log_string(
645                 f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
646             )
647
648             test_nb_total, test_nb_correct = self.compute_error(
649                 model, "test", nb_to_use=1000
650             )
651             log_string(
652                 f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
653             )
654
655             input = self.test_input[:48]
656             result = input.clone()
657             ar_mask = result.new_zeros(result.size())
658             ar_mask[:, self.height * self.width :] = 1
659             result *= 1 - ar_mask
660             masked_inplace_autoregression(
661                 model, self.batch_size, result, ar_mask, device=self.device
662             )
663
664             mazes, paths = self.seq2map(input)
665             _, predicted_paths = self.seq2map(result)
666
667             filename = os.path.join(args.result_dir, f"maze_result_{n_epoch:04d}.png")
668             maze.save_image(
669                 filename,
670                 mazes=mazes,
671                 target_paths=paths,
672                 predicted_paths=predicted_paths,
673                 path_correct=maze.path_correctness(mazes, predicted_paths),
674             )
675             log_string(f"wrote {filename}")
676
677             model.train(t)
678
679
680 ######################################################################
681
682
683 import snake
684
685
686 class TaskSnake(Task):
687     def __init__(
688         self,
689         nb_train_samples,
690         nb_test_samples,
691         batch_size,
692         height,
693         width,
694         nb_colors,
695         length,
696         prompt_length,
697         device=torch.device("cpu"),
698     ):
699         self.batch_size = batch_size
700         self.height = height
701         self.width = width
702         self.device = device
703         self.prompt_length = prompt_length
704
705         self.train_input, self.train_prior_visits, _, _ = snake.generate_sequences(
706             nb_train_samples,
707             height,
708             width,
709             nb_colors,
710             length,
711             prompt_length,
712             self.device,
713         )
714         self.test_input, self.test_prior_visits, _, _ = snake.generate_sequences(
715             nb_test_samples,
716             height,
717             width,
718             nb_colors,
719             length,
720             prompt_length,
721             self.device,
722         )
723
724         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
725
726     def batches(self, split="train", nb_to_use=-1, desc=None):
727         assert split in {"train", "test"}
728         input = self.train_input if split == "train" else self.test_input
729         if nb_to_use > 0:
730             input = input[:nb_to_use]
731         if desc is None:
732             desc = f"epoch-{split}"
733         for batch in tqdm.tqdm(
734             input.split(self.batch_size), dynamic_ncols=True, desc=desc
735         ):
736             yield batch
737
738     def vocabulary_size(self):
739         return self.nb_codes
740
741     def produce_results(self, n_epoch, model):
742         with torch.autograd.no_grad():
743             t = model.training
744             model.eval()
745
746             def compute_nb_correct(input, prior_visits):
747                 result = input.clone()
748                 i = torch.arange(result.size(1), device=result.device)[None, :]
749                 ar_mask = (
750                     torch.logical_and(i >= self.prompt_length * 2, i % 2 == 0)
751                     .long()
752                     .expand_as(result)
753                 )
754                 result *= 1 - ar_mask
755
756                 # snake.solver(result,ar_mask)
757
758                 masked_inplace_autoregression(
759                     model, self.batch_size, result, ar_mask, device=self.device
760                 )
761
762                 nb_total = ((prior_visits > 0) * ar_mask).sum()
763
764                 nb_correct = (
765                     (result == input).long() * (prior_visits > 0) * ar_mask
766                 ).sum()
767
768                 # nb_total = result.size(0)
769                 # nb_correct = ((result - input).abs().sum(1) == 0).sum()
770
771                 return nb_total, nb_correct
772
773             # train_nb_total, train_nb_correct = compute_nb_correct(
774             # self.train_input, self.train_prior_visits
775             # )
776
777             # log_string(
778             # f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
779             # )
780
781             test_nb_total, test_nb_correct = compute_nb_correct(
782                 self.test_input[:1000], self.test_prior_visits[:1000]
783             )
784
785             log_string(
786                 f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
787             )
788
789             model.train(t)
790
791
792 ######################################################################
793
794
795 def picoclvr_pruner_horizontal_green(p):
796     return not ("green" in p and ("left" in p or "right" in p))
797
798
799 picoclvr_pruner_train = (
800     picoclvr_pruner_horizontal_green
801     if args.picocvlr_prune_properties in {"train+eval"}
802     else None
803 )
804
805 picoclvr_pruner_eval = (
806     (lambda p: not picoclvr_pruner_horizontal_green(p))
807     if args.picocvlr_prune_properties in {"train+eval", "eval"}
808     else None
809 )
810
811 ######################################################################
812
813 if args.task == "picoclvr":
814     task = TaskPicoCLVR(
815         nb_train_samples=args.nb_train_samples,
816         nb_test_samples=args.nb_test_samples,
817         batch_size=args.batch_size,
818         height=args.picoclvr_height,
819         width=args.picoclvr_width,
820         nb_colors=args.picoclvr_nb_colors,
821         device=device,
822         pruner_train=picoclvr_pruner_train,
823         pruner_eval=picoclvr_pruner_eval,
824     )
825
826 elif args.task == "mnist":
827     task = TaskMNIST(
828         batch_size=args.batch_size,
829         device=device,
830     )
831
832 elif args.task == "maze":
833     task = TaskMaze(
834         nb_train_samples=args.nb_train_samples,
835         nb_test_samples=args.nb_test_samples,
836         batch_size=args.batch_size,
837         height=args.maze_height,
838         width=args.maze_width,
839         nb_walls=args.maze_nb_walls,
840         device=device,
841     )
842
843 elif args.task == "snake":
844     task = TaskSnake(
845         nb_train_samples=args.nb_train_samples,
846         nb_test_samples=args.nb_test_samples,
847         batch_size=args.batch_size,
848         height=args.snake_height,
849         width=args.snake_width,
850         nb_colors=args.snake_nb_colors,
851         length=args.snake_length,
852         prompt_length=args.snake_length // 2,
853         device=device,
854     )
855
856 else:
857     raise ValueError(f"Unknown task {args.task}")
858
859 ######################################################################
860
861 log_string(f"device {device}")
862
863 vocabulary_size = task.vocabulary_size()
864
865 log_string(f"vocabulary_size {vocabulary_size}")
866
867 ##############################
868
869 model = mygpt.MyGPT(
870     vocabulary_size=vocabulary_size,
871     dim_model=args.dim_model,
872     dim_keys=args.dim_keys,
873     dim_hidden=args.dim_hidden,
874     nb_heads=args.nb_heads,
875     nb_blocks=args.nb_blocks,
876     causal=True,
877     dropout=args.dropout,
878 )
879
880 model.to(device)
881
882 nb_parameters = sum(p.numel() for p in model.parameters())
883 log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
884
885 ######################################################################
886
887 nb_epochs_finished = 0
888
889 if args.no_checkpoint:
890     log_string(f"not trying to load checkpoint.")
891
892 else:
893     try:
894         checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
895         checkpoint = torch.load(checkpoint_name)
896         nb_epochs_finished = checkpoint["nb_epochs_finished"]
897         model.load_state_dict(checkpoint["model_state"])
898         torch.set_rng_state(checkpoint["rng_state"])
899         if torch.cuda.is_available():
900             torch.cuda.set_rng_state(checkpoint["cuda_rng_state"])
901
902         log_string(f"checkpoint loaded with {nb_epochs_finished} epochs finished.")
903
904     except FileNotFoundError:
905         log_string("starting from scratch.")
906
907     except:
908         log_string("error when loading the checkpoint.")
909         exit(1)
910
911 ######################################################################
912
913 nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
914
915 token_count = 0
916 for input in task.batches(split="train"):
917     token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1))
918 token_probas = token_count / token_count.sum()
919 entropy = -torch.xlogy(token_probas, token_probas).sum()
920 train_set_perplexity = math.exp(entropy)
921
922 ##############################
923
924 if args.learning_rate_schedule == "cos":
925     learning_rate_schedule = {}
926     for n_epoch in range(args.nb_epochs):
927         u = n_epoch / args.nb_epochs * math.pi
928         learning_rate_schedule[n_epoch] = args.learning_rate * 0.5 * (1 + math.cos(u))
929 else:
930     u = {
931         int(k): float(v)
932         for k, v in [
933             tuple(x.split(":")) for x in args.learning_rate_schedule.split(",")
934         ]
935     }
936
937     learning_rate_schedule = {}
938     learning_rate = args.learning_rate
939     for n_epoch in range(args.nb_epochs):
940         if n_epoch in u:
941             learning_rate = u[n_epoch]
942         learning_rate_schedule[n_epoch] = learning_rate
943
944 log_string(f"learning_rate_schedule {learning_rate_schedule}")
945
946 ##############################
947
948 nb_samples_seen = 0
949
950 if nb_epochs_finished >= nb_epochs:
951     task.produce_results(nb_epochs_finished, model)
952
953 for n_epoch in range(nb_epochs_finished, nb_epochs):
954     learning_rate = learning_rate_schedule[n_epoch]
955
956     log_string(f"learning_rate {learning_rate}")
957
958     if args.optim == "sgd":
959         optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
960     elif args.optim == "adam":
961         optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
962     elif args.optim == "adamw":
963         optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
964     else:
965         raise ValueError(f"Unknown optimizer {args.optim}.")
966
967     model.train()
968
969     nb_train_samples, acc_train_loss = 0, 0.0
970
971     for input in task.batches(split="train"):
972         input = input.to(device)
973         output = model(mygpt.BracketedSequence(input)).x
974         loss = F.cross_entropy(output.transpose(1, 2), input)
975         acc_train_loss += loss.item() * input.size(0)
976         nb_train_samples += input.size(0)
977         nb_samples_seen += input.size(0)
978
979         optimizer.zero_grad()
980         loss.backward()
981         optimizer.step()
982
983     with torch.autograd.no_grad():
984         model.eval()
985
986         nb_test_samples, acc_test_loss = 0, 0.0
987
988         for input in task.batches(split="test"):
989             input = input.to(device)
990
991             output = model(mygpt.BracketedSequence(input)).x
992             loss = F.cross_entropy(output.transpose(1, 2), input)
993             acc_test_loss += loss.item() * input.size(0)
994             nb_test_samples += input.size(0)
995
996         train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
997         test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
998
999         log_string(
1000             f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
1001         )
1002
1003         task.produce_results(n_epoch, model)
1004
1005     checkpoint = {
1006         "nb_epochs_finished": n_epoch + 1,
1007         "model_state": model.state_dict(),
1008         "rng_state": torch.get_rng_state(),
1009     }
1010
1011     if torch.cuda.is_available():
1012         checkpoint["cuda_rng_state"] = torch.cuda.get_rng_state()
1013
1014     checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
1015     torch.save(checkpoint, checkpoint_name)
1016     log_string(f"saved checkpoint {checkpoint_name}")
1017
1018 ######################################################################