Update.
[picoclvr.git] / main.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # torch.backends.cuda.matmul.allow_tf23
9 # torch.autocast(torch.bfloat16)
10
11 import math, sys, argparse, time, tqdm, os
12
13 import torch, torchvision
14 from torch import nn
15 from torch.nn import functional as F
16
17 import mygpt, tensorstack
18
19 ######################################################################
20
21 if torch.cuda.is_available():
22     device = torch.device("cuda")
23     torch.backends.cuda.matmul.allow_tf32 = True
24 else:
25     device = torch.device("cpu")
26
27 ######################################################################
28
29 parser = argparse.ArgumentParser(
30     description="An implementation of GPT with cache.",
31     formatter_class=argparse.ArgumentDefaultsHelpFormatter,
32 )
33
34 parser.add_argument(
35     "--task", type=str, default="picoclvr", help="picoclvr, mnist, maze, snake, stack"
36 )
37
38 parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
39
40 parser.add_argument("--result_dir", type=str, default="results_default")
41
42 parser.add_argument("--seed", type=int, default=0)
43
44 parser.add_argument("--nb_epochs", type=int, default=None)
45
46 parser.add_argument("--batch_size", type=int, default=None)
47
48 parser.add_argument("--nb_train_samples", type=int, default=250000)
49
50 parser.add_argument("--nb_test_samples", type=int, default=10000)
51
52 parser.add_argument("--optim", type=str, default="adam")
53
54 parser.add_argument("--learning_rate", type=float, default=1e-4)
55
56 parser.add_argument("--learning_rate_schedule", type=str, default="10: 2e-5,30: 4e-6")
57
58 parser.add_argument("--dim_model", type=int, default=512)
59
60 parser.add_argument("--dim_keys", type=int, default=64)
61
62 parser.add_argument("--dim_hidden", type=int, default=2048)
63
64 parser.add_argument("--nb_heads", type=int, default=8)
65
66 parser.add_argument("--nb_blocks", type=int, default=12)
67
68 parser.add_argument("--dropout", type=float, default=0.1)
69
70 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
71
72 parser.add_argument("--no_checkpoint", action="store_true", default=False)
73
74 parser.add_argument("--overwrite_results", action="store_true", default=False)
75
76 parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
77
78 ##############################
79 # picoclvr options
80
81 parser.add_argument("--picoclvr_nb_colors", type=int, default=5)
82
83 parser.add_argument("--picoclvr_height", type=int, default=12)
84
85 parser.add_argument("--picoclvr_width", type=int, default=16)
86
87 parser.add_argument("--picocvlr_prune_properties", type=str, default="none")
88
89 ##############################
90 # Maze options
91
92 parser.add_argument("--maze_height", type=int, default=13)
93
94 parser.add_argument("--maze_width", type=int, default=21)
95
96 parser.add_argument("--maze_nb_walls", type=int, default=15)
97
98 ##############################
99 # Snake options
100
101 parser.add_argument("--snake_height", type=int, default=6)
102
103 parser.add_argument("--snake_width", type=int, default=8)
104
105 parser.add_argument("--snake_nb_colors", type=int, default=5)
106
107 parser.add_argument("--snake_length", type=int, default=200)
108
109 ##############################
110 # Snake options
111
112 parser.add_argument("--stack_nb_steps", type=int, default=100)
113
114 parser.add_argument("--stack_nb_stacks", type=int, default=1)
115
116 parser.add_argument("--stack_nb_values", type=int, default=10)
117
118 ######################################################################
119
120 args = parser.parse_args()
121
122 assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"}
123
124 try:
125     os.mkdir(args.result_dir)
126 except FileExistsError:
127     if not args.overwrite_results:
128         print(f"result directory {args.result_dir} already exists")
129         exit(1)
130
131 log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
132
133 if args.seed >= 0:
134     # torch.backends.cudnn.deterministic = True
135     # torch.backends.cudnn.benchmark = False
136     # torch.use_deterministic_algorithms(True)
137     torch.manual_seed(args.seed)
138     if torch.cuda.is_available():
139         torch.cuda.manual_seed_all(args.seed)
140
141 ######################################################################
142
143 default_args = {
144     "picoclvr": {
145         "nb_epochs": 25,
146         "batch_size": 25,
147         "nb_train_samples": 250000,
148         "nb_test_samples": 10000,
149     },
150     "mnist": {
151         "nb_epochs": 25,
152         "batch_size": 10,
153         "nb_train_samples": 250000,
154         "nb_test_samples": 10000,
155     },
156     "maze": {
157         "nb_epochs": 25,
158         "batch_size": 25,
159         "nb_train_samples": 250000,
160         "nb_test_samples": 10000,
161     },
162     "snake": {
163         "nb_epochs": 5,
164         "batch_size": 25,
165         "nb_train_samples": 250000,
166         "nb_test_samples": 10000,
167     },
168     "stack": {
169         "nb_epochs": 5,
170         "batch_size": 25,
171         "nb_train_samples": 100000,
172         "nb_test_samples": 1000,
173     },
174 }
175
176 if args.task in default_args:
177     for k, v in default_args[args.task].items():
178         if getattr(args, k) is None:
179             setattr(args, k, v)
180
181 ######################################################################
182
183
184 def log_string(s):
185     t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
186
187     if log_file is not None:
188         log_file.write(t + s + "\n")
189         log_file.flush()
190
191     print(t + s)
192     sys.stdout.flush()
193
194
195 for n in vars(args):
196     log_string(f"args.{n} {getattr(args, n)}")
197
198 ######################################################################
199
200
201 # ra_mask is boolean, with 1s on the values to generate
202
203
204 def masked_inplace_autoregression(
205     model,
206     batch_size,
207     input,
208     ar_mask,
209     forbidden_tokens=None,
210     progress_bar_desc="autoregression",
211     device=torch.device("cpu"),
212 ):
213     # p = logits.softmax(1)
214     # entropy[:,s]= p.xlogy(p).sum(1) / math.log(2)
215     batches = zip(input.split(batch_size), ar_mask.split(batch_size))
216     if progress_bar_desc is not None:
217         tqdm.tqdm(
218             batches,
219             dynamic_ncols=True,
220             desc=progress_bar_desc,
221             total=input.size(0) // batch_size,
222         )
223     for input, ar_mask in batches:
224         i = (ar_mask.sum(0) > 0).nonzero()
225         if i.min() > 0:
226             model(
227                 mygpt.BracketedSequence(input, 0, i.min())
228             )  # Needed to initialize the model's cache
229         for s in range(i.min(), i.max() + 1):
230             output = model(mygpt.BracketedSequence(input, s, 1)).x
231             logits = output[:, s]
232             if forbidden_tokens is not None:
233                 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
234             if args.deterministic_synthesis:
235                 t_next = logits.argmax(1)
236             else:
237                 dist = torch.distributions.categorical.Categorical(logits=logits)
238                 t_next = dist.sample()
239             input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
240
241
242 ######################################################################
243
244
245 class Task:
246     def batches(self, split="train"):
247         pass
248
249     def vocabulary_size(self):
250         pass
251
252     def produce_results(self, n_epoch, model):
253         pass
254
255
256 ######################################################################
257
258 import picoclvr
259
260
261 class TaskPicoCLVR(Task):
262     # Make a tensor from a list of strings
263     def tensorize(self, descr):
264         token_descr = [s.strip().split(" ") for s in descr]
265         l = max([len(s) for s in token_descr])
266         token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
267         id_descr = [[self.token2id[u] for u in s] for s in token_descr]
268         return torch.tensor(id_descr, device=self.device)
269
270     # Make a list of strings from a tensor
271     def detensorize(self, x):
272         return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
273
274     # trim all the tensors in the tuple z to remove as much token from
275     # left and right in the first tensor. If z is a tuple, all its
276     # elements are trimed according to the triming for the first
277     def trim(self, z, token="<nul>"):
278         n = self.token2id[token]
279         if type(z) == tuple:
280             x = z[0]
281             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
282             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
283             return tuple([t[:, a:b] for t in z])
284         else:
285             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
286             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
287             return z[:, a:b]
288
289     ######################
290     # Not the cleanest part of the code
291
292     # Extract the last image of each sequence, from the last <img>
293     # included, and set to <nul> all the tokens from the beginning of
294     # that image to the end
295     def excise_last_image(self, input):
296         t_img, t_nul = self.token2id["<img>"], self.token2id["<nul>"]
297         nb_img_tokens = self.height * self.width + 1
298
299         input = input.clone()
300         t = (input == t_img).long()
301         tail_masks = (t.cumsum(dim=1) == t.sum(dim=1, keepdim=True)).long()
302         i = (t * tail_masks).nonzero(as_tuple=True)
303         j = (
304             i[0][:, None],
305             i[1][:, None] + torch.arange(nb_img_tokens, device=input.device)[None, :],
306         )
307         images = self.trim(input[j])
308         input[j] = t_nul
309         loss_masks = 1 - tail_masks
310         input, loss_masks = self.trim((input, loss_masks))
311         return input, loss_masks, images
312
313     def add_true_image(self, input, images, loss_masks):
314         t_nul = self.token2id["<nul>"]
315         nb_img_tokens = self.height * self.width + 1
316         input = F.pad(input, (0, nb_img_tokens), value=t_nul)
317         loss_masks = F.pad(loss_masks, (0, nb_img_tokens), value=0)
318         t = (input == t_nul).long()
319         i = (t.cumsum(dim=1) == 1).nonzero(as_tuple=True)
320         j = (
321             i[0][:, None],
322             i[1][:, None] + torch.arange(nb_img_tokens, device=input.device)[None, :],
323         )
324         input[j] = images
325         loss_masks[j] = 1
326         input, loss_masks = self.trim((input, loss_masks))
327         return input, loss_masks
328
329     def add_generated_image(self, input, loss_masks, model):
330         t_img, t_nul = self.token2id["<img>"], self.token2id["<nul>"]
331         nb_img_tokens = self.height * self.width + 1
332
333         input = F.pad(input, (0, nb_img_tokens), value=t_nul)
334         loss_masks = F.pad(loss_masks, (0, nb_img_tokens), value=0)
335         t = (input == t_nul).long()
336         i = (t.cumsum(dim=1) == 1).nonzero(as_tuple=True)
337         input[i] = t_img
338
339         j = (
340             i[0][:, None],
341             i[1][:, None]
342             + 1
343             + torch.arange(nb_img_tokens - 1, device=input.device)[None, :],
344         )
345         ar_masks = input.new_zeros(input.size(), dtype=torch.int64)
346         ar_masks[j] = 1
347         forbidden_tokens = (
348             torch.arange(self.vocabulary_size(), device=input.device) == t_nul
349         )
350         with torch.autograd.no_grad():
351             t = model.training
352             model.eval()
353             masked_inplace_autoregression(
354                 model,
355                 self.batch_size,
356                 input,
357                 ar_masks,
358                 forbidden_tokens,
359                 progress_bar_desc=None,
360                 device=self.device,
361             )
362             model.train(t)
363
364         input, loss_masks = self.trim((input, loss_masks))
365
366         return input, loss_masks
367
368     ######################
369
370     def __init__(
371         self,
372         nb_train_samples,
373         nb_test_samples,
374         batch_size,
375         height,
376         width,
377         nb_colors=5,
378         device=torch.device("cpu"),
379         pruner_train=None,
380         pruner_eval=None,
381     ):
382         def generate_descr(nb, cache_suffix, pruner):
383             return picoclvr.generate(
384                 nb,
385                 height=self.height,
386                 width=self.width,
387                 nb_colors=nb_colors,
388                 pruner=pruner,
389             )
390
391         self.height = height
392         self.width = width
393         self.batch_size = batch_size
394         self.device = device
395         self.pruner_train = pruner_train
396         self.pruner_eval = pruner_eval
397
398         param = {
399             "nb_train_samples": nb_train_samples,
400             "nb_test_samples": nb_test_samples,
401             "height": height,
402             "width": width,
403             "nb_colors": nb_colors,
404             "batch_size": batch_size,
405             "rng_state": list(torch.get_rng_state()),
406         }
407
408         log_string(
409             f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
410         )
411         self.train_descr = generate_descr(
412             nb_train_samples, "train", pruner=self.pruner_train
413         )
414         self.test_descr = generate_descr(nb_test_samples, "test", pruner=None)
415
416         # Build the tokenizer
417         tokens = {"<nul>", "<img>"}
418         for d in [self.train_descr, self.test_descr]:
419             for s in d:
420                 for t in s.strip().split(" "):
421                     tokens.add(t)
422         # make this set a sorted list to get the same tensors given
423         # the same descr
424         tokens = list(tokens)
425         tokens.sort()
426         self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
427         self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
428
429         # Tokenize the train and test sets
430         self.train_input = self.tensorize(self.train_descr)
431         self.test_input = self.tensorize(self.test_descr)
432
433     def batches(self, split="train"):
434         assert split in {"train", "test"}
435         input = self.train_input if split == "train" else self.test_input
436         for batch in tqdm.tqdm(
437             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
438         ):
439             yield self.trim(batch)
440
441     def vocabulary_size(self):
442         return len(self.token2id)
443
444     def compute_missing_properties(self, n_epoch, model, pruner=None):
445         acc_nb_requested_properties = []
446         acc_nb_missing_properties = []
447         acc_nb_results = 0
448
449         for input in tqdm.tqdm(
450             self.test_input.split(self.batch_size),
451             dynamic_ncols=True,
452             desc=f"test-properties",
453         ):
454             tape, loss_masks, _ = self.excise_last_image(input)
455             tape, loss_masks = self.add_generated_image(tape, loss_masks, model)
456             result_descr = self.detensorize(tape)
457             np = picoclvr.nb_properties(
458                 result_descr,
459                 height=self.height,
460                 width=self.width,
461                 pruner=pruner,
462             )
463             nb_requested_properties, _, nb_missing_properties = zip(*np)
464             acc_nb_requested_properties += nb_requested_properties
465             acc_nb_missing_properties += nb_missing_properties
466             acc_nb_results += len(result_descr)
467
468         nb_requested_properties = sum(acc_nb_requested_properties)
469         nb_missing_properties = sum(acc_nb_missing_properties)
470
471         prefix = "" if pruner is None else "pruned_"
472         log_string(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
473         log_string(
474             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
475         )
476         log_string(
477             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
478         )
479
480     ######################################################################
481
482     def produce_results(self, n_epoch, model):
483         self.compute_missing_properties(n_epoch, model)
484
485         if self.pruner_eval is not None:
486             self.compute_missing_properties(n_epoch, model, self.pruner_eval)
487
488         nb_tokens_to_generate = self.height * self.width + 3
489         result_descr = []
490         nb_per_primer = 8
491         primer = []
492
493         for primer_descr in [
494             "red above green <sep> green top <sep> blue right of red",
495             "there is red <sep> there is yellow <sep> there is blue",
496             "red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left",
497             "green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top",
498         ]:
499             primer += [primer_descr] * nb_per_primer
500
501         tape = self.tensorize(primer)
502         loss_masks = 1 - (tape == self.token2id["<nul>"]).long()
503         tape, loss_masks = self.add_generated_image(tape, loss_masks, model)
504         result_descr = self.detensorize(tape)
505
506         np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
507
508         acc_nb_requested_properties, _, acc_nb_missing_properties = zip(*np)
509         acc_nb_results = len(result_descr)
510
511         nb_requested_properties = sum(acc_nb_requested_properties)
512         nb_missing_properties = sum(acc_nb_missing_properties)
513
514         prefix = "demo_"
515         log_string(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
516         log_string(
517             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
518         )
519         log_string(
520             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
521         )
522
523         img = picoclvr.descr2img(result_descr, height=self.height, width=self.width)
524
525         if img.dim() == 5:
526             if img.size(1) == 1:
527                 img = F.pad(img.squeeze(1), pad=(1, 1, 1, 1), value=64)
528             else:
529                 img = torch.cat(
530                     [
531                         torchvision.utils.make_grid(x, padding=1, pad_value=64)[None]
532                         for x in img
533                     ],
534                     0,
535                 )
536
537         image_name = os.path.join(args.result_dir, f"picoclvr_result_{n_epoch:04d}.png")
538         torchvision.utils.save_image(
539             img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=0.0
540         )
541         log_string(f"wrote {image_name}")
542
543
544 ######################################################################
545
546
547 class TaskMNIST(Task):
548     def __init__(self, batch_size, device=torch.device("cpu")):
549         self.device = device
550         self.batch_size = batch_size
551
552     def batches(self, split="train"):
553         assert split in {"train", "test"}
554         data_set = torchvision.datasets.MNIST(
555             root="./data", train=(split == "train"), download=True
556         )
557         data_input = data_set.data.view(-1, 28 * 28).long()
558         if args.nb_train_samples is not None:
559             data_input = data_input[: args.nb_train_samples]
560         for batch in tqdm.tqdm(
561             data_input.split(self.batch_size), desc=f"epoch-{split}"
562         ):
563             yield batch
564
565     def vocabulary_size(self):
566         return 256
567
568     def produce_results(self, n_epoch, model):
569         results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64)
570         ar_mask = torch.full_like(results, 1)
571         masked_inplace_autoregression(
572             model, self.batch_size, results, ar_mask, device=self.device
573         )
574         image_name = os.path.join(args.result_dir, f"mnist_result_{n_epoch:04d}.png")
575         torchvision.utils.save_image(
576             1 - results.reshape(-1, 1, 28, 28) / 255.0,
577             image_name,
578             nrow=16,
579             pad_value=0.8,
580         )
581         log_string(f"wrote {image_name}")
582
583
584 ######################################################################
585
586 import maze
587
588
589 class TaskMaze(Task):
590     def map2seq(self, *m):
591         return torch.cat([x.flatten(1) for x in m], 1)
592
593     def seq2map(self, s):
594         s = s.reshape(s.size(0), -1, self.height, self.width)
595         return (s[:, k] for k in range(s.size(1)))
596
597     def __init__(
598         self,
599         nb_train_samples,
600         nb_test_samples,
601         batch_size,
602         height,
603         width,
604         nb_walls,
605         device=torch.device("cpu"),
606     ):
607         self.batch_size = batch_size
608         self.height = height
609         self.width = width
610         self.device = device
611
612         train_mazes, train_paths, _ = maze.create_maze_data(
613             nb_train_samples,
614             height=height,
615             width=width,
616             nb_walls=nb_walls,
617             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"),
618         )
619         self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device))
620
621         test_mazes, test_paths, _ = maze.create_maze_data(
622             nb_test_samples,
623             height=height,
624             width=width,
625             nb_walls=nb_walls,
626             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"),
627         )
628         self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
629
630         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
631
632     def batches(self, split="train", nb_to_use=-1, desc=None):
633         assert split in {"train", "test"}
634         input = self.train_input if split == "train" else self.test_input
635         if nb_to_use > 0:
636             input = input[:nb_to_use]
637         if desc is None:
638             desc = f"epoch-{split}"
639         for batch in tqdm.tqdm(
640             input.split(self.batch_size), dynamic_ncols=True, desc=desc
641         ):
642             yield batch
643
644     def vocabulary_size(self):
645         return self.nb_codes
646
647     def compute_error(self, model, split="train", nb_to_use=-1):
648         nb_total, nb_correct = 0, 0
649         count = torch.zeros(
650             self.width * self.height,
651             self.width * self.height,
652             device=self.device,
653             dtype=torch.int64,
654         )
655         for input in tqdm.tqdm(
656             task.batches(split, nb_to_use),
657             dynamic_ncols=True,
658             desc=f"test-mazes",
659         ):
660             result = input.clone()
661             ar_mask = result.new_zeros(result.size())
662             ar_mask[:, self.height * self.width :] = 1
663             result *= 1 - ar_mask
664             masked_inplace_autoregression(
665                 model,
666                 self.batch_size,
667                 result,
668                 ar_mask,
669                 progress_bar_desc=None,
670                 device=self.device,
671             )
672             mazes, paths = self.seq2map(result)
673             path_correctness = maze.path_correctness(mazes, paths)
674             nb_correct += path_correctness.long().sum()
675             nb_total += mazes.size(0)
676
677             optimal_path_lengths = (
678                 (input[:, self.height * self.width :] == maze.v_path).long().sum(1)
679             )
680             predicted_path_lengths = (
681                 (result[:, self.height * self.width :] == maze.v_path).long().sum(1)
682             )
683             optimal_path_lengths = optimal_path_lengths[path_correctness]
684             predicted_path_lengths = predicted_path_lengths[path_correctness]
685             count[optimal_path_lengths, predicted_path_lengths] += 1
686
687         if count.max() == 0:
688             count = None
689         else:
690             count = count[
691                 : count.sum(1).nonzero().max() + 1, : count.sum(0).nonzero().max() + 1
692             ]
693
694         return nb_total, nb_correct, count
695
696     def produce_results(self, n_epoch, model):
697         with torch.autograd.no_grad():
698             t = model.training
699             model.eval()
700
701             train_nb_total, train_nb_correct, count = self.compute_error(
702                 model, "train", nb_to_use=1000
703             )
704             log_string(
705                 f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
706             )
707
708             test_nb_total, test_nb_correct, count = self.compute_error(
709                 model, "test", nb_to_use=1000
710             )
711             log_string(
712                 f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
713             )
714
715             if count is not None:
716                 proportion_optimal = count.diagonal().sum().float() / count.sum()
717                 log_string(f"proportion_optimal_test {proportion_optimal*100:.02f}%")
718                 with open(
719                     os.path.join(args.result_dir, f"maze_result_{n_epoch:04d}.txt"), "w"
720                 ) as f:
721                     for i in range(count.size(0)):
722                         for j in range(count.size(1)):
723                             eol = " " if j < count.size(1) - 1 else "\n"
724                             f.write(f"{count[i,j]}{eol}")
725
726             input = self.test_input[:48]
727             result = input.clone()
728             ar_mask = result.new_zeros(result.size())
729             ar_mask[:, self.height * self.width :] = 1
730             result *= 1 - ar_mask
731             masked_inplace_autoregression(
732                 model, self.batch_size, result, ar_mask, device=self.device
733             )
734
735             mazes, paths = self.seq2map(input)
736             _, predicted_paths = self.seq2map(result)
737
738             filename = os.path.join(args.result_dir, f"maze_result_{n_epoch:04d}.png")
739             maze.save_image(
740                 filename,
741                 mazes=mazes,
742                 target_paths=paths,
743                 predicted_paths=predicted_paths,
744                 path_correct=maze.path_correctness(mazes, predicted_paths),
745                 path_optimal=maze.path_optimality(paths, predicted_paths),
746             )
747             log_string(f"wrote {filename}")
748
749             model.train(t)
750
751
752 ######################################################################
753
754
755 import snake
756
757
758 class TaskSnake(Task):
759     def __init__(
760         self,
761         nb_train_samples,
762         nb_test_samples,
763         batch_size,
764         height,
765         width,
766         nb_colors,
767         length,
768         prompt_length,
769         device=torch.device("cpu"),
770     ):
771         self.batch_size = batch_size
772         self.height = height
773         self.width = width
774         self.device = device
775         self.prompt_length = prompt_length
776
777         self.train_input, self.train_prior_visits, _, _ = snake.generate_sequences(
778             nb_train_samples,
779             height,
780             width,
781             nb_colors,
782             length,
783             prompt_length,
784             self.device,
785         )
786         self.test_input, self.test_prior_visits, _, _ = snake.generate_sequences(
787             nb_test_samples,
788             height,
789             width,
790             nb_colors,
791             length,
792             prompt_length,
793             self.device,
794         )
795
796         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
797
798     def batches(self, split="train", nb_to_use=-1, desc=None):
799         assert split in {"train", "test"}
800         input = self.train_input if split == "train" else self.test_input
801         if nb_to_use > 0:
802             input = input[:nb_to_use]
803         if desc is None:
804             desc = f"epoch-{split}"
805         for batch in tqdm.tqdm(
806             input.split(self.batch_size), dynamic_ncols=True, desc=desc
807         ):
808             yield batch
809
810     def vocabulary_size(self):
811         return self.nb_codes
812
813     def produce_results(self, n_epoch, model):
814         with torch.autograd.no_grad():
815             t = model.training
816             model.eval()
817
818             def compute_nb_correct(input, prior_visits):
819                 result = input.clone()
820                 i = torch.arange(result.size(1), device=result.device)[None, :]
821                 ar_mask = (
822                     torch.logical_and(i >= self.prompt_length * 2, i % 2 == 0)
823                     .long()
824                     .expand_as(result)
825                 )
826                 result *= 1 - ar_mask
827
828                 # snake.solver(result,ar_mask)
829
830                 masked_inplace_autoregression(
831                     model, self.batch_size, result, ar_mask, device=self.device
832                 )
833
834                 nb_total = ((prior_visits > 0) * ar_mask).sum()
835
836                 nb_correct = (
837                     (result == input).long() * (prior_visits > 0) * ar_mask
838                 ).sum()
839
840                 # nb_total = result.size(0)
841                 # nb_correct = ((result - input).abs().sum(1) == 0).sum()
842
843                 return nb_total, nb_correct
844
845             # train_nb_total, train_nb_correct = compute_nb_correct(
846             # self.train_input, self.train_prior_visits
847             # )
848
849             # log_string(
850             # f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
851             # )
852
853             test_nb_total, test_nb_correct = compute_nb_correct(
854                 self.test_input[:1000], self.test_prior_visits[:1000]
855             )
856
857             log_string(
858                 f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
859             )
860
861             model.train(t)
862
863
864 ######################################################################
865
866
867 import stack
868
869
870 class TaskStack(Task):
871     def __init__(
872         self,
873         nb_train_samples,
874         nb_test_samples,
875         batch_size,
876         nb_steps,
877         nb_stacks,
878         nb_values,
879         device=torch.device("cpu"),
880     ):
881         self.batch_size = batch_size
882         self.nb_steps = nb_steps
883         self.nb_stacks = nb_stacks
884         self.nb_values = nb_values
885         self.device = device
886
887         self.train_input, self.train_stack_counts = stack.generate_sequences(
888             nb_train_samples, nb_steps, nb_stacks, nb_values, self.device
889         )
890
891         self.test_input, self.test_stack_counts = stack.generate_sequences(
892             nb_test_samples, nb_steps, nb_stacks, nb_values, self.device
893         )
894
895         mask = self.test_input.clone()
896         stack.remove_poped_values(mask,self.nb_stacks)
897         mask=(mask!=self.test_input)
898         counts = self.test_stack_counts.flatten()[mask.flatten()]
899         counts=F.one_hot(counts).sum(0)
900         log_string(f"stack_count {counts}")
901
902         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
903
904     def batches(self, split="train", nb_to_use=-1, desc=None):
905         assert split in {"train", "test"}
906         input = self.train_input if split == "train" else self.test_input
907         if nb_to_use > 0:
908             input = input[:nb_to_use]
909         if desc is None:
910             desc = f"epoch-{split}"
911         for batch in tqdm.tqdm(
912             input.split(self.batch_size), dynamic_ncols=True, desc=desc
913         ):
914             yield batch
915
916     def vocabulary_size(self):
917         return self.nb_codes
918
919     def produce_results(self, n_epoch, model):
920         with torch.autograd.no_grad():
921             t = model.training
922             model.eval()
923
924             def compute_nb_correct(input):
925                 result = input.clone()
926                 stack.remove_poped_values(result,self.nb_stacks)
927                 ar_mask = (result != input).long()
928                 result *= 1 - ar_mask
929
930                 masked_inplace_autoregression(
931                     model, self.batch_size, result, ar_mask, device=self.device
932                 )
933
934                 nb_total = ar_mask.sum()
935
936                 nb_correct = (
937                     (result == input).long() * ar_mask
938                 ).sum()
939
940                 return nb_total, nb_correct
941
942             test_nb_total, test_nb_correct = compute_nb_correct(self.test_input[:1000])
943
944             log_string(
945                 f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
946             )
947
948             model.train(t)
949
950
951 ######################################################################
952
953
954 def picoclvr_pruner_horizontal_green(p):
955     return not ("green" in p and ("left" in p or "right" in p))
956
957
958 picoclvr_pruner_train = (
959     picoclvr_pruner_horizontal_green
960     if args.picocvlr_prune_properties in {"train+eval"}
961     else None
962 )
963
964 picoclvr_pruner_eval = (
965     (lambda p: not picoclvr_pruner_horizontal_green(p))
966     if args.picocvlr_prune_properties in {"train+eval", "eval"}
967     else None
968 )
969
970 ######################################################################
971
972 if args.task == "picoclvr":
973     task = TaskPicoCLVR(
974         nb_train_samples=args.nb_train_samples,
975         nb_test_samples=args.nb_test_samples,
976         batch_size=args.batch_size,
977         height=args.picoclvr_height,
978         width=args.picoclvr_width,
979         nb_colors=args.picoclvr_nb_colors,
980         device=device,
981         pruner_train=picoclvr_pruner_train,
982         pruner_eval=picoclvr_pruner_eval,
983     )
984
985 elif args.task == "mnist":
986     task = TaskMNIST(
987         batch_size=args.batch_size,
988         device=device,
989     )
990
991 elif args.task == "maze":
992     task = TaskMaze(
993         nb_train_samples=args.nb_train_samples,
994         nb_test_samples=args.nb_test_samples,
995         batch_size=args.batch_size,
996         height=args.maze_height,
997         width=args.maze_width,
998         nb_walls=args.maze_nb_walls,
999         device=device,
1000     )
1001
1002 elif args.task == "snake":
1003     task = TaskSnake(
1004         nb_train_samples=args.nb_train_samples,
1005         nb_test_samples=args.nb_test_samples,
1006         batch_size=args.batch_size,
1007         height=args.snake_height,
1008         width=args.snake_width,
1009         nb_colors=args.snake_nb_colors,
1010         length=args.snake_length,
1011         prompt_length=args.snake_length // 2,
1012         device=device,
1013     )
1014
1015 elif args.task == "stack":
1016     task = TaskStack(
1017         nb_train_samples=args.nb_train_samples,
1018         nb_test_samples=args.nb_test_samples,
1019         batch_size=args.batch_size,
1020         nb_steps = args.stack_nb_steps,
1021         nb_stacks = args.stack_nb_stacks,
1022         nb_values = args.stack_nb_values,
1023         device=device,
1024     )
1025
1026 else:
1027     raise ValueError(f"Unknown task {args.task}")
1028
1029 ######################################################################
1030
1031 log_string(f"device {device}")
1032
1033 vocabulary_size = task.vocabulary_size()
1034
1035 log_string(f"vocabulary_size {vocabulary_size}")
1036
1037 ##############################
1038
1039 model = mygpt.MyGPT(
1040     vocabulary_size=vocabulary_size,
1041     dim_model=args.dim_model,
1042     dim_keys=args.dim_keys,
1043     dim_hidden=args.dim_hidden,
1044     nb_heads=args.nb_heads,
1045     nb_blocks=args.nb_blocks,
1046     causal=True,
1047     dropout=args.dropout,
1048 )
1049
1050 model.to(device)
1051
1052 nb_parameters = sum(p.numel() for p in model.parameters())
1053 log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
1054
1055 ######################################################################
1056
1057 nb_epochs_finished = 0
1058
1059 if args.no_checkpoint:
1060     log_string(f"not trying to load checkpoint.")
1061
1062 else:
1063     try:
1064         checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
1065         checkpoint = torch.load(checkpoint_name)
1066         nb_epochs_finished = checkpoint["nb_epochs_finished"]
1067         model.load_state_dict(checkpoint["model_state"])
1068         torch.set_rng_state(checkpoint["rng_state"])
1069         if torch.cuda.is_available():
1070             torch.cuda.set_rng_state(checkpoint["cuda_rng_state"])
1071
1072         log_string(f"checkpoint loaded with {nb_epochs_finished} epochs finished.")
1073
1074     except FileNotFoundError:
1075         log_string("starting from scratch.")
1076
1077     except:
1078         log_string("error when loading the checkpoint.")
1079         exit(1)
1080
1081 ######################################################################
1082
1083 nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
1084
1085 token_count = 0
1086 for input in task.batches(split="train"):
1087     token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1))
1088 token_probas = token_count / token_count.sum()
1089 entropy = -torch.xlogy(token_probas, token_probas).sum()
1090 train_set_perplexity = math.exp(entropy)
1091
1092 ##############################
1093
1094 if args.learning_rate_schedule == "cos":
1095     learning_rate_schedule = {}
1096     for n_epoch in range(args.nb_epochs):
1097         u = n_epoch / args.nb_epochs * math.pi
1098         learning_rate_schedule[n_epoch] = args.learning_rate * 0.5 * (1 + math.cos(u))
1099 else:
1100     u = {
1101         int(k): float(v)
1102         for k, v in [
1103             tuple(x.split(":")) for x in args.learning_rate_schedule.split(",")
1104         ]
1105     }
1106
1107     learning_rate_schedule = {}
1108     learning_rate = args.learning_rate
1109     for n_epoch in range(args.nb_epochs):
1110         if n_epoch in u:
1111             learning_rate = u[n_epoch]
1112         learning_rate_schedule[n_epoch] = learning_rate
1113
1114 log_string(f"learning_rate_schedule {learning_rate_schedule}")
1115
1116 ##############################
1117
1118 nb_samples_seen = 0
1119
1120 if nb_epochs_finished >= nb_epochs:
1121     task.produce_results(nb_epochs_finished, model)
1122
1123 for n_epoch in range(nb_epochs_finished, nb_epochs):
1124     learning_rate = learning_rate_schedule[n_epoch]
1125
1126     log_string(f"learning_rate {learning_rate}")
1127
1128     if args.optim == "sgd":
1129         optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
1130     elif args.optim == "adam":
1131         optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
1132     elif args.optim == "adamw":
1133         optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
1134     else:
1135         raise ValueError(f"Unknown optimizer {args.optim}.")
1136
1137     model.train()
1138
1139     nb_train_samples, acc_train_loss = 0, 0.0
1140
1141     for input in task.batches(split="train"):
1142         input = input.to(device)
1143         output = model(mygpt.BracketedSequence(input)).x
1144         loss = F.cross_entropy(output.transpose(1, 2), input)
1145         acc_train_loss += loss.item() * input.size(0)
1146         nb_train_samples += input.size(0)
1147         nb_samples_seen += input.size(0)
1148
1149         optimizer.zero_grad()
1150         loss.backward()
1151         optimizer.step()
1152
1153     with torch.autograd.no_grad():
1154         model.eval()
1155
1156         nb_test_samples, acc_test_loss = 0, 0.0
1157
1158         for input in task.batches(split="test"):
1159             input = input.to(device)
1160
1161             output = model(mygpt.BracketedSequence(input)).x
1162             loss = F.cross_entropy(output.transpose(1, 2), input)
1163             acc_test_loss += loss.item() * input.size(0)
1164             nb_test_samples += input.size(0)
1165
1166         train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
1167         test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
1168
1169         log_string(
1170             f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
1171         )
1172
1173         task.produce_results(n_epoch, model)
1174
1175     checkpoint = {
1176         "nb_epochs_finished": n_epoch + 1,
1177         "model_state": model.state_dict(),
1178         "rng_state": torch.get_rng_state(),
1179     }
1180
1181     if torch.cuda.is_available():
1182         checkpoint["cuda_rng_state"] = torch.cuda.get_rng_state()
1183
1184     checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
1185     torch.save(checkpoint, checkpoint_name)
1186     log_string(f"saved checkpoint {checkpoint_name}")
1187
1188 ######################################################################