Update.
[picoclvr.git] / main.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # torch.backends.cuda.matmul.allow_tf23
9 # torch.autocast(torch.bfloat16)
10
11 import math, sys, argparse, time, tqdm, itertools, os
12
13 import torch, torchvision
14 from torch import nn
15 from torch.nn import functional as F
16
17 import mygpt, tensorstack
18
19 ######################################################################
20
21 if torch.cuda.is_available():
22     device = torch.device("cuda")
23     torch.backends.cuda.matmul.allow_tf32 = True
24 else:
25     device = torch.device("cpu")
26
27 ######################################################################
28
29 parser = argparse.ArgumentParser(
30     description="An implementation of GPT with cache to solve a toy geometric reasoning task."
31 )
32
33 parser.add_argument("--task", type=str, default="picoclvr")
34
35 parser.add_argument("--log_filename", type=str, default="train.log")
36
37 parser.add_argument("--result_dir", type=str, default="results_default")
38
39 parser.add_argument("--seed", type=int, default=0)
40
41 parser.add_argument("--nb_epochs", type=int, default=25)
42
43 parser.add_argument("--batch_size", type=int, default=25)
44
45 parser.add_argument("--nb_train_samples", type=int, default=250000)
46
47 parser.add_argument("--nb_test_samples", type=int, default=10000)
48
49 parser.add_argument("--optim", type=str, default="adam")
50
51 parser.add_argument("--learning_rate", type=float, default=1e-4)
52
53 parser.add_argument("--learning_rate_schedule", type=str, default="10: 2e-5,30: 4e-6")
54
55 parser.add_argument("--dim_model", type=int, default=512)
56
57 parser.add_argument("--dim_keys", type=int, default=64)
58
59 parser.add_argument("--dim_hidden", type=int, default=2048)
60
61 parser.add_argument("--nb_heads", type=int, default=8)
62
63 parser.add_argument("--nb_blocks", type=int, default=12)
64
65 parser.add_argument("--dropout", type=float, default=0.1)
66
67 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
68
69 parser.add_argument("--no_checkpoint", action="store_true", default=False)
70
71 parser.add_argument("--overwrite_results", action="store_true", default=False)
72
73 parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
74
75 ##############################
76 # picoclvr options
77
78 parser.add_argument("--picoclvr_nb_colors", type=int, default=5)
79
80 parser.add_argument("--picoclvr_height", type=int, default=12)
81
82 parser.add_argument("--picoclvr_width", type=int, default=16)
83
84 parser.add_argument("--picocvlr_prune_properties", type=str, default="none")
85
86 ##############################
87 # Maze options
88
89 parser.add_argument("--maze_height", type=int, default=13)
90
91 parser.add_argument("--maze_width", type=int, default=21)
92
93 parser.add_argument("--maze_nb_walls", type=int, default=15)
94
95 ##############################
96 # Snake options
97
98 parser.add_argument("--snake_height", type=int, default=6)
99
100 parser.add_argument("--snake_width", type=int, default=8)
101
102 parser.add_argument("--snake_nb_colors", type=int, default=3)
103
104 parser.add_argument("--snake_length", type=int, default=400)
105
106 ######################################################################
107
108 args = parser.parse_args()
109
110 assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"}
111
112 try:
113     os.mkdir(args.result_dir)
114 except FileExistsError:
115     if not args.overwrite_results:
116         print(f"result directory {args.result_dir} already exists")
117         exit(1)
118
119 log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
120
121 if args.seed >= 0:
122     # torch.backends.cudnn.deterministic = True
123     # torch.backends.cudnn.benchmark = False
124     # torch.use_deterministic_algorithms(True)
125     torch.manual_seed(args.seed)
126     if torch.cuda.is_available():
127         torch.cuda.manual_seed_all(args.seed)
128
129 ######################################################################
130
131
132 def log_string(s):
133     t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
134
135     if log_file is not None:
136         log_file.write(t + s + "\n")
137         log_file.flush()
138
139     print(t + s)
140     sys.stdout.flush()
141
142
143 for n in vars(args):
144     log_string(f"args.{n} {getattr(args, n)}")
145
146 ######################################################################
147
148
149 def masked_inplace_autoregression(
150     model, batch_size, input, ar_mask, forbidden_tokens=None, device=torch.device("cpu")
151 ):
152     for input, ar_mask in zip(input.split(batch_size), ar_mask.split(batch_size)):
153         i = (ar_mask.sum(0) > 0).nonzero()
154         if i.min() > 0:
155             model(
156                 mygpt.BracketedSequence(input, 0, i.min())
157             )  # Needed to initialize the model's cache
158         for s in range(i.min(), i.max() + 1):
159             output = model(mygpt.BracketedSequence(input, s, 1)).x
160             logits = output[:, s]
161             if forbidden_tokens is not None:
162                 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
163             if args.deterministic_synthesis:
164                 t_next = logits.argmax(1)
165             else:
166                 dist = torch.distributions.categorical.Categorical(logits=logits)
167                 t_next = dist.sample()
168             input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
169
170
171 ######################################################################
172
173
174 class Task:
175     def batches(self, split="train"):
176         pass
177
178     def vocabulary_size(self):
179         pass
180
181     def produce_results(self, n_epoch, model):
182         pass
183
184
185 ######################################################################
186
187 import picoclvr
188
189
190 class TaskPicoCLVR(Task):
191     # Make a tensor from a list of strings
192     def tensorize(self, descr):
193         token_descr = [s.strip().split(" ") for s in descr]
194         l = max([len(s) for s in token_descr])
195         token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
196         id_descr = [[self.token2id[u] for u in s] for s in token_descr]
197         return torch.tensor(id_descr, device=self.device)
198
199     # Make a list of strings from a tensor
200     def detensorize(self, x):
201         return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
202
203     # trim all the tensors in the tuple z to remove as much token from
204     # left and right in the first tensor. If z is a tuple, all its
205     # elements are trimed according to the triming for the first
206     def trim(self, z, token="<nul>"):
207         n = self.token2id[token]
208         if type(z) == tuple:
209             x = z[0]
210             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
211             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
212             return tuple([t[:, a:b] for t in z])
213         else:
214             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
215             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
216             return z[:, a:b]
217
218     ######################
219     # Not the cleanest part of the code
220
221     # Extract the last image of each sequence, from the last <img>
222     # included, and set to <nul> all the tokens from the beginning of
223     # that image to the end
224     def excise_last_image(self, input):
225         t_img, t_nul = self.token2id["<img>"], self.token2id["<nul>"]
226         nb_img_tokens = self.height * self.width + 1
227
228         input = input.clone()
229         t = (input == t_img).long()
230         tail_masks = (t.cumsum(dim=1) == t.sum(dim=1, keepdim=True)).long()
231         i = (t * tail_masks).nonzero(as_tuple=True)
232         j = (
233             i[0][:, None],
234             i[1][:, None] + torch.arange(nb_img_tokens, device=input.device)[None, :],
235         )
236         images = self.trim(input[j])
237         input[j] = t_nul
238         loss_masks = 1 - tail_masks
239         input, loss_masks = self.trim((input, loss_masks))
240         return input, loss_masks, images
241
242     def add_true_image(self, input, images, loss_masks):
243         t_nul = self.token2id["<nul>"]
244         nb_img_tokens = self.height * self.width + 1
245         input = F.pad(input, (0, nb_img_tokens), value=t_nul)
246         loss_masks = F.pad(loss_masks, (0, nb_img_tokens), value=0)
247         t = (input == t_nul).long()
248         i = (t.cumsum(dim=1) == 1).nonzero(as_tuple=True)
249         j = (
250             i[0][:, None],
251             i[1][:, None] + torch.arange(nb_img_tokens, device=input.device)[None, :],
252         )
253         input[j] = images
254         loss_masks[j] = 1
255         input, loss_masks = self.trim((input, loss_masks))
256         return input, loss_masks
257
258     def add_generated_image(self, input, loss_masks, model):
259         t_img, t_nul = self.token2id["<img>"], self.token2id["<nul>"]
260         nb_img_tokens = self.height * self.width + 1
261
262         input = F.pad(input, (0, nb_img_tokens), value=t_nul)
263         loss_masks = F.pad(loss_masks, (0, nb_img_tokens), value=0)
264         t = (input == t_nul).long()
265         i = (t.cumsum(dim=1) == 1).nonzero(as_tuple=True)
266         input[i] = t_img
267
268         j = (
269             i[0][:, None],
270             i[1][:, None]
271             + 1
272             + torch.arange(nb_img_tokens - 1, device=input.device)[None, :],
273         )
274         ar_masks = input.new_zeros(input.size(), dtype=torch.int64)
275         ar_masks[j] = 1
276         forbidden_tokens = (
277             torch.arange(self.vocabulary_size(), device=input.device) == t_nul
278         )
279         with torch.autograd.no_grad():
280             t = model.training
281             model.eval()
282             masked_inplace_autoregression(
283                 model,
284                 self.batch_size,
285                 input,
286                 ar_masks,
287                 forbidden_tokens,
288                 device=self.device,
289             )
290             model.train(t)
291
292         input, loss_masks = self.trim((input, loss_masks))
293
294         return input, loss_masks
295
296     ######################
297
298     def __init__(
299         self,
300         nb_train_samples,
301         nb_test_samples,
302         batch_size,
303         height,
304         width,
305         nb_colors=5,
306         device=torch.device("cpu"),
307         pruner_train=None,
308         pruner_eval=None,
309     ):
310         def generate_descr(nb, cache_suffix, pruner):
311             return picoclvr.generate(
312                 nb,
313                 height=self.height,
314                 width=self.width,
315                 nb_colors=nb_colors,
316                 pruner=pruner,
317             )
318
319         self.height = height
320         self.width = width
321         self.batch_size = batch_size
322         self.device = device
323         self.pruner_train = pruner_train
324         self.pruner_eval = pruner_eval
325
326         param = {
327             "nb_train_samples": nb_train_samples,
328             "nb_test_samples": nb_test_samples,
329             "height": height,
330             "width": width,
331             "nb_colors": nb_colors,
332             "batch_size": batch_size,
333             "rng_state": list(torch.get_rng_state()),
334         }
335
336         log_string(
337             f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
338         )
339         self.train_descr = generate_descr(
340             nb_train_samples, "train", pruner=self.pruner_train
341         )
342         self.test_descr = generate_descr(nb_test_samples, "test", pruner=None)
343
344         # Build the tokenizer
345         tokens = {"<nul>", "<img>"}
346         for d in [self.train_descr, self.test_descr]:
347             for s in d:
348                 for t in s.strip().split(" "):
349                     tokens.add(t)
350         # make this set a sorted list to get the same tensors given
351         # the same descr
352         tokens = list(tokens)
353         tokens.sort()
354         self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
355         self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
356
357         # Tokenize the train and test sets
358         self.train_input = self.tensorize(self.train_descr)
359         self.test_input = self.tensorize(self.test_descr)
360
361     def batches(self, split="train"):
362         assert split in {"train", "test"}
363         input = self.train_input if split == "train" else self.test_input
364         for batch in tqdm.tqdm(
365             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
366         ):
367             yield self.trim(batch)
368
369     def vocabulary_size(self):
370         return len(self.token2id)
371
372     def compute_missing_properties(self, n_epoch, model, pruner=None):
373         acc_nb_requested_properties = []
374         acc_nb_missing_properties = []
375         acc_nb_results = 0
376
377         for input in tqdm.tqdm(
378             self.test_input.split(self.batch_size),
379             dynamic_ncols=True,
380             desc=f"test-properties",
381         ):
382             tape, loss_masks, _ = self.excise_last_image(input)
383             tape, loss_masks = self.add_generated_image(tape, loss_masks, model)
384             result_descr = self.detensorize(tape)
385             np = picoclvr.nb_properties(
386                 result_descr,
387                 height=self.height,
388                 width=self.width,
389                 pruner=pruner,
390             )
391             nb_requested_properties, _, nb_missing_properties = zip(*np)
392             acc_nb_requested_properties += nb_requested_properties
393             acc_nb_missing_properties += nb_missing_properties
394             acc_nb_results += len(result_descr)
395
396         nb_requested_properties = sum(acc_nb_requested_properties)
397         nb_missing_properties = sum(acc_nb_missing_properties)
398
399         prefix = "" if pruner is None else "pruned_"
400         log_string(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
401         log_string(
402             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
403         )
404         log_string(
405             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
406         )
407
408     ######################################################################
409
410     def produce_results(self, n_epoch, model):
411         self.compute_missing_properties(n_epoch, model)
412
413         if self.pruner_eval is not None:
414             self.compute_missing_properties(n_epoch, model, self.pruner_eval)
415
416         nb_tokens_to_generate = self.height * self.width + 3
417         result_descr = []
418         nb_per_primer = 8
419         primer = []
420
421         for primer_descr in [
422             "red above green <sep> green top <sep> blue right of red",
423             "there is red <sep> there is yellow <sep> there is blue",
424             "red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left",
425             "green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top",
426         ]:
427             primer += [primer_descr] * nb_per_primer
428
429         tape = self.tensorize(primer)
430         loss_masks = 1 - (tape == self.token2id["<nul>"]).long()
431         tape, loss_masks = self.add_generated_image(tape, loss_masks, model)
432         result_descr = self.detensorize(tape)
433
434         np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
435
436         acc_nb_requested_properties, _, acc_nb_missing_properties = zip(*np)
437         acc_nb_results = len(result_descr)
438
439         nb_requested_properties = sum(acc_nb_requested_properties)
440         nb_missing_properties = sum(acc_nb_missing_properties)
441
442         prefix = "demo_"
443         log_string(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
444         log_string(
445             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
446         )
447         log_string(
448             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
449         )
450
451         img = picoclvr.descr2img(result_descr, height=self.height, width=self.width)
452
453         if img.dim() == 5:
454             if img.size(1) == 1:
455                 img = F.pad(img.squeeze(1), pad=(1, 1, 1, 1), value=64)
456             else:
457                 img = torch.cat(
458                     [
459                         torchvision.utils.make_grid(x, padding=1, pad_value=64)[None]
460                         for x in img
461                     ],
462                     0,
463                 )
464
465         image_name = os.path.join(args.result_dir, f"picoclvr_result_{n_epoch:04d}.png")
466         torchvision.utils.save_image(
467             img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=1.0
468         )
469         log_string(f"wrote {image_name}")
470
471
472 ######################################################################
473
474
475 class TaskMNIST(Task):
476     def __init__(self, batch_size, device=torch.device("cpu")):
477         self.device = device
478         self.batch_size = batch_size
479
480     def batches(self, split="train"):
481         assert split in {"train", "test"}
482         data_set = torchvision.datasets.MNIST(
483             root="./data", train=(split == "train"), download=True
484         )
485         data_input = data_set.data.view(-1, 28 * 28).long()
486         if args.nb_train_samples is not None:
487             data_input = data_input[: args.nb_train_samples]
488         for batch in tqdm.tqdm(
489             data_input.split(self.batch_size), desc=f"epoch-{split}"
490         ):
491             yield batch
492
493     def vocabulary_size(self):
494         return 256
495
496     def produce_results(self, n_epoch, model):
497         results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64)
498         ar_mask = torch.full_like(results, 1)
499         masked_inplace_autoregression(
500             model, self.batch_size, results, ar_mask, device=self.device
501         )
502         image_name = os.path.join(args.result_dir, f"mnist_result_{n_epoch:04d}.png")
503         torchvision.utils.save_image(
504             1 - results.reshape(-1, 1, 28, 28) / 255.0,
505             image_name,
506             nrow=16,
507             pad_value=0.8,
508         )
509         log_string(f"wrote {image_name}")
510
511
512 ######################################################################
513
514 import maze
515
516
517 class TaskMaze(Task):
518     def map2seq(self, *m):
519         return torch.cat([x.flatten(1) for x in m], 1)
520
521     def seq2map(self, s):
522         s = s.reshape(s.size(0), -1, self.height, self.width)
523         return (s[:, k] for k in range(s.size(1)))
524
525     def __init__(
526         self,
527         nb_train_samples,
528         nb_test_samples,
529         batch_size,
530         height,
531         width,
532         nb_walls,
533         device=torch.device("cpu"),
534     ):
535         self.batch_size = batch_size
536         self.height = height
537         self.width = width
538         self.device = device
539
540         train_mazes, train_paths, _ = maze.create_maze_data(
541             nb_train_samples,
542             height=height,
543             width=width,
544             nb_walls=nb_walls,
545             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"),
546         )
547         self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device))
548
549         test_mazes, test_paths, _ = maze.create_maze_data(
550             nb_test_samples,
551             height=height,
552             width=width,
553             nb_walls=nb_walls,
554             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"),
555         )
556         self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
557
558         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
559
560     def batches(self, split="train", nb_to_use=-1, desc=None):
561         assert split in {"train", "test"}
562         input = self.train_input if split == "train" else self.test_input
563         if nb_to_use > 0:
564             input = input[:nb_to_use]
565         if desc is None:
566             desc = f"epoch-{split}"
567         for batch in tqdm.tqdm(
568             input.split(self.batch_size), dynamic_ncols=True, desc=desc
569         ):
570             yield batch
571
572     def vocabulary_size(self):
573         return self.nb_codes
574
575     def compute_error(self, model, split="train", nb_to_use=-1):
576         nb_total, nb_correct = 0, 0
577         for input in task.batches(split, nb_to_use):
578             result = input.clone()
579             ar_mask = result.new_zeros(result.size())
580             ar_mask[:, self.height * self.width :] = 1
581             result *= 1 - ar_mask
582             masked_inplace_autoregression(
583                 model, self.batch_size, result, ar_mask, device=self.device
584             )
585             mazes, paths = self.seq2map(result)
586             nb_correct += maze.path_correctness(mazes, paths).long().sum()
587             nb_total += mazes.size(0)
588
589         return nb_total, nb_correct
590
591     def produce_results(self, n_epoch, model):
592         with torch.autograd.no_grad():
593             t = model.training
594             model.eval()
595
596             train_nb_total, train_nb_correct = self.compute_error(
597                 model, "train", nb_to_use=1000
598             )
599             log_string(
600                 f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
601             )
602
603             test_nb_total, test_nb_correct = self.compute_error(
604                 model, "test", nb_to_use=1000
605             )
606             log_string(
607                 f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
608             )
609
610             input = self.test_input[:48]
611             result = input.clone()
612             ar_mask = result.new_zeros(result.size())
613             ar_mask[:, self.height * self.width :] = 1
614             result *= 1 - ar_mask
615             masked_inplace_autoregression(
616                 model, self.batch_size, result, ar_mask, device=self.device
617             )
618
619             mazes, paths = self.seq2map(input)
620             _, predicted_paths = self.seq2map(result)
621
622             filename = os.path.join(args.result_dir, f"maze_result_{n_epoch:04d}.png")
623             maze.save_image(
624                 filename,
625                 mazes=mazes,
626                 target_paths=paths,
627                 predicted_paths=predicted_paths,
628                 path_correct=maze.path_correctness(mazes, predicted_paths),
629             )
630             log_string(f"wrote {filename}")
631
632             model.train(t)
633
634
635 ######################################################################
636
637
638 def generate_snake_sequences(
639     nb, height, width, nb_colors, length, device=torch.device("cpu")
640 ):
641     worlds = torch.randint(nb_colors, (nb, height, width), device=device)
642     # nb x 2
643     snake_position = torch.cat(
644         (
645             torch.randint(height, (nb, 1), device=device),
646             torch.randint(width, (nb, 1), device=device),
647         ),
648         1,
649     )
650     snake_direction = torch.randint(4, (nb,), device=device)
651     sequences = torch.empty(nb, 2 * length, device=device, dtype=torch.int64)
652     i = torch.arange(nb, device=device)  # [:,None]
653
654     for l in range(length):
655         # nb x 3
656         snake_next_direction = torch.cat(
657             (
658                 (snake_direction[:, None] - 1) % 4,
659                 snake_direction[:, None],
660                 (snake_direction[:, None] + 1) % 4,
661             ),
662             1,
663         )
664
665         # nb x 3
666         vh = (snake_next_direction + 1) % 2 * (snake_next_direction - 1)
667         vw = snake_next_direction % 2 * (snake_next_direction - 2)
668
669         # nb x 3 x 2
670         snake_next_speed = torch.cat((vh[:, :, None], vw[:, :, None]), 2)
671         snake_next_position = snake_position[:, None, :] + snake_next_speed
672
673         # nb x 3
674         val = torch.logical_and(
675             torch.logical_and(
676                 snake_next_position[:, :, 0] >= 0, snake_next_position[:, :, 0] < height
677             ),
678             torch.logical_and(
679                 snake_next_position[:, :, 1] >= 0, snake_next_position[:, :, 1] < width
680             ),
681         ).float()
682         val = (
683             torch.rand_like(val) * val * torch.tensor([[1.0, 4.0, 1.0]], device=device)
684         )
685
686         # nb
687         j = val.argmax(1)
688         snake_direction = snake_next_direction[i, j]
689
690         sequences[:, 2 * l] = worlds[i, snake_position[:, 0], snake_position[:, 1]] + 4
691         sequences[:, 2 * l + 1] = snake_direction
692
693         # nb x 2
694         snake_position = snake_next_position[i, j]
695
696     return sequences, worlds
697
698
699 # generate_snake_sequences(nb=1, height=4, width=6, nb_colors=3, length=20)
700 # exit(0)
701
702
703 class TaskSnake(Task):
704     def __init__(
705         self,
706         nb_train_samples,
707         nb_test_samples,
708         batch_size,
709         height,
710         width,
711         nb_colors,
712         length,
713         device=torch.device("cpu"),
714     ):
715         self.batch_size = batch_size
716         self.height = height
717         self.width = width
718         self.device = device
719
720         self.train_input, self.train_worlds = generate_snake_sequences(
721             nb_train_samples, height, width, nb_colors, length, self.device
722         )
723         self.test_input, self.test_worlds = generate_snake_sequences(
724             nb_test_samples, height, width, nb_colors, length, self.device
725         )
726
727         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
728
729     def batches(self, split="train", nb_to_use=-1, desc=None):
730         assert split in {"train", "test"}
731         input = self.train_input if split == "train" else self.test_input
732         if nb_to_use > 0:
733             input = input[:nb_to_use]
734         if desc is None:
735             desc = f"epoch-{split}"
736         for batch in tqdm.tqdm(
737             input.split(self.batch_size), dynamic_ncols=True, desc=desc
738         ):
739             yield batch
740
741     def vocabulary_size(self):
742         return self.nb_codes
743
744     def produce_results(self, n_epoch, model):
745         with torch.autograd.no_grad():
746             t = model.training
747             model.eval()
748
749             def compute_nb_correct(input):
750                 result = input.clone()
751                 i = torch.arange(result.size(1), device=result.device)
752                 ar_mask = torch.logical_and(i >= i.size(0) // 2, i % 2 == 0)[
753                     None, :
754                 ].long()
755                 result *= 1 - ar_mask
756                 masked_inplace_autoregression(
757                     model, self.batch_size, result, ar_mask, device=self.device
758                 )
759
760                 nb_total = ar_mask.sum() * input.size(0)
761                 nb_correct = ((result == input).long() * ar_mask).sum()
762
763                 # nb_total = result.size(0)
764                 # nb_correct = ((result - input).abs().sum(1) == 0).sum()
765
766                 return nb_total, nb_correct
767
768             train_nb_total, train_nb_correct = compute_nb_correct(self.train_input)
769
770             log_string(
771                 f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
772             )
773
774             test_nb_total, test_nb_correct = compute_nb_correct(self.test_input)
775
776             log_string(
777                 f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
778             )
779
780             model.train(t)
781
782
783 ######################################################################
784
785
786 def picoclvr_pruner_horizontal_green(p):
787     return not ("green" in p and ("left" in p or "right" in p))
788
789
790 picoclvr_pruner_train = (
791     picoclvr_pruner_horizontal_green
792     if args.picocvlr_prune_properties in {"train+eval"}
793     else None
794 )
795
796 picoclvr_pruner_eval = (
797     (lambda p: not picoclvr_pruner_horizontal_green(p))
798     if args.picocvlr_prune_properties in {"train+eval", "eval"}
799     else None
800 )
801
802 ######################################################################
803
804 if args.task == "picoclvr":
805     task = TaskPicoCLVR(
806         nb_train_samples=args.nb_train_samples,
807         nb_test_samples=args.nb_test_samples,
808         batch_size=args.batch_size,
809         height=args.picoclvr_height,
810         width=args.picoclvr_width,
811         nb_colors=args.picoclvr_nb_colors,
812         device=device,
813         pruner_train=picoclvr_pruner_train,
814         pruner_eval=picoclvr_pruner_eval,
815     )
816
817 elif args.task == "mnist":
818     task = TaskMNIST(
819         batch_size=args.batch_size,
820         device=device,
821     )
822
823 elif args.task == "maze":
824     task = TaskMaze(
825         nb_train_samples=args.nb_train_samples,
826         nb_test_samples=args.nb_test_samples,
827         batch_size=args.batch_size,
828         height=args.maze_height,
829         width=args.maze_width,
830         nb_walls=args.maze_nb_walls,
831         device=device,
832     )
833
834 elif args.task == "snake":
835     task = TaskSnake(
836         nb_train_samples=args.nb_train_samples,
837         nb_test_samples=args.nb_test_samples,
838         batch_size=args.batch_size,
839         height=args.snake_height,
840         width=args.snake_width,
841         nb_colors=args.snake_nb_colors,
842         length=args.snake_length,
843         device=device,
844     )
845
846 else:
847     raise ValueError(f"Unknown task {args.task}")
848
849 ######################################################################
850
851 log_string(f"device {device}")
852
853 vocabulary_size = task.vocabulary_size()
854
855 log_string(f"vocabulary_size {vocabulary_size}")
856
857 ##############################
858
859 model = mygpt.MyGPT(
860     vocabulary_size=vocabulary_size,
861     dim_model=args.dim_model,
862     dim_keys=args.dim_keys,
863     dim_hidden=args.dim_hidden,
864     nb_heads=args.nb_heads,
865     nb_blocks=args.nb_blocks,
866     causal=True,
867     dropout=args.dropout,
868 )
869
870 model.to(device)
871
872 nb_parameters = sum(p.numel() for p in model.parameters())
873 log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
874
875 ######################################################################
876
877 nb_epochs_finished = 0
878
879 if args.no_checkpoint:
880     log_string(f"not trying to load checkpoint.")
881
882 else:
883     try:
884         checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
885         checkpoint = torch.load(checkpoint_name)
886         nb_epochs_finished = checkpoint["nb_epochs_finished"]
887         model.load_state_dict(checkpoint["model_state"])
888         torch.set_rng_state(checkpoint["rng_state"])
889         if torch.cuda.is_available():
890             torch.cuda.set_rng_state(checkpoint["cuda_rng_state"])
891
892         log_string(f"checkpoint loaded with {nb_epochs_finished} epochs finished.")
893
894     except FileNotFoundError:
895         log_string("starting from scratch.")
896
897     except:
898         log_string("error when loading the checkpoint.")
899         exit(1)
900
901 ######################################################################
902
903 nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
904
905 token_count = 0
906 for input in task.batches(split="train"):
907     token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1))
908 token_probas = token_count / token_count.sum()
909 entropy = -torch.xlogy(token_probas, token_probas).sum()
910 train_set_perplexity = math.exp(entropy)
911
912 ##############################
913
914 if args.learning_rate_schedule == "cos":
915     learning_rate_schedule = {}
916     for n_epoch in range(args.nb_epochs):
917         u = n_epoch / args.nb_epochs * math.pi
918         learning_rate_schedule[n_epoch] = args.learning_rate * 0.5 * (1 + math.cos(u))
919 else:
920     u = {
921         int(k): float(v)
922         for k, v in [
923             tuple(x.split(":")) for x in args.learning_rate_schedule.split(",")
924         ]
925     }
926
927     learning_rate_schedule = {}
928     learning_rate = args.learning_rate
929     for n_epoch in range(args.nb_epochs):
930         if n_epoch in u:
931             learning_rate = u[n_epoch]
932         learning_rate_schedule[n_epoch] = learning_rate
933
934 log_string(f"learning_rate_schedule {learning_rate_schedule}")
935
936 ##############################
937
938 nb_samples_seen = 0
939
940 if nb_epochs_finished >= nb_epochs:
941     task.produce_results(nb_epochs_finished, model)
942
943 for n_epoch in range(nb_epochs_finished, nb_epochs):
944     learning_rate = learning_rate_schedule[n_epoch]
945
946     log_string(f"learning_rate {learning_rate}")
947
948     if args.optim == "sgd":
949         optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
950     elif args.optim == "adam":
951         optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
952     elif args.optim == "adamw":
953         optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
954     else:
955         raise ValueError(f"Unknown optimizer {args.optim}.")
956
957     model.train()
958
959     nb_train_samples, acc_train_loss = 0, 0.0
960
961     for input in task.batches(split="train"):
962         input = input.to(device)
963         output = model(mygpt.BracketedSequence(input)).x
964         loss = F.cross_entropy(output.transpose(1, 2), input)
965         acc_train_loss += loss.item() * input.size(0)
966         nb_train_samples += input.size(0)
967         nb_samples_seen += input.size(0)
968
969         optimizer.zero_grad()
970         loss.backward()
971         optimizer.step()
972
973     with torch.autograd.no_grad():
974         model.eval()
975
976         nb_test_samples, acc_test_loss = 0, 0.0
977
978         for input in task.batches(split="test"):
979             input = input.to(device)
980
981             # input, loss_masks, true_images = task.excise_last_image(input)
982             # input, loss_masks = task.add_true_image(input, true_images, loss_masks)
983
984             output = model(mygpt.BracketedSequence(input)).x
985             loss = F.cross_entropy(output.transpose(1, 2), input)
986             acc_test_loss += loss.item() * input.size(0)
987             nb_test_samples += input.size(0)
988
989         train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
990         test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
991
992         log_string(
993             f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
994         )
995
996         task.produce_results(n_epoch, model)
997
998     checkpoint = {
999         "nb_epochs_finished": n_epoch + 1,
1000         "model_state": model.state_dict(),
1001         "rng_state": torch.get_rng_state(),
1002     }
1003
1004     if torch.cuda.is_available():
1005         checkpoint["cuda_rng_state"] = torch.cuda.get_rng_state()
1006
1007     checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
1008     torch.save(checkpoint, checkpoint_name)
1009     log_string(f"saved checkpoint {checkpoint_name}")
1010
1011 ######################################################################