3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 # torch.backends.cuda.matmul.allow_tf23
9 # torch.autocast(torch.bfloat16)
11 import math, sys, argparse, time, tqdm, os
13 import torch, torchvision
15 from torch.nn import functional as F
17 import mygpt, tensorstack
19 ######################################################################
21 if torch.cuda.is_available():
22 device = torch.device("cuda")
23 torch.backends.cuda.matmul.allow_tf32 = True
25 device = torch.device("cpu")
27 ######################################################################
29 parser = argparse.ArgumentParser(
30 description="An implementation of GPT with cache.",
31 formatter_class=argparse.ArgumentDefaultsHelpFormatter,
38 help="picoclvr, mnist, maze, snake, stack, expr",
41 parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
43 parser.add_argument("--result_dir", type=str, default=None)
45 parser.add_argument("--seed", type=int, default=0)
47 parser.add_argument("--nb_epochs", type=int, default=None)
49 parser.add_argument("--batch_size", type=int, default=None)
51 parser.add_argument("--nb_train_samples", type=int, default=None)
53 parser.add_argument("--nb_test_samples", type=int, default=None)
55 parser.add_argument("--optim", type=str, default="adam")
57 parser.add_argument("--learning_rate", type=float, default=1e-4)
59 parser.add_argument("--learning_rate_schedule", type=str, default="10: 2e-5,30: 4e-6")
61 parser.add_argument("--dim_model", type=int, default=512)
63 parser.add_argument("--dim_keys", type=int, default=64)
65 parser.add_argument("--dim_hidden", type=int, default=2048)
67 parser.add_argument("--nb_heads", type=int, default=8)
69 parser.add_argument("--nb_blocks", type=int, default=12)
71 parser.add_argument("--dropout", type=float, default=0.1)
73 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
75 parser.add_argument("--no_checkpoint", action="store_true", default=False)
77 parser.add_argument("--overwrite_results", action="store_true", default=False)
79 parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
81 ##############################
84 parser.add_argument("--picoclvr_nb_colors", type=int, default=5)
86 parser.add_argument("--picoclvr_height", type=int, default=12)
88 parser.add_argument("--picoclvr_width", type=int, default=16)
90 parser.add_argument("--picocvlr_prune_properties", type=str, default="none")
92 ##############################
95 parser.add_argument("--maze_height", type=int, default=13)
97 parser.add_argument("--maze_width", type=int, default=21)
99 parser.add_argument("--maze_nb_walls", type=int, default=15)
101 ##############################
104 parser.add_argument("--snake_height", type=int, default=6)
106 parser.add_argument("--snake_width", type=int, default=8)
108 parser.add_argument("--snake_nb_colors", type=int, default=5)
110 parser.add_argument("--snake_length", type=int, default=200)
112 ##############################
115 parser.add_argument("--stack_nb_steps", type=int, default=100)
117 parser.add_argument("--stack_nb_stacks", type=int, default=1)
119 parser.add_argument("--stack_nb_digits", type=int, default=3)
121 parser.add_argument("--stack_fraction_values_for_train", type=float, default=None)
123 ##############################
126 parser.add_argument("--expr_nb_variables", type=int, default=5)
128 parser.add_argument("--expr_sequence_length", type=int, default=30)
130 ######################################################################
132 args = parser.parse_args()
134 assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"}
136 if args.result_dir is None:
137 args.result_dir = f"results_{args.task}"
139 ######################################################################
145 "nb_train_samples": 250000,
146 "nb_test_samples": 10000,
151 "nb_train_samples": 250000,
152 "nb_test_samples": 10000,
157 "nb_train_samples": 250000,
158 "nb_test_samples": 10000,
163 "nb_train_samples": 250000,
164 "nb_test_samples": 10000,
169 "nb_train_samples": 100000,
170 "nb_test_samples": 1000,
175 "nb_train_samples": 250000,
176 "nb_test_samples": 10000,
180 if args.task in default_args:
181 for k, v in default_args[args.task].items():
182 if getattr(args, k) is None:
185 ######################################################################
188 os.mkdir(args.result_dir)
189 except FileExistsError:
190 if not args.overwrite_results:
191 print(f"result directory {args.result_dir} already exists")
194 log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
197 # torch.backends.cudnn.deterministic = True
198 # torch.backends.cudnn.benchmark = False
199 # torch.use_deterministic_algorithms(True)
200 torch.manual_seed(args.seed)
201 if torch.cuda.is_available():
202 torch.cuda.manual_seed_all(args.seed)
204 ######################################################################
208 t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
210 if log_file is not None:
211 log_file.write(t + s + "\n")
219 log_string(f"args.{n} {getattr(args, n)}")
221 ######################################################################
224 # ra_mask is boolean, with 1s on the values to generate
227 def masked_inplace_autoregression(
232 forbidden_tokens=None,
233 progress_bar_desc="autoregression",
234 device=torch.device("cpu"),
236 batches = zip(input.split(batch_size), ar_mask.split(batch_size))
238 if progress_bar_desc is not None:
242 desc=progress_bar_desc,
243 total=input.size(0) // batch_size,
246 for input, ar_mask in batches:
247 i = (ar_mask.sum(0) > 0).nonzero()
250 mygpt.BracketedSequence(input, 0, i.min())
251 ) # Needed to initialize the model's cache
252 for s in range(i.min(), i.max() + 1):
253 output = model(mygpt.BracketedSequence(input, s, 1)).x
254 logits = output[:, s]
255 if forbidden_tokens is not None:
256 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
257 if args.deterministic_synthesis:
258 t_next = logits.argmax(1)
260 dist = torch.distributions.categorical.Categorical(logits=logits)
261 t_next = dist.sample()
262 input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
265 ######################################################################
269 def batches(self, split="train"):
272 def vocabulary_size(self):
275 def produce_results(self, n_epoch, model):
279 ######################################################################
284 class TaskPicoCLVR(Task):
285 # Make a tensor from a list of strings
286 def tensorize(self, descr):
287 token_descr = [s.strip().split(" ") for s in descr]
288 l = max([len(s) for s in token_descr])
289 token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
290 id_descr = [[self.token2id[u] for u in s] for s in token_descr]
291 return torch.tensor(id_descr, device=self.device)
293 # Make a list of strings from a tensor
294 def detensorize(self, x):
295 return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
297 # trim all the tensors in the tuple z to remove as much token from
298 # left and right in the first tensor. If z is a tuple, all its
299 # elements are trimed according to the triming for the first
300 def trim(self, z, token="<nul>"):
301 n = self.token2id[token]
304 i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
305 a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
306 return tuple([t[:, a:b] for t in z])
308 i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
309 a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
312 ######################
313 # Not the cleanest part of the code
315 # Extract the last image of each sequence, from the last <img>
316 # included, and set to <nul> all the tokens from the beginning of
317 # that image to the end
318 def excise_last_image(self, input):
319 t_img, t_nul = self.token2id["<img>"], self.token2id["<nul>"]
320 nb_img_tokens = self.height * self.width + 1
322 input = input.clone()
323 t = (input == t_img).long()
324 tail_masks = (t.cumsum(dim=1) == t.sum(dim=1, keepdim=True)).long()
325 i = (t * tail_masks).nonzero(as_tuple=True)
328 i[1][:, None] + torch.arange(nb_img_tokens, device=input.device)[None, :],
330 images = self.trim(input[j])
332 loss_masks = 1 - tail_masks
333 input, loss_masks = self.trim((input, loss_masks))
334 return input, loss_masks, images
336 def add_true_image(self, input, images, loss_masks):
337 t_nul = self.token2id["<nul>"]
338 nb_img_tokens = self.height * self.width + 1
339 input = F.pad(input, (0, nb_img_tokens), value=t_nul)
340 loss_masks = F.pad(loss_masks, (0, nb_img_tokens), value=0)
341 t = (input == t_nul).long()
342 i = (t.cumsum(dim=1) == 1).nonzero(as_tuple=True)
345 i[1][:, None] + torch.arange(nb_img_tokens, device=input.device)[None, :],
349 input, loss_masks = self.trim((input, loss_masks))
350 return input, loss_masks
352 def add_generated_image(self, input, loss_masks, model):
353 t_img, t_nul = self.token2id["<img>"], self.token2id["<nul>"]
354 nb_img_tokens = self.height * self.width + 1
356 input = F.pad(input, (0, nb_img_tokens), value=t_nul)
357 loss_masks = F.pad(loss_masks, (0, nb_img_tokens), value=0)
358 t = (input == t_nul).long()
359 i = (t.cumsum(dim=1) == 1).nonzero(as_tuple=True)
366 + torch.arange(nb_img_tokens - 1, device=input.device)[None, :],
368 ar_masks = input.new_zeros(input.size(), dtype=torch.int64)
371 torch.arange(self.vocabulary_size(), device=input.device) == t_nul
373 with torch.autograd.no_grad():
376 masked_inplace_autoregression(
382 progress_bar_desc=None,
387 input, loss_masks = self.trim((input, loss_masks))
389 return input, loss_masks
391 ######################
401 device=torch.device("cpu"),
405 def generate_descr(nb, cache_suffix, pruner):
406 return picoclvr.generate(
416 self.batch_size = batch_size
418 self.pruner_train = pruner_train
419 self.pruner_eval = pruner_eval
422 "nb_train_samples": nb_train_samples,
423 "nb_test_samples": nb_test_samples,
426 "nb_colors": nb_colors,
427 "batch_size": batch_size,
428 "rng_state": list(torch.get_rng_state()),
432 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
434 self.train_descr = generate_descr(
435 nb_train_samples, "train", pruner=self.pruner_train
437 self.test_descr = generate_descr(nb_test_samples, "test", pruner=None)
439 # Build the tokenizer
440 tokens = {"<nul>", "<img>"}
441 for d in [self.train_descr, self.test_descr]:
443 for t in s.strip().split(" "):
445 # make this set a sorted list to get the same tensors given
447 tokens = list(tokens)
449 self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
450 self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
452 # Tokenize the train and test sets
453 self.train_input = self.tensorize(self.train_descr)
454 self.test_input = self.tensorize(self.test_descr)
456 def batches(self, split="train"):
457 assert split in {"train", "test"}
458 input = self.train_input if split == "train" else self.test_input
459 for batch in tqdm.tqdm(
460 input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
462 yield self.trim(batch)
464 def vocabulary_size(self):
465 return len(self.token2id)
467 def compute_missing_properties(self, n_epoch, model, pruner=None):
468 acc_nb_requested_properties = []
469 acc_nb_missing_properties = []
472 for input in tqdm.tqdm(
473 self.test_input.split(self.batch_size),
475 desc=f"test-properties",
477 tape, loss_masks, _ = self.excise_last_image(input)
478 tape, loss_masks = self.add_generated_image(tape, loss_masks, model)
479 result_descr = self.detensorize(tape)
480 np = picoclvr.nb_properties(
486 nb_requested_properties, _, nb_missing_properties = zip(*np)
487 acc_nb_requested_properties += nb_requested_properties
488 acc_nb_missing_properties += nb_missing_properties
489 acc_nb_results += len(result_descr)
491 nb_requested_properties = sum(acc_nb_requested_properties)
492 nb_missing_properties = sum(acc_nb_missing_properties)
494 prefix = "" if pruner is None else "pruned_"
495 log_string(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
497 f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
500 f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
503 ######################################################################
505 def produce_results(self, n_epoch, model):
506 self.compute_missing_properties(n_epoch, model)
508 if self.pruner_eval is not None:
509 self.compute_missing_properties(n_epoch, model, self.pruner_eval)
511 nb_tokens_to_generate = self.height * self.width + 3
516 for primer_descr in [
517 "red above green <sep> green top <sep> blue right of red",
518 "there is red <sep> there is yellow <sep> there is blue",
519 "red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left",
520 "green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top",
522 primer += [primer_descr] * nb_per_primer
524 tape = self.tensorize(primer)
525 loss_masks = 1 - (tape == self.token2id["<nul>"]).long()
526 tape, loss_masks = self.add_generated_image(tape, loss_masks, model)
527 result_descr = self.detensorize(tape)
529 np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
531 acc_nb_requested_properties, _, acc_nb_missing_properties = zip(*np)
532 acc_nb_results = len(result_descr)
534 nb_requested_properties = sum(acc_nb_requested_properties)
535 nb_missing_properties = sum(acc_nb_missing_properties)
538 log_string(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
540 f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
543 f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
546 img = picoclvr.descr2img(result_descr, height=self.height, width=self.width)
550 img = F.pad(img.squeeze(1), pad=(1, 1, 1, 1), value=64)
554 torchvision.utils.make_grid(x, padding=1, pad_value=64)[None]
560 image_name = os.path.join(args.result_dir, f"picoclvr_result_{n_epoch:04d}.png")
561 torchvision.utils.save_image(
562 img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=0.0
564 log_string(f"wrote {image_name}")
567 ######################################################################
570 class TaskMNIST(Task):
571 def __init__(self, batch_size, device=torch.device("cpu")):
573 self.batch_size = batch_size
575 def batches(self, split="train"):
576 assert split in {"train", "test"}
577 data_set = torchvision.datasets.MNIST(
578 root="./data", train=(split == "train"), download=True
580 data_input = data_set.data.view(-1, 28 * 28).long()
581 if args.nb_train_samples is not None:
582 data_input = data_input[: args.nb_train_samples]
583 for batch in tqdm.tqdm(
584 data_input.split(self.batch_size), desc=f"epoch-{split}"
588 def vocabulary_size(self):
591 def produce_results(self, n_epoch, model):
592 results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64)
593 ar_mask = torch.full_like(results, 1)
594 masked_inplace_autoregression(
595 model, self.batch_size, results, ar_mask, device=self.device
597 image_name = os.path.join(args.result_dir, f"mnist_result_{n_epoch:04d}.png")
598 torchvision.utils.save_image(
599 1 - results.reshape(-1, 1, 28, 28) / 255.0,
604 log_string(f"wrote {image_name}")
607 ######################################################################
612 class TaskMaze(Task):
613 def map2seq(self, *m):
614 return torch.cat([x.flatten(1) for x in m], 1)
616 def seq2map(self, s):
617 s = s.reshape(s.size(0), -1, self.height, self.width)
618 return (s[:, k] for k in range(s.size(1)))
628 device=torch.device("cpu"),
630 self.batch_size = batch_size
635 train_mazes, train_paths, _ = maze.create_maze_data(
640 progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"),
642 self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device))
644 test_mazes, test_paths, _ = maze.create_maze_data(
649 progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"),
651 self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
653 self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
655 def batches(self, split="train", nb_to_use=-1, desc=None):
656 assert split in {"train", "test"}
657 input = self.train_input if split == "train" else self.test_input
659 input = input[:nb_to_use]
661 desc = f"epoch-{split}"
662 for batch in tqdm.tqdm(
663 input.split(self.batch_size), dynamic_ncols=True, desc=desc
667 def vocabulary_size(self):
670 def compute_error(self, model, split="train", nb_to_use=-1):
671 nb_total, nb_correct = 0, 0
673 self.width * self.height,
674 self.width * self.height,
678 for input in tqdm.tqdm(
679 task.batches(split, nb_to_use),
683 result = input.clone()
684 ar_mask = result.new_zeros(result.size())
685 ar_mask[:, self.height * self.width :] = 1
686 result *= 1 - ar_mask
687 masked_inplace_autoregression(
692 progress_bar_desc=None,
695 mazes, paths = self.seq2map(result)
696 path_correctness = maze.path_correctness(mazes, paths)
697 nb_correct += path_correctness.long().sum()
698 nb_total += mazes.size(0)
700 optimal_path_lengths = (
701 (input[:, self.height * self.width :] == maze.v_path).long().sum(1)
703 predicted_path_lengths = (
704 (result[:, self.height * self.width :] == maze.v_path).long().sum(1)
706 optimal_path_lengths = optimal_path_lengths[path_correctness]
707 predicted_path_lengths = predicted_path_lengths[path_correctness]
708 count[optimal_path_lengths, predicted_path_lengths] += 1
714 : count.sum(1).nonzero().max() + 1, : count.sum(0).nonzero().max() + 1
717 return nb_total, nb_correct, count
719 def produce_results(self, n_epoch, model):
720 with torch.autograd.no_grad():
724 train_nb_total, train_nb_correct, count = self.compute_error(
725 model, "train", nb_to_use=1000
728 f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
731 test_nb_total, test_nb_correct, count = self.compute_error(
732 model, "test", nb_to_use=1000
735 f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
738 if count is not None:
739 proportion_optimal = count.diagonal().sum().float() / count.sum()
740 log_string(f"proportion_optimal_test {proportion_optimal*100:.02f}%")
742 os.path.join(args.result_dir, f"maze_result_{n_epoch:04d}.txt"), "w"
744 for i in range(count.size(0)):
745 for j in range(count.size(1)):
746 eol = " " if j < count.size(1) - 1 else "\n"
747 f.write(f"{count[i,j]}{eol}")
749 input = self.test_input[:48]
750 result = input.clone()
751 ar_mask = result.new_zeros(result.size())
752 ar_mask[:, self.height * self.width :] = 1
753 result *= 1 - ar_mask
754 masked_inplace_autoregression(
755 model, self.batch_size, result, ar_mask, device=self.device
758 mazes, paths = self.seq2map(input)
759 _, predicted_paths = self.seq2map(result)
761 filename = os.path.join(args.result_dir, f"maze_result_{n_epoch:04d}.png")
766 predicted_paths=predicted_paths,
767 path_correct=maze.path_correctness(mazes, predicted_paths),
768 path_optimal=maze.path_optimality(paths, predicted_paths),
770 log_string(f"wrote {filename}")
775 ######################################################################
781 class TaskSnake(Task):
792 device=torch.device("cpu"),
794 self.batch_size = batch_size
798 self.prompt_length = prompt_length
800 self.train_input, self.train_prior_visits, _, _ = snake.generate_sequences(
809 self.test_input, self.test_prior_visits, _, _ = snake.generate_sequences(
819 self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
821 def batches(self, split="train", nb_to_use=-1, desc=None):
822 assert split in {"train", "test"}
823 input = self.train_input if split == "train" else self.test_input
825 input = input[:nb_to_use]
827 desc = f"epoch-{split}"
828 for batch in tqdm.tqdm(
829 input.split(self.batch_size), dynamic_ncols=True, desc=desc
833 def vocabulary_size(self):
836 def produce_results(self, n_epoch, model):
837 with torch.autograd.no_grad():
841 def compute_nb_correct(input, prior_visits):
842 result = input.clone()
843 i = torch.arange(result.size(1), device=result.device)[None, :]
845 torch.logical_and(i >= self.prompt_length * 2, i % 2 == 0)
849 result *= 1 - ar_mask
851 # snake.solver(result,ar_mask)
853 masked_inplace_autoregression(
854 model, self.batch_size, result, ar_mask, device=self.device
857 nb_total = ((prior_visits > 0) * ar_mask).sum()
860 (result == input).long() * (prior_visits > 0) * ar_mask
863 # nb_total = result.size(0)
864 # nb_correct = ((result - input).abs().sum(1) == 0).sum()
866 return nb_total, nb_correct
868 # train_nb_total, train_nb_correct = compute_nb_correct(
869 # self.train_input, self.train_prior_visits
873 # f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
876 test_nb_total, test_nb_correct = compute_nb_correct(
877 self.test_input[:1000], self.test_prior_visits[:1000]
881 f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
887 ######################################################################
893 class TaskStack(Task):
902 fraction_values_for_train=None,
903 device=torch.device("cpu"),
905 self.batch_size = batch_size
906 self.nb_steps = nb_steps
907 self.nb_stacks = nb_stacks
908 self.nb_digits = nb_digits
911 if fraction_values_for_train is None:
912 values_for_train = None
913 values_for_test = None
915 all = torch.randperm(10**nb_digits)
916 nb_for_train = int(all.size(0) * fraction_values_for_train)
917 values_for_train = all[:nb_for_train]
918 values_for_test = all[nb_for_train:]
920 self.train_input, self.train_stack_counts = stack.generate_sequences(
929 self.test_input, self.test_stack_counts = stack.generate_sequences(
938 i = torch.logical_and(self.test_input % 2 == 1, self.test_input < 2 * nb_stacks)
939 counts = self.test_stack_counts.flatten()[i.flatten()]
940 counts = F.one_hot(counts).sum(0)
941 log_string(f"test_pop_stack_counts {counts}")
943 self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
945 def batches(self, split="train", nb_to_use=-1, desc=None):
946 assert split in {"train", "test"}
947 input = self.train_input if split == "train" else self.test_input
949 input = input[:nb_to_use]
951 desc = f"epoch-{split}"
952 for batch in tqdm.tqdm(
953 input.split(self.batch_size), dynamic_ncols=True, desc=desc
957 def vocabulary_size(self):
960 def produce_results(self, n_epoch, model):
961 with torch.autograd.no_grad():
965 def compute_nb_correct(input):
966 result = input.clone()
967 stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
968 ar_mask = (result != input).long()
969 masked_inplace_autoregression(
970 model, self.batch_size, result, ar_mask, device=self.device
973 errors = ((result != input).long() * ar_mask).reshape(
974 -1, 1 + self.nb_digits
976 ar_mask = ar_mask.reshape(-1, 1 + self.nb_digits)
978 nb_total = ar_mask.max(1).values.sum()
979 nb_correct = nb_total - errors.max(1).values.sum()
981 return nb_total, nb_correct
983 test_nb_total, test_nb_correct = compute_nb_correct(self.test_input[:1000])
986 f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
989 ##############################################################
990 # Log a few generated sequences
991 input = self.test_input[:10, : 12 * (1 + self.nb_digits)]
992 result = input.clone()
993 stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
994 ar_mask = (result != input).long()
995 for n in range(result.size(0)):
997 f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
999 masked_inplace_autoregression(
1000 model, self.batch_size, result, ar_mask, device=self.device
1002 for n in range(result.size(0)):
1004 f"test_after {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
1006 ##############################################################
1011 ######################################################################
1017 class TaskExpr(Task):
1025 device=torch.device("cpu"),
1027 self.batch_size = batch_size
1028 self.device = device
1030 train_sequences = expr.generate_sequences(
1032 nb_variables=nb_variables,
1033 length=sequence_length,
1034 # length=2 * sequence_length,
1035 # randomize_length=True,
1037 test_sequences = expr.generate_sequences(
1039 nb_variables=nb_variables,
1040 length=sequence_length,
1042 self.char2id = dict(
1045 for n, c in enumerate(
1046 set("#" + "".join(train_sequences + test_sequences))
1050 self.id2char = dict([(n, c) for c, n in self.char2id.items()])
1052 self.filler, self.space = self.char2id["#"], self.char2id[" "]
1054 len_max = max([len(x) for x in train_sequences])
1055 self.train_input = torch.cat(
1059 [self.char2id[c] for c in s + "#" * (len_max - len(s))]
1060 for s in train_sequences
1067 len_max = max([len(x) for x in test_sequences])
1068 self.test_input = torch.cat(
1072 [self.char2id[c] for c in s + "#" * (len_max - len(s))]
1073 for s in test_sequences
1080 self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1082 def batches(self, split="train", nb_to_use=-1, desc=None):
1083 assert split in {"train", "test"}
1084 input = self.train_input if split == "train" else self.test_input
1086 input = input[:nb_to_use]
1088 desc = f"epoch-{split}"
1089 for batch in tqdm.tqdm(
1090 input.split(self.batch_size), dynamic_ncols=True, desc=desc
1092 if split == "train":
1093 last = (batch != self.filler).max(0).values.nonzero().max() + 1
1094 batch = batch[:, :last]
1097 def vocabulary_size(self):
1098 return self.nb_codes
1100 def seq2str(self, s):
1101 return "".join([self.id2char[k.item()] for k in s])
1103 def produce_results(self, n_epoch, model):
1104 with torch.autograd.no_grad():
1108 def compute_nb_correct(input):
1109 result = input.clone()
1110 ar_mask = (result == self.space).long().cumsum(dim=1).clamp(max=1)
1111 result = (1 - ar_mask) * result + ar_mask * self.filler
1112 masked_inplace_autoregression(
1113 model, self.batch_size, result, ar_mask, device=self.device
1116 nb_total = input.size(0)
1117 nb_correct = (input == result).long().min(1).values.sum()
1119 #######################################################################
1120 # Comput predicted vs. true variable values
1122 nb_delta = torch.zeros(5, dtype=torch.int64)
1125 values_input = expr.extract_results([self.seq2str(s) for s in input])
1126 values_result = expr.extract_results([self.seq2str(s) for s in result])
1128 for i, r in zip(values_input, values_result):
1129 for n, vi in i.items():
1131 if vr is None or vr < 0:
1135 if d >= nb_delta.size(0):
1140 ######################################################################
1142 return nb_total, nb_correct, nb_delta, nb_missed
1144 test_nb_total, test_nb_correct, test_nb_delta, test_nb_missed = compute_nb_correct(self.test_input[:1000])
1147 f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
1150 nb_total = test_nb_delta.sum() + test_nb_missed
1151 for d in range(test_nb_delta.size(0)):
1152 log_string(f"error_value {n_epoch} delta {d} {test_nb_delta[d]} {test_nb_delta[d]*100/nb_total:.02f}%")
1153 log_string(f"error_value {n_epoch} missed {test_nb_missed} {test_nb_missed*100/nb_total:.02f}%")
1156 ##############################################################
1157 # Log a few generated sequences
1158 input = self.test_input[:10]
1159 result = input.clone()
1160 ar_mask = (result == self.space).long().cumsum(dim=1).clamp(max=1)
1161 result = (1 - ar_mask) * result + ar_mask * self.filler
1162 for n in range(result.size(0)):
1163 log_string(f"test_before {self.seq2str(result[n])}")
1164 masked_inplace_autoregression(
1165 model, self.batch_size, result, ar_mask, device=self.device
1167 correct = (1 - ar_mask) * self.space + ar_mask * input
1168 for n in range(result.size(0)):
1169 comment = "GOOD" if (result[n] - input[n]).abs().max() == 0 else ""
1170 log_string(f"test_after {self.seq2str(result[n])} {comment}")
1171 log_string(f"correct {self.seq2str(correct[n])}")
1172 ##############################################################
1177 ######################################################################
1180 def picoclvr_pruner_horizontal_green(p):
1181 return not ("green" in p and ("left" in p or "right" in p))
1184 picoclvr_pruner_train = (
1185 picoclvr_pruner_horizontal_green
1186 if args.picocvlr_prune_properties in {"train+eval"}
1190 picoclvr_pruner_eval = (
1191 (lambda p: not picoclvr_pruner_horizontal_green(p))
1192 if args.picocvlr_prune_properties in {"train+eval", "eval"}
1196 ######################################################################
1198 if args.task == "picoclvr":
1199 task = TaskPicoCLVR(
1200 nb_train_samples=args.nb_train_samples,
1201 nb_test_samples=args.nb_test_samples,
1202 batch_size=args.batch_size,
1203 height=args.picoclvr_height,
1204 width=args.picoclvr_width,
1205 nb_colors=args.picoclvr_nb_colors,
1207 pruner_train=picoclvr_pruner_train,
1208 pruner_eval=picoclvr_pruner_eval,
1211 elif args.task == "mnist":
1213 batch_size=args.batch_size,
1217 elif args.task == "maze":
1219 nb_train_samples=args.nb_train_samples,
1220 nb_test_samples=args.nb_test_samples,
1221 batch_size=args.batch_size,
1222 height=args.maze_height,
1223 width=args.maze_width,
1224 nb_walls=args.maze_nb_walls,
1228 elif args.task == "snake":
1230 nb_train_samples=args.nb_train_samples,
1231 nb_test_samples=args.nb_test_samples,
1232 batch_size=args.batch_size,
1233 height=args.snake_height,
1234 width=args.snake_width,
1235 nb_colors=args.snake_nb_colors,
1236 length=args.snake_length,
1237 prompt_length=args.snake_length // 2,
1241 elif args.task == "stack":
1243 nb_train_samples=args.nb_train_samples,
1244 nb_test_samples=args.nb_test_samples,
1245 batch_size=args.batch_size,
1246 nb_steps=args.stack_nb_steps,
1247 nb_stacks=args.stack_nb_stacks,
1248 nb_digits=args.stack_nb_digits,
1249 fraction_values_for_train=args.stack_fraction_values_for_train,
1253 elif args.task == "expr":
1255 nb_train_samples=args.nb_train_samples,
1256 nb_test_samples=args.nb_test_samples,
1257 nb_variables=args.expr_nb_variables,
1258 sequence_length=args.expr_sequence_length,
1259 batch_size=args.batch_size,
1264 raise ValueError(f"Unknown task {args.task}")
1266 ######################################################################
1268 log_string(f"device {device}")
1270 vocabulary_size = task.vocabulary_size()
1272 log_string(f"vocabulary_size {vocabulary_size}")
1274 ##############################
1276 model = mygpt.MyGPT(
1277 vocabulary_size=vocabulary_size,
1278 dim_model=args.dim_model,
1279 dim_keys=args.dim_keys,
1280 dim_hidden=args.dim_hidden,
1281 nb_heads=args.nb_heads,
1282 nb_blocks=args.nb_blocks,
1284 dropout=args.dropout,
1289 nb_parameters = sum(p.numel() for p in model.parameters())
1290 log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
1292 ######################################################################
1294 nb_epochs_finished = 0
1296 if args.no_checkpoint:
1297 log_string(f"not trying to load checkpoint.")
1301 checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
1302 checkpoint = torch.load(checkpoint_name)
1303 nb_epochs_finished = checkpoint["nb_epochs_finished"]
1304 model.load_state_dict(checkpoint["model_state"])
1305 torch.set_rng_state(checkpoint["rng_state"])
1306 if torch.cuda.is_available():
1307 torch.cuda.set_rng_state(checkpoint["cuda_rng_state"])
1309 log_string(f"checkpoint loaded with {nb_epochs_finished} epochs finished.")
1311 except FileNotFoundError:
1312 log_string("starting from scratch.")
1315 log_string("error when loading the checkpoint.")
1318 ######################################################################
1320 nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
1323 for input in task.batches(split="train"):
1324 token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1))
1325 token_probas = token_count / token_count.sum()
1326 entropy = -torch.xlogy(token_probas, token_probas).sum()
1327 train_set_perplexity = math.exp(entropy)
1329 ##############################
1331 if args.learning_rate_schedule == "cos":
1332 learning_rate_schedule = {}
1333 for n_epoch in range(args.nb_epochs):
1334 u = n_epoch / args.nb_epochs * math.pi
1335 learning_rate_schedule[n_epoch] = args.learning_rate * 0.5 * (1 + math.cos(u))
1340 tuple(x.split(":")) for x in args.learning_rate_schedule.split(",")
1344 learning_rate_schedule = {}
1345 learning_rate = args.learning_rate
1346 for n_epoch in range(args.nb_epochs):
1348 learning_rate = u[n_epoch]
1349 learning_rate_schedule[n_epoch] = learning_rate
1351 log_string(f"learning_rate_schedule {learning_rate_schedule}")
1353 ##############################
1357 if nb_epochs_finished >= nb_epochs:
1358 task.produce_results(nb_epochs_finished, model)
1360 for n_epoch in range(nb_epochs_finished, nb_epochs):
1361 learning_rate = learning_rate_schedule[n_epoch]
1363 log_string(f"learning_rate {learning_rate}")
1365 if args.optim == "sgd":
1366 optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
1367 elif args.optim == "adam":
1368 optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
1369 elif args.optim == "adamw":
1370 optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
1372 raise ValueError(f"Unknown optimizer {args.optim}.")
1376 nb_train_samples, acc_train_loss = 0, 0.0
1378 for input in task.batches(split="train"):
1379 input = input.to(device)
1380 output = model(mygpt.BracketedSequence(input)).x
1381 loss = F.cross_entropy(output.transpose(1, 2), input)
1382 acc_train_loss += loss.item() * input.size(0)
1383 nb_train_samples += input.size(0)
1384 nb_samples_seen += input.size(0)
1386 optimizer.zero_grad()
1390 with torch.autograd.no_grad():
1393 nb_test_samples, acc_test_loss = 0, 0.0
1395 for input in task.batches(split="test"):
1396 input = input.to(device)
1398 output = model(mygpt.BracketedSequence(input)).x
1399 loss = F.cross_entropy(output.transpose(1, 2), input)
1400 acc_test_loss += loss.item() * input.size(0)
1401 nb_test_samples += input.size(0)
1403 train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
1404 test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
1407 f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
1410 task.produce_results(n_epoch, model)
1413 "nb_epochs_finished": n_epoch + 1,
1414 "model_state": model.state_dict(),
1415 "rng_state": torch.get_rng_state(),
1418 if torch.cuda.is_available():
1419 checkpoint["cuda_rng_state"] = torch.cuda.get_rng_state()
1421 checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
1422 torch.save(checkpoint, checkpoint_name)
1423 log_string(f"saved checkpoint {checkpoint_name}")
1425 ######################################################################