Update.
[picoclvr.git] / main.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # torch.backends.cuda.matmul.allow_tf23
9 # torch.autocast(torch.bfloat16)
10
11 import math, sys, argparse, time, tqdm, os
12
13 import torch, torchvision
14 from torch import nn
15 from torch.nn import functional as F
16
17 import mygpt, tensorstack
18
19 ######################################################################
20
21 if torch.cuda.is_available():
22     device = torch.device("cuda")
23     torch.backends.cuda.matmul.allow_tf32 = True
24 else:
25     device = torch.device("cpu")
26
27 ######################################################################
28
29 parser = argparse.ArgumentParser(
30     description="An implementation of GPT with cache.",
31     formatter_class=argparse.ArgumentDefaultsHelpFormatter,
32 )
33
34 parser.add_argument("--task", type=str, default="picoclvr")
35
36 parser.add_argument("--log_filename", type=str, default="train.log")
37
38 parser.add_argument("--result_dir", type=str, default="results_default")
39
40 parser.add_argument("--seed", type=int, default=0)
41
42 parser.add_argument("--nb_epochs", type=int, default=25)
43
44 parser.add_argument("--batch_size", type=int, default=None)
45
46 parser.add_argument("--nb_train_samples", type=int, default=250000)
47
48 parser.add_argument("--nb_test_samples", type=int, default=10000)
49
50 parser.add_argument("--optim", type=str, default="adam")
51
52 parser.add_argument("--learning_rate", type=float, default=1e-4)
53
54 parser.add_argument("--learning_rate_schedule", type=str, default="10: 2e-5,30: 4e-6")
55
56 parser.add_argument("--dim_model", type=int, default=512)
57
58 parser.add_argument("--dim_keys", type=int, default=64)
59
60 parser.add_argument("--dim_hidden", type=int, default=2048)
61
62 parser.add_argument("--nb_heads", type=int, default=8)
63
64 parser.add_argument("--nb_blocks", type=int, default=12)
65
66 parser.add_argument("--dropout", type=float, default=0.1)
67
68 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
69
70 parser.add_argument("--no_checkpoint", action="store_true", default=False)
71
72 parser.add_argument("--overwrite_results", action="store_true", default=False)
73
74 parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
75
76 ##############################
77 # picoclvr options
78
79 parser.add_argument("--picoclvr_nb_colors", type=int, default=5)
80
81 parser.add_argument("--picoclvr_height", type=int, default=12)
82
83 parser.add_argument("--picoclvr_width", type=int, default=16)
84
85 parser.add_argument("--picocvlr_prune_properties", type=str, default="none")
86
87 ##############################
88 # Maze options
89
90 parser.add_argument("--maze_height", type=int, default=13)
91
92 parser.add_argument("--maze_width", type=int, default=21)
93
94 parser.add_argument("--maze_nb_walls", type=int, default=15)
95
96 ##############################
97 # Snake options
98
99 parser.add_argument("--snake_height", type=int, default=6)
100
101 parser.add_argument("--snake_width", type=int, default=8)
102
103 parser.add_argument("--snake_nb_colors", type=int, default=3)
104
105 parser.add_argument("--snake_length", type=int, default=400)
106
107 ######################################################################
108
109 args = parser.parse_args()
110
111 assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"}
112
113 try:
114     os.mkdir(args.result_dir)
115 except FileExistsError:
116     if not args.overwrite_results:
117         print(f"result directory {args.result_dir} already exists")
118         exit(1)
119
120 log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
121
122 if args.seed >= 0:
123     # torch.backends.cudnn.deterministic = True
124     # torch.backends.cudnn.benchmark = False
125     # torch.use_deterministic_algorithms(True)
126     torch.manual_seed(args.seed)
127     if torch.cuda.is_available():
128         torch.cuda.manual_seed_all(args.seed)
129
130 ######################################################################
131
132 default_args = {
133     "picoclvr": {
134         "batch_size": 25,
135     },
136     "mnist": {
137         "batch_size": 10,
138     },
139     "maze": {
140         "batch_size": 25,
141     },
142     "snake": {
143         "batch_size": 20,
144     },
145 }
146
147 if args.task in default_args:
148     for k, v in default_args[args.task].items():
149         if getattr(args, k) is None:
150             setattr(args, k, v)
151
152 ######################################################################
153
154
155 def log_string(s):
156     t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
157
158     if log_file is not None:
159         log_file.write(t + s + "\n")
160         log_file.flush()
161
162     print(t + s)
163     sys.stdout.flush()
164
165
166 for n in vars(args):
167     log_string(f"args.{n} {getattr(args, n)}")
168
169 ######################################################################
170
171
172 def masked_inplace_autoregression(
173     model, batch_size, input, ar_mask, forbidden_tokens=None, device=torch.device("cpu")
174 ):
175     for input, ar_mask in tqdm.tqdm(
176         zip(input.split(batch_size), ar_mask.split(batch_size)),
177         dynamic_ncols=True,
178         desc="autoregression",
179         total=input.size(0) // batch_size,
180     ):
181         i = (ar_mask.sum(0) > 0).nonzero()
182         if i.min() > 0:
183             model(
184                 mygpt.BracketedSequence(input, 0, i.min())
185             )  # Needed to initialize the model's cache
186         for s in range(i.min(), i.max() + 1):
187             output = model(mygpt.BracketedSequence(input, s, 1)).x
188             logits = output[:, s]
189             if forbidden_tokens is not None:
190                 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
191             if args.deterministic_synthesis:
192                 t_next = logits.argmax(1)
193             else:
194                 dist = torch.distributions.categorical.Categorical(logits=logits)
195                 t_next = dist.sample()
196             input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
197
198
199 ######################################################################
200
201
202 class Task:
203     def batches(self, split="train"):
204         pass
205
206     def vocabulary_size(self):
207         pass
208
209     def produce_results(self, n_epoch, model):
210         pass
211
212
213 ######################################################################
214
215 import picoclvr
216
217
218 class TaskPicoCLVR(Task):
219     # Make a tensor from a list of strings
220     def tensorize(self, descr):
221         token_descr = [s.strip().split(" ") for s in descr]
222         l = max([len(s) for s in token_descr])
223         token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
224         id_descr = [[self.token2id[u] for u in s] for s in token_descr]
225         return torch.tensor(id_descr, device=self.device)
226
227     # Make a list of strings from a tensor
228     def detensorize(self, x):
229         return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
230
231     # trim all the tensors in the tuple z to remove as much token from
232     # left and right in the first tensor. If z is a tuple, all its
233     # elements are trimed according to the triming for the first
234     def trim(self, z, token="<nul>"):
235         n = self.token2id[token]
236         if type(z) == tuple:
237             x = z[0]
238             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
239             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
240             return tuple([t[:, a:b] for t in z])
241         else:
242             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
243             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
244             return z[:, a:b]
245
246     ######################
247     # Not the cleanest part of the code
248
249     # Extract the last image of each sequence, from the last <img>
250     # included, and set to <nul> all the tokens from the beginning of
251     # that image to the end
252     def excise_last_image(self, input):
253         t_img, t_nul = self.token2id["<img>"], self.token2id["<nul>"]
254         nb_img_tokens = self.height * self.width + 1
255
256         input = input.clone()
257         t = (input == t_img).long()
258         tail_masks = (t.cumsum(dim=1) == t.sum(dim=1, keepdim=True)).long()
259         i = (t * tail_masks).nonzero(as_tuple=True)
260         j = (
261             i[0][:, None],
262             i[1][:, None] + torch.arange(nb_img_tokens, device=input.device)[None, :],
263         )
264         images = self.trim(input[j])
265         input[j] = t_nul
266         loss_masks = 1 - tail_masks
267         input, loss_masks = self.trim((input, loss_masks))
268         return input, loss_masks, images
269
270     def add_true_image(self, input, images, loss_masks):
271         t_nul = self.token2id["<nul>"]
272         nb_img_tokens = self.height * self.width + 1
273         input = F.pad(input, (0, nb_img_tokens), value=t_nul)
274         loss_masks = F.pad(loss_masks, (0, nb_img_tokens), value=0)
275         t = (input == t_nul).long()
276         i = (t.cumsum(dim=1) == 1).nonzero(as_tuple=True)
277         j = (
278             i[0][:, None],
279             i[1][:, None] + torch.arange(nb_img_tokens, device=input.device)[None, :],
280         )
281         input[j] = images
282         loss_masks[j] = 1
283         input, loss_masks = self.trim((input, loss_masks))
284         return input, loss_masks
285
286     def add_generated_image(self, input, loss_masks, model):
287         t_img, t_nul = self.token2id["<img>"], self.token2id["<nul>"]
288         nb_img_tokens = self.height * self.width + 1
289
290         input = F.pad(input, (0, nb_img_tokens), value=t_nul)
291         loss_masks = F.pad(loss_masks, (0, nb_img_tokens), value=0)
292         t = (input == t_nul).long()
293         i = (t.cumsum(dim=1) == 1).nonzero(as_tuple=True)
294         input[i] = t_img
295
296         j = (
297             i[0][:, None],
298             i[1][:, None]
299             + 1
300             + torch.arange(nb_img_tokens - 1, device=input.device)[None, :],
301         )
302         ar_masks = input.new_zeros(input.size(), dtype=torch.int64)
303         ar_masks[j] = 1
304         forbidden_tokens = (
305             torch.arange(self.vocabulary_size(), device=input.device) == t_nul
306         )
307         with torch.autograd.no_grad():
308             t = model.training
309             model.eval()
310             masked_inplace_autoregression(
311                 model,
312                 self.batch_size,
313                 input,
314                 ar_masks,
315                 forbidden_tokens,
316                 device=self.device,
317             )
318             model.train(t)
319
320         input, loss_masks = self.trim((input, loss_masks))
321
322         return input, loss_masks
323
324     ######################
325
326     def __init__(
327         self,
328         nb_train_samples,
329         nb_test_samples,
330         batch_size,
331         height,
332         width,
333         nb_colors=5,
334         device=torch.device("cpu"),
335         pruner_train=None,
336         pruner_eval=None,
337     ):
338         def generate_descr(nb, cache_suffix, pruner):
339             return picoclvr.generate(
340                 nb,
341                 height=self.height,
342                 width=self.width,
343                 nb_colors=nb_colors,
344                 pruner=pruner,
345             )
346
347         self.height = height
348         self.width = width
349         self.batch_size = batch_size
350         self.device = device
351         self.pruner_train = pruner_train
352         self.pruner_eval = pruner_eval
353
354         param = {
355             "nb_train_samples": nb_train_samples,
356             "nb_test_samples": nb_test_samples,
357             "height": height,
358             "width": width,
359             "nb_colors": nb_colors,
360             "batch_size": batch_size,
361             "rng_state": list(torch.get_rng_state()),
362         }
363
364         log_string(
365             f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
366         )
367         self.train_descr = generate_descr(
368             nb_train_samples, "train", pruner=self.pruner_train
369         )
370         self.test_descr = generate_descr(nb_test_samples, "test", pruner=None)
371
372         # Build the tokenizer
373         tokens = {"<nul>", "<img>"}
374         for d in [self.train_descr, self.test_descr]:
375             for s in d:
376                 for t in s.strip().split(" "):
377                     tokens.add(t)
378         # make this set a sorted list to get the same tensors given
379         # the same descr
380         tokens = list(tokens)
381         tokens.sort()
382         self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
383         self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
384
385         # Tokenize the train and test sets
386         self.train_input = self.tensorize(self.train_descr)
387         self.test_input = self.tensorize(self.test_descr)
388
389     def batches(self, split="train"):
390         assert split in {"train", "test"}
391         input = self.train_input if split == "train" else self.test_input
392         for batch in tqdm.tqdm(
393             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
394         ):
395             yield self.trim(batch)
396
397     def vocabulary_size(self):
398         return len(self.token2id)
399
400     def compute_missing_properties(self, n_epoch, model, pruner=None):
401         acc_nb_requested_properties = []
402         acc_nb_missing_properties = []
403         acc_nb_results = 0
404
405         for input in tqdm.tqdm(
406             self.test_input.split(self.batch_size),
407             dynamic_ncols=True,
408             desc=f"test-properties",
409         ):
410             tape, loss_masks, _ = self.excise_last_image(input)
411             tape, loss_masks = self.add_generated_image(tape, loss_masks, model)
412             result_descr = self.detensorize(tape)
413             np = picoclvr.nb_properties(
414                 result_descr,
415                 height=self.height,
416                 width=self.width,
417                 pruner=pruner,
418             )
419             nb_requested_properties, _, nb_missing_properties = zip(*np)
420             acc_nb_requested_properties += nb_requested_properties
421             acc_nb_missing_properties += nb_missing_properties
422             acc_nb_results += len(result_descr)
423
424         nb_requested_properties = sum(acc_nb_requested_properties)
425         nb_missing_properties = sum(acc_nb_missing_properties)
426
427         prefix = "" if pruner is None else "pruned_"
428         log_string(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
429         log_string(
430             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
431         )
432         log_string(
433             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
434         )
435
436     ######################################################################
437
438     def produce_results(self, n_epoch, model):
439         self.compute_missing_properties(n_epoch, model)
440
441         if self.pruner_eval is not None:
442             self.compute_missing_properties(n_epoch, model, self.pruner_eval)
443
444         nb_tokens_to_generate = self.height * self.width + 3
445         result_descr = []
446         nb_per_primer = 8
447         primer = []
448
449         for primer_descr in [
450             "red above green <sep> green top <sep> blue right of red",
451             "there is red <sep> there is yellow <sep> there is blue",
452             "red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left",
453             "green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top",
454         ]:
455             primer += [primer_descr] * nb_per_primer
456
457         tape = self.tensorize(primer)
458         loss_masks = 1 - (tape == self.token2id["<nul>"]).long()
459         tape, loss_masks = self.add_generated_image(tape, loss_masks, model)
460         result_descr = self.detensorize(tape)
461
462         np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
463
464         acc_nb_requested_properties, _, acc_nb_missing_properties = zip(*np)
465         acc_nb_results = len(result_descr)
466
467         nb_requested_properties = sum(acc_nb_requested_properties)
468         nb_missing_properties = sum(acc_nb_missing_properties)
469
470         prefix = "demo_"
471         log_string(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
472         log_string(
473             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
474         )
475         log_string(
476             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
477         )
478
479         img = picoclvr.descr2img(result_descr, height=self.height, width=self.width)
480
481         if img.dim() == 5:
482             if img.size(1) == 1:
483                 img = F.pad(img.squeeze(1), pad=(1, 1, 1, 1), value=64)
484             else:
485                 img = torch.cat(
486                     [
487                         torchvision.utils.make_grid(x, padding=1, pad_value=64)[None]
488                         for x in img
489                     ],
490                     0,
491                 )
492
493         image_name = os.path.join(args.result_dir, f"picoclvr_result_{n_epoch:04d}.png")
494         torchvision.utils.save_image(
495             img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=1.0
496         )
497         log_string(f"wrote {image_name}")
498
499
500 ######################################################################
501
502
503 class TaskMNIST(Task):
504     def __init__(self, batch_size, device=torch.device("cpu")):
505         self.device = device
506         self.batch_size = batch_size
507
508     def batches(self, split="train"):
509         assert split in {"train", "test"}
510         data_set = torchvision.datasets.MNIST(
511             root="./data", train=(split == "train"), download=True
512         )
513         data_input = data_set.data.view(-1, 28 * 28).long()
514         if args.nb_train_samples is not None:
515             data_input = data_input[: args.nb_train_samples]
516         for batch in tqdm.tqdm(
517             data_input.split(self.batch_size), desc=f"epoch-{split}"
518         ):
519             yield batch
520
521     def vocabulary_size(self):
522         return 256
523
524     def produce_results(self, n_epoch, model):
525         results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64)
526         ar_mask = torch.full_like(results, 1)
527         masked_inplace_autoregression(
528             model, self.batch_size, results, ar_mask, device=self.device
529         )
530         image_name = os.path.join(args.result_dir, f"mnist_result_{n_epoch:04d}.png")
531         torchvision.utils.save_image(
532             1 - results.reshape(-1, 1, 28, 28) / 255.0,
533             image_name,
534             nrow=16,
535             pad_value=0.8,
536         )
537         log_string(f"wrote {image_name}")
538
539
540 ######################################################################
541
542 import maze
543
544
545 class TaskMaze(Task):
546     def map2seq(self, *m):
547         return torch.cat([x.flatten(1) for x in m], 1)
548
549     def seq2map(self, s):
550         s = s.reshape(s.size(0), -1, self.height, self.width)
551         return (s[:, k] for k in range(s.size(1)))
552
553     def __init__(
554         self,
555         nb_train_samples,
556         nb_test_samples,
557         batch_size,
558         height,
559         width,
560         nb_walls,
561         device=torch.device("cpu"),
562     ):
563         self.batch_size = batch_size
564         self.height = height
565         self.width = width
566         self.device = device
567
568         train_mazes, train_paths, _ = maze.create_maze_data(
569             nb_train_samples,
570             height=height,
571             width=width,
572             nb_walls=nb_walls,
573             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"),
574         )
575         self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device))
576
577         test_mazes, test_paths, _ = maze.create_maze_data(
578             nb_test_samples,
579             height=height,
580             width=width,
581             nb_walls=nb_walls,
582             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"),
583         )
584         self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
585
586         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
587
588     def batches(self, split="train", nb_to_use=-1, desc=None):
589         assert split in {"train", "test"}
590         input = self.train_input if split == "train" else self.test_input
591         if nb_to_use > 0:
592             input = input[:nb_to_use]
593         if desc is None:
594             desc = f"epoch-{split}"
595         for batch in tqdm.tqdm(
596             input.split(self.batch_size), dynamic_ncols=True, desc=desc
597         ):
598             yield batch
599
600     def vocabulary_size(self):
601         return self.nb_codes
602
603     def compute_error(self, model, split="train", nb_to_use=-1):
604         nb_total, nb_correct = 0, 0
605         for input in task.batches(split, nb_to_use):
606             result = input.clone()
607             ar_mask = result.new_zeros(result.size())
608             ar_mask[:, self.height * self.width :] = 1
609             result *= 1 - ar_mask
610             masked_inplace_autoregression(
611                 model, self.batch_size, result, ar_mask, device=self.device
612             )
613             mazes, paths = self.seq2map(result)
614             nb_correct += maze.path_correctness(mazes, paths).long().sum()
615             nb_total += mazes.size(0)
616
617         return nb_total, nb_correct
618
619     def produce_results(self, n_epoch, model):
620         with torch.autograd.no_grad():
621             t = model.training
622             model.eval()
623
624             train_nb_total, train_nb_correct = self.compute_error(
625                 model, "train", nb_to_use=1000
626             )
627             log_string(
628                 f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
629             )
630
631             test_nb_total, test_nb_correct = self.compute_error(
632                 model, "test", nb_to_use=1000
633             )
634             log_string(
635                 f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
636             )
637
638             input = self.test_input[:48]
639             result = input.clone()
640             ar_mask = result.new_zeros(result.size())
641             ar_mask[:, self.height * self.width :] = 1
642             result *= 1 - ar_mask
643             masked_inplace_autoregression(
644                 model, self.batch_size, result, ar_mask, device=self.device
645             )
646
647             mazes, paths = self.seq2map(input)
648             _, predicted_paths = self.seq2map(result)
649
650             filename = os.path.join(args.result_dir, f"maze_result_{n_epoch:04d}.png")
651             maze.save_image(
652                 filename,
653                 mazes=mazes,
654                 target_paths=paths,
655                 predicted_paths=predicted_paths,
656                 path_correct=maze.path_correctness(mazes, predicted_paths),
657             )
658             log_string(f"wrote {filename}")
659
660             model.train(t)
661
662
663 ######################################################################
664
665
666 def generate_snake_sequences(
667     nb, height, width, nb_colors, length, prompt_length, device=torch.device("cpu")
668 ):
669     worlds = torch.randint(nb_colors, (nb, height, width), device=device)
670     nb_prior_visits = torch.zeros(nb, height, width, device=device)
671
672     # nb x 2
673     snake_position = torch.cat(
674         (
675             torch.randint(height, (nb, 1), device=device),
676             torch.randint(width, (nb, 1), device=device),
677         ),
678         1,
679     )
680     snake_direction = torch.randint(4, (nb,), device=device)
681     sequences = torch.empty(nb, 2 * length, device=device, dtype=torch.int64)
682     sequences_prior_visits = torch.zeros(
683         nb, 2 * length, device=device, dtype=torch.int64
684     )
685     i = torch.arange(nb, device=device)  # [:,None]
686
687     for l in range(length):
688         # nb x 3
689         snake_next_direction = torch.cat(
690             (
691                 (snake_direction[:, None] - 1) % 4,
692                 snake_direction[:, None],
693                 (snake_direction[:, None] + 1) % 4,
694             ),
695             1,
696         )
697
698         # nb x 3
699         vh = (snake_next_direction + 1) % 2 * (snake_next_direction - 1)
700         vw = snake_next_direction % 2 * (snake_next_direction - 2)
701
702         # nb x 3 x 2
703         snake_next_speed = torch.cat((vh[:, :, None], vw[:, :, None]), 2)
704         snake_next_position = snake_position[:, None, :] + snake_next_speed
705
706         # nb x 3
707         val = torch.logical_and(
708             torch.logical_and(
709                 snake_next_position[:, :, 0] >= 0, snake_next_position[:, :, 0] < height
710             ),
711             torch.logical_and(
712                 snake_next_position[:, :, 1] >= 0, snake_next_position[:, :, 1] < width
713             ),
714         ).float()
715         val = (
716             # The multiplicative factors bias toward moving forward
717             torch.rand_like(val)
718             * val
719             * torch.tensor([[1.0, 2.0, 1.0]], device=device)
720         )
721
722         # nb
723         j = val.argmax(1)
724         snake_direction = snake_next_direction[i, j]
725
726         sequences[:, 2 * l] = worlds[i, snake_position[:, 0], snake_position[:, 1]] + 4
727         sequences_prior_visits[:, 2 * l] = nb_prior_visits[
728             i, snake_position[:, 0], snake_position[:, 1]
729         ]
730         if l < prompt_length:
731             nb_prior_visits[i, snake_position[:, 0], snake_position[:, 1]] += 1
732         sequences[:, 2 * l + 1] = snake_direction
733
734         # nb x 2
735         snake_position = snake_next_position[i, j]
736
737     return sequences, sequences_prior_visits
738
739
740 # generate_snake_sequences(nb=1, height=4, width=6, nb_colors=3, length=20)
741 # exit(0)
742
743
744 def snake_solver(input, ar_mask):
745     for n in range(input.size(0)):
746         i, j, memory = 0, 0, {}
747         # print(input[n])
748         # print(ar_mask[n])
749         for l in range(input.size(1) // 2):
750             if ar_mask[n, 2 * l] == 1:
751                 if memory.get((i, j)) is None:
752                     input[n, 2 * l] = -1
753                 else:
754                     input[n, 2 * l] = memory[(i, j)]
755             else:
756                 # print(f'@3 {memory=}')
757                 if memory.get((i, j)) is None:
758                     memory[(i, j)] = input[n, 2 * l]
759                 else:
760                     assert memory[(i, j)] == input[n, 2 * l], f"n={n} l={l}"
761             # print(f'@1 {i=} {j=}')
762             d = input[n, 2 * l + 1].item()
763             i += (d + 1) % 2 * (d - 1)
764             j += d % 2 * (d - 2)
765             # print(f'@2 {i=} {j=}')
766
767
768 class TaskSnake(Task):
769     def __init__(
770         self,
771         nb_train_samples,
772         nb_test_samples,
773         batch_size,
774         height,
775         width,
776         nb_colors,
777         length,
778         prompt_length,
779         device=torch.device("cpu"),
780     ):
781         self.batch_size = batch_size
782         self.height = height
783         self.width = width
784         self.device = device
785         self.prompt_length = prompt_length
786
787         self.train_input, self.train_prior_visits = generate_snake_sequences(
788             nb_train_samples,
789             height,
790             width,
791             nb_colors,
792             length,
793             prompt_length,
794             self.device,
795         )
796         self.test_input, self.test_prior_visits = generate_snake_sequences(
797             nb_test_samples,
798             height,
799             width,
800             nb_colors,
801             length,
802             prompt_length,
803             self.device,
804         )
805
806         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
807
808     def batches(self, split="train", nb_to_use=-1, desc=None):
809         assert split in {"train", "test"}
810         input = self.train_input if split == "train" else self.test_input
811         if nb_to_use > 0:
812             input = input[:nb_to_use]
813         if desc is None:
814             desc = f"epoch-{split}"
815         for batch in tqdm.tqdm(
816             input.split(self.batch_size), dynamic_ncols=True, desc=desc
817         ):
818             yield batch
819
820     def vocabulary_size(self):
821         return self.nb_codes
822
823     def produce_results(self, n_epoch, model):
824         with torch.autograd.no_grad():
825             t = model.training
826             model.eval()
827
828             def compute_nb_correct(input, prior_visits):
829                 result = input.clone()
830                 i = torch.arange(result.size(1), device=result.device)[None, :]
831                 ar_mask = (
832                     torch.logical_and(i >= self.prompt_length * 2, i % 2 == 0)
833                     .long()
834                     .expand_as(result)
835                 )
836                 result *= 1 - ar_mask
837
838                 # snake_solver(result,ar_mask)
839
840                 masked_inplace_autoregression(
841                     model, self.batch_size, result, ar_mask, device=self.device
842                 )
843
844                 nb_total = ((prior_visits > 0) * ar_mask).sum()
845
846                 nb_correct = (
847                     (result == input).long() * (prior_visits > 0) * ar_mask
848                 ).sum()
849
850                 # nb_total = result.size(0)
851                 # nb_correct = ((result - input).abs().sum(1) == 0).sum()
852
853                 return nb_total, nb_correct
854
855             # train_nb_total, train_nb_correct = compute_nb_correct(
856             # self.train_input, self.train_prior_visits
857             # )
858
859             # log_string(
860             # f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
861             # )
862
863             test_nb_total, test_nb_correct = compute_nb_correct(
864                 self.test_input[:1000], self.test_prior_visits[:1000]
865             )
866
867             log_string(
868                 f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
869             )
870
871             model.train(t)
872
873
874 ######################################################################
875
876
877 def picoclvr_pruner_horizontal_green(p):
878     return not ("green" in p and ("left" in p or "right" in p))
879
880
881 picoclvr_pruner_train = (
882     picoclvr_pruner_horizontal_green
883     if args.picocvlr_prune_properties in {"train+eval"}
884     else None
885 )
886
887 picoclvr_pruner_eval = (
888     (lambda p: not picoclvr_pruner_horizontal_green(p))
889     if args.picocvlr_prune_properties in {"train+eval", "eval"}
890     else None
891 )
892
893 ######################################################################
894
895 if args.task == "picoclvr":
896     task = TaskPicoCLVR(
897         nb_train_samples=args.nb_train_samples,
898         nb_test_samples=args.nb_test_samples,
899         batch_size=args.batch_size,
900         height=args.picoclvr_height,
901         width=args.picoclvr_width,
902         nb_colors=args.picoclvr_nb_colors,
903         device=device,
904         pruner_train=picoclvr_pruner_train,
905         pruner_eval=picoclvr_pruner_eval,
906     )
907
908 elif args.task == "mnist":
909     task = TaskMNIST(
910         batch_size=args.batch_size,
911         device=device,
912     )
913
914 elif args.task == "maze":
915     task = TaskMaze(
916         nb_train_samples=args.nb_train_samples,
917         nb_test_samples=args.nb_test_samples,
918         batch_size=args.batch_size,
919         height=args.maze_height,
920         width=args.maze_width,
921         nb_walls=args.maze_nb_walls,
922         device=device,
923     )
924
925 elif args.task == "snake":
926     task = TaskSnake(
927         nb_train_samples=args.nb_train_samples,
928         nb_test_samples=args.nb_test_samples,
929         batch_size=args.batch_size,
930         height=args.snake_height,
931         width=args.snake_width,
932         nb_colors=args.snake_nb_colors,
933         length=args.snake_length,
934         prompt_length=args.snake_length // 2,
935         device=device,
936     )
937
938 else:
939     raise ValueError(f"Unknown task {args.task}")
940
941 ######################################################################
942
943 log_string(f"device {device}")
944
945 vocabulary_size = task.vocabulary_size()
946
947 log_string(f"vocabulary_size {vocabulary_size}")
948
949 ##############################
950
951 model = mygpt.MyGPT(
952     vocabulary_size=vocabulary_size,
953     dim_model=args.dim_model,
954     dim_keys=args.dim_keys,
955     dim_hidden=args.dim_hidden,
956     nb_heads=args.nb_heads,
957     nb_blocks=args.nb_blocks,
958     causal=True,
959     dropout=args.dropout,
960 )
961
962 model.to(device)
963
964 nb_parameters = sum(p.numel() for p in model.parameters())
965 log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
966
967 ######################################################################
968
969 nb_epochs_finished = 0
970
971 if args.no_checkpoint:
972     log_string(f"not trying to load checkpoint.")
973
974 else:
975     try:
976         checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
977         checkpoint = torch.load(checkpoint_name)
978         nb_epochs_finished = checkpoint["nb_epochs_finished"]
979         model.load_state_dict(checkpoint["model_state"])
980         torch.set_rng_state(checkpoint["rng_state"])
981         if torch.cuda.is_available():
982             torch.cuda.set_rng_state(checkpoint["cuda_rng_state"])
983
984         log_string(f"checkpoint loaded with {nb_epochs_finished} epochs finished.")
985
986     except FileNotFoundError:
987         log_string("starting from scratch.")
988
989     except:
990         log_string("error when loading the checkpoint.")
991         exit(1)
992
993 ######################################################################
994
995 nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
996
997 token_count = 0
998 for input in task.batches(split="train"):
999     token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1))
1000 token_probas = token_count / token_count.sum()
1001 entropy = -torch.xlogy(token_probas, token_probas).sum()
1002 train_set_perplexity = math.exp(entropy)
1003
1004 ##############################
1005
1006 if args.learning_rate_schedule == "cos":
1007     learning_rate_schedule = {}
1008     for n_epoch in range(args.nb_epochs):
1009         u = n_epoch / args.nb_epochs * math.pi
1010         learning_rate_schedule[n_epoch] = args.learning_rate * 0.5 * (1 + math.cos(u))
1011 else:
1012     u = {
1013         int(k): float(v)
1014         for k, v in [
1015             tuple(x.split(":")) for x in args.learning_rate_schedule.split(",")
1016         ]
1017     }
1018
1019     learning_rate_schedule = {}
1020     learning_rate = args.learning_rate
1021     for n_epoch in range(args.nb_epochs):
1022         if n_epoch in u:
1023             learning_rate = u[n_epoch]
1024         learning_rate_schedule[n_epoch] = learning_rate
1025
1026 log_string(f"learning_rate_schedule {learning_rate_schedule}")
1027
1028 ##############################
1029
1030 nb_samples_seen = 0
1031
1032 if nb_epochs_finished >= nb_epochs:
1033     task.produce_results(nb_epochs_finished, model)
1034
1035 for n_epoch in range(nb_epochs_finished, nb_epochs):
1036     learning_rate = learning_rate_schedule[n_epoch]
1037
1038     log_string(f"learning_rate {learning_rate}")
1039
1040     if args.optim == "sgd":
1041         optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
1042     elif args.optim == "adam":
1043         optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
1044     elif args.optim == "adamw":
1045         optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
1046     else:
1047         raise ValueError(f"Unknown optimizer {args.optim}.")
1048
1049     model.train()
1050
1051     nb_train_samples, acc_train_loss = 0, 0.0
1052
1053     for input in task.batches(split="train"):
1054         input = input.to(device)
1055         output = model(mygpt.BracketedSequence(input)).x
1056         loss = F.cross_entropy(output.transpose(1, 2), input)
1057         acc_train_loss += loss.item() * input.size(0)
1058         nb_train_samples += input.size(0)
1059         nb_samples_seen += input.size(0)
1060
1061         optimizer.zero_grad()
1062         loss.backward()
1063         optimizer.step()
1064
1065     with torch.autograd.no_grad():
1066         model.eval()
1067
1068         nb_test_samples, acc_test_loss = 0, 0.0
1069
1070         for input in task.batches(split="test"):
1071             input = input.to(device)
1072
1073             # input, loss_masks, true_images = task.excise_last_image(input)
1074             # input, loss_masks = task.add_true_image(input, true_images, loss_masks)
1075
1076             output = model(mygpt.BracketedSequence(input)).x
1077             loss = F.cross_entropy(output.transpose(1, 2), input)
1078             acc_test_loss += loss.item() * input.size(0)
1079             nb_test_samples += input.size(0)
1080
1081         train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
1082         test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
1083
1084         log_string(
1085             f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
1086         )
1087
1088         task.produce_results(n_epoch, model)
1089
1090     checkpoint = {
1091         "nb_epochs_finished": n_epoch + 1,
1092         "model_state": model.state_dict(),
1093         "rng_state": torch.get_rng_state(),
1094     }
1095
1096     if torch.cuda.is_available():
1097         checkpoint["cuda_rng_state"] = torch.cuda.get_rng_state()
1098
1099     checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
1100     torch.save(checkpoint, checkpoint_name)
1101     log_string(f"saved checkpoint {checkpoint_name}")
1102
1103 ######################################################################