3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, argparse, time, tqdm, os, datetime, warnings
10 import torch, torchvision
12 from torch.nn import functional as F
15 import mygpt, tasks, problems
17 ######################################################################
19 if torch.cuda.is_available():
20 device = torch.device("cuda")
21 torch.backends.cuda.matmul.allow_tf32 = True
23 device = torch.device("cpu")
25 ######################################################################
30 if x in {"1", "true", "yes"}:
32 elif x in {"0", "false", "no"}:
38 parser = argparse.ArgumentParser(
39 description="An implementation of GPT with cache.",
40 formatter_class=argparse.ArgumentDefaultsHelpFormatter,
47 help="byheart, learnop, guessop, mixing, memory, twotargets, addition, picoclvr, mnist, maze, snake, stack, expr, rpl, grid, qmlp",
50 parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
52 parser.add_argument("--result_dir", type=str, default=None)
54 parser.add_argument("--seed", type=int, default=0)
56 parser.add_argument("--max_percents_of_test_in_train", type=int, default=1)
58 ########################################
60 parser.add_argument("--nb_epochs", type=int, default=50)
62 parser.add_argument("--batch_size", type=int, default=None)
64 parser.add_argument("--nb_train_samples", type=int, default=None)
66 parser.add_argument("--nb_test_samples", type=int, default=None)
68 parser.add_argument("--optim", type=str, default="adam")
70 ########################################
72 parser.add_argument("--nb_warmup_iter", type=int, default=100)
74 parser.add_argument("--nb_decay_iter", type=int, default=5000)
76 parser.add_argument("--learning_rate", type=float, default=6e-4)
78 parser.add_argument("--min_learning_rate", type=float, default=6e-5)
82 parser.add_argument("--legacy_lr_schedule", type=str2bool, default=True)
84 parser.add_argument("--legacy_large_lr", type=float, default=1e-4)
86 parser.add_argument("--legacy_small_lr", type=float, default=2e-5)
88 parser.add_argument("--legacy_nb_epoch_large_lr", type=float, default=10)
90 ########################################
92 parser.add_argument("--model", type=str, default=None)
94 parser.add_argument("--attention", type=str, default=None)
96 parser.add_argument("--dim_model", type=int, default=None)
98 parser.add_argument("--dim_keys", type=int, default=None)
100 parser.add_argument("--dim_hidden", type=int, default=None)
102 parser.add_argument("--nb_heads", type=int, default=None)
104 parser.add_argument("--nb_lines", type=int, default=None)
106 parser.add_argument("--caterpillar_height", type=int, default=None)
108 parser.add_argument("--rho", type=float, default=0.0)
110 parser.add_argument("--nb_blocks", type=int, default=None)
112 parser.add_argument("--dropout", type=float, default=0.1)
114 ########################################
116 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
118 parser.add_argument("--no_checkpoint", action="store_true", default=False)
120 parser.add_argument("--continue_training", action="store_true", default=False)
122 parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
124 ##############################
127 parser.add_argument("--rpl_nb_starting_values", type=int, default=3)
129 parser.add_argument("--rpl_max_input", type=int, default=9)
131 parser.add_argument("--rpl_prog_len", type=int, default=8)
133 parser.add_argument("--rpl_nb_runs", type=int, default=5)
135 parser.add_argument("--rpl_no_prog", action="store_true", default=False)
137 ##############################
140 parser.add_argument("--grid_size", type=int, default=6)
142 ##############################
145 parser.add_argument("--picoclvr_nb_colors", type=int, default=5)
147 parser.add_argument("--picoclvr_height", type=int, default=12)
149 parser.add_argument("--picoclvr_width", type=int, default=16)
151 parser.add_argument("--picocvlr_prune_properties", type=str, default="none")
153 ##############################
156 parser.add_argument("--maze_height", type=int, default=13)
158 parser.add_argument("--maze_width", type=int, default=21)
160 parser.add_argument("--maze_nb_walls", type=int, default=15)
162 ##############################
165 parser.add_argument("--snake_height", type=int, default=9)
167 parser.add_argument("--snake_width", type=int, default=12)
169 parser.add_argument("--snake_nb_colors", type=int, default=5)
171 parser.add_argument("--snake_length", type=int, default=200)
173 ##############################
176 parser.add_argument("--stack_nb_steps", type=int, default=100)
178 parser.add_argument("--stack_nb_stacks", type=int, default=3)
180 parser.add_argument("--stack_nb_digits", type=int, default=3)
182 parser.add_argument("--stack_fraction_values_for_train", type=float, default=0.75)
184 ##############################
187 parser.add_argument("--expr_nb_variables", type=int, default=5)
189 parser.add_argument("--expr_sequence_length", type=int, default=40)
191 parser.add_argument("--expr_operand_max", type=int, default=9)
193 parser.add_argument("--expr_result_max", type=int, default=99)
195 parser.add_argument("--expr_input_file", type=str, default=None)
197 ##############################
200 parser.add_argument("--memory_len_total", type=int, default=32)
202 ##############################
205 parser.add_argument("--mixing_hard", action="store_true", default=False)
207 parser.add_argument("--mixing_deterministic_start", action="store_true", default=False)
209 ######################################################################
211 args = parser.parse_args()
213 assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"}
215 if args.result_dir is None:
216 args.result_dir = f"results_{args.task}_{args.model}"
218 ######################################################################
220 default_task_args = {
224 "nb_train_samples": 250000,
225 "nb_test_samples": 10000,
230 "nb_train_samples": 50000,
231 "nb_test_samples": 10000,
236 "nb_train_samples": 2500000,
237 "nb_test_samples": 10000,
242 "nb_train_samples": 250000,
243 "nb_test_samples": 10000,
248 "nb_train_samples": 100000,
249 "nb_test_samples": 1000,
254 "nb_train_samples": 1000000,
255 "nb_test_samples": 10000,
260 "nb_train_samples": 50000,
261 "nb_test_samples": 10000,
266 "nb_train_samples": 100000,
267 "nb_test_samples": 10000,
272 "nb_train_samples": 250000,
273 "nb_test_samples": 10000,
278 "nb_train_samples": 2500000,
279 "nb_test_samples": 10000,
284 "nb_train_samples": 250000,
285 "nb_test_samples": 10000,
290 "nb_train_samples": 100000,
291 "nb_test_samples": 1000,
296 "nb_train_samples": 50000,
297 "nb_test_samples": 10000,
302 "nb_train_samples": 25000,
303 "nb_test_samples": 10000,
308 "nb_train_samples": 250000,
309 "nb_test_samples": 10000,
314 "nb_train_samples": 60000,
315 "nb_test_samples": 10000,
319 if args.task in default_task_args:
320 for k, v in default_task_args[args.task].items():
321 if getattr(args, k) is None:
324 ######################################################################
326 default_model_args = {
336 "attention": "caterpillar",
342 "caterpillar_height": 4,
354 "attention": "caterpillar",
360 "caterpillar_height": 4,
372 "attention": "caterpillar",
378 "caterpillar_height": 32,
390 "attention": "caterpillar",
407 "attention": "caterpillar",
417 if args.model in default_model_args:
418 for k, v in default_model_args[args.model].items():
419 if getattr(args, k) is None:
422 raise ValueError(f"Unknown model {args.model}")
424 ######################################################################
427 os.mkdir(args.result_dir)
428 except FileExistsError:
429 if not args.continue_training:
430 print(f"result directory {args.result_dir} already exists")
433 log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
436 # torch.backends.cudnn.deterministic = True
437 # torch.backends.cudnn.benchmark = False
438 # torch.use_deterministic_algorithms(True)
439 torch.manual_seed(args.seed)
440 if torch.cuda.is_available():
441 torch.cuda.manual_seed_all(args.seed)
443 ######################################################################
447 t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
449 if log_file is not None:
450 log_file.write(t + s + "\n")
457 with os.popen("sha256sum *.py") as f:
459 log_string(f"sha256sum {l.strip()}")
461 now = time.strftime("%Y%m%d-%H%M%S", time.localtime())
462 os.system(f"tar zcvf {args.result_dir}/src-{now}.tgz *.py *.sh")
464 log_string(f"argv {' '.join(sys.argv)}")
467 log_string(f"args.{n} {getattr(args, n)}")
470 ######################################################################
473 def get_lr(n_epoch, it):
474 if args.legacy_lr_schedule:
475 # my crude scheduling to compare to previous baseline, added
478 if it < args.nb_warmup_iter:
479 return args.legacy_large_lr * it / args.nb_warmup_iter
480 elif n_epoch < args.legacy_nb_epoch_large_lr:
481 return args.legacy_large_lr
483 return args.legacy_small_lr
487 # 1) linear warmup for warmup_iter steps
488 if it < args.nb_warmup_iter:
489 return args.learning_rate * it / args.nb_warmup_iter
490 # 2) if it > nb_decay_iter, return min learning rate
491 if it > args.nb_decay_iter:
492 return args.min_learning_rate
493 # 3) in between, use cosine decay down to min learning rate
494 decay_ratio = (it - args.nb_warmup_iter) / (
495 args.nb_decay_iter - args.nb_warmup_iter
497 coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
498 return args.min_learning_rate + coeff * (
499 args.learning_rate - args.min_learning_rate
503 ######################################################################
506 def picoclvr_pruner_horizontal_green(p):
507 return not ("green" in p and ("left" in p or "right" in p))
510 picoclvr_pruner_train = (
511 picoclvr_pruner_horizontal_green
512 if args.picocvlr_prune_properties in {"train+eval"}
516 picoclvr_pruner_eval = (
517 (lambda p: not picoclvr_pruner_horizontal_green(p))
518 if args.picocvlr_prune_properties in {"train+eval", "eval"}
522 ######################################################################
526 if args.task == "byheart":
527 task = tasks.SandBox(
528 problem=problems.ProblemByHeart(),
529 nb_train_samples=args.nb_train_samples,
530 nb_test_samples=args.nb_test_samples,
531 batch_size=args.batch_size,
535 args.max_percents_of_test_in_train = -1
537 elif args.task == "learnop":
538 task = tasks.SandBox(
539 problem=problems.ProblemLearnOperator(),
540 nb_train_samples=args.nb_train_samples,
541 nb_test_samples=args.nb_test_samples,
542 batch_size=args.batch_size,
548 elif args.task == "guessop":
549 task = tasks.SandBox(
550 problem=problems.ProblemGuessOperator(),
551 nb_train_samples=args.nb_train_samples,
552 nb_test_samples=args.nb_test_samples,
553 batch_size=args.batch_size,
559 elif args.task == "twotargets":
560 task = tasks.SandBox(
561 problem=problems.ProblemTwoTargets(),
562 nb_train_samples=args.nb_train_samples,
563 nb_test_samples=args.nb_test_samples,
564 batch_size=args.batch_size,
569 elif args.task == "memory":
570 task = tasks.SandBox(
571 problem=problems.ProblemMemory(len_total=args.memory_len_total),
572 nb_train_samples=args.nb_train_samples,
573 nb_test_samples=args.nb_test_samples,
574 batch_size=args.batch_size,
579 elif args.task == "mixing":
580 task = tasks.SandBox(
581 problem=problems.ProblemMixing(
582 hard=args.mixing_hard, random_start=not args.mixing_deterministic_start
584 nb_train_samples=args.nb_train_samples,
585 nb_test_samples=args.nb_test_samples,
586 batch_size=args.batch_size,
591 elif args.task == "addition":
592 task = tasks.SandBox(
593 problem=problems.ProblemAddition(),
594 nb_train_samples=args.nb_train_samples,
595 nb_test_samples=args.nb_test_samples,
596 batch_size=args.batch_size,
601 elif args.task == "picoclvr":
602 task = tasks.PicoCLVR(
603 nb_train_samples=args.nb_train_samples,
604 nb_test_samples=args.nb_test_samples,
605 batch_size=args.batch_size,
606 height=args.picoclvr_height,
607 width=args.picoclvr_width,
608 nb_colors=args.picoclvr_nb_colors,
611 pruner_train=picoclvr_pruner_train,
612 pruner_eval=picoclvr_pruner_eval,
615 elif args.task == "mnist":
617 nb_train_samples=args.nb_train_samples,
618 nb_test_samples=args.nb_test_samples,
619 batch_size=args.batch_size,
623 elif args.task == "maze":
625 nb_train_samples=args.nb_train_samples,
626 nb_test_samples=args.nb_test_samples,
627 batch_size=args.batch_size,
628 height=args.maze_height,
629 width=args.maze_width,
630 nb_walls=args.maze_nb_walls,
634 elif args.task == "snake":
636 nb_train_samples=args.nb_train_samples,
637 nb_test_samples=args.nb_test_samples,
638 batch_size=args.batch_size,
639 height=args.snake_height,
640 width=args.snake_width,
641 nb_colors=args.snake_nb_colors,
642 length=args.snake_length,
643 prompt_length=args.snake_length // 2,
647 elif args.task == "stack":
649 nb_train_samples=args.nb_train_samples,
650 nb_test_samples=args.nb_test_samples,
651 batch_size=args.batch_size,
653 nb_steps=args.stack_nb_steps,
654 nb_stacks=args.stack_nb_stacks,
655 nb_digits=args.stack_nb_digits,
656 fraction_values_for_train=args.stack_fraction_values_for_train,
660 elif args.task == "expr":
662 nb_train_samples=args.nb_train_samples,
663 nb_test_samples=args.nb_test_samples,
664 nb_variables=args.expr_nb_variables,
665 sequence_length=args.expr_sequence_length,
666 operand_max=args.expr_operand_max,
667 result_max=args.expr_result_max,
668 batch_size=args.batch_size,
672 elif args.task == "rpl":
674 nb_train_samples=args.nb_train_samples,
675 nb_test_samples=args.nb_test_samples,
676 batch_size=args.batch_size,
677 nb_starting_values=args.rpl_nb_starting_values,
678 max_input=args.rpl_max_input,
679 prog_len=args.rpl_prog_len,
680 nb_runs=args.rpl_nb_runs,
681 no_prog=args.rpl_no_prog,
686 elif args.task == "grid":
688 nb_train_samples=args.nb_train_samples,
689 nb_test_samples=args.nb_test_samples,
690 batch_size=args.batch_size,
696 elif args.task == "qmlp":
698 nb_train_samples=args.nb_train_samples,
699 nb_test_samples=args.nb_test_samples,
700 batch_size=args.batch_size,
701 result_dir=args.result_dir,
707 raise ValueError(f"Unknown task {args.task}")
709 ######################################################################
711 log_string(f"device {device}")
713 vocabulary_size = task.vocabulary_size()
715 log_string(f"vocabulary_size {vocabulary_size}")
717 ##############################
720 vocabulary_size=vocabulary_size,
721 dim_model=args.dim_model,
722 dim_keys=args.dim_keys,
723 dim_hidden=args.dim_hidden,
724 nb_heads=args.nb_heads,
725 nb_lines=args.nb_lines,
726 caterpillar_height=args.caterpillar_height,
727 nb_blocks=args.nb_blocks,
729 dropout=args.dropout,
730 attention_layer=args.attention,
735 nb_parameters = sum(p.numel() for p in model.parameters())
736 log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
738 ######################################################################
740 nb_epochs_finished = 0
742 if args.no_checkpoint:
743 log_string(f"not trying to load checkpoint.")
747 checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
748 checkpoint = torch.load(checkpoint_name)
749 nb_epochs_finished = checkpoint["nb_epochs_finished"]
750 model.load_state_dict(checkpoint["model_state"])
751 torch.set_rng_state(checkpoint["rng_state"])
752 if torch.cuda.is_available():
753 torch.cuda.set_rng_state(checkpoint["cuda_rng_state"])
755 log_string(f"checkpoint loaded with {nb_epochs_finished} epochs finished.")
757 except FileNotFoundError:
758 log_string("starting from scratch.")
761 log_string("error when loading the checkpoint.")
764 ######################################################################
766 if args.task == "expr" and args.expr_input_file is not None:
767 task.produce_results(
768 n_epoch=nb_epochs_finished,
770 result_dir=args.result_dir,
772 deterministic_synthesis=args.deterministic_synthesis,
773 input_file=args.expr_input_file,
778 ######################################################################
780 nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
782 # Compute the entropy of the training tokens
785 for input in task.batches(split="train"):
786 token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1))
787 token_probas = token_count / token_count.sum()
788 entropy = -torch.xlogy(token_probas, token_probas).sum()
789 train_set_perplexity = math.exp(entropy)
791 ######################################################################
792 # A bit of paranoia never hurts
794 if args.max_percents_of_test_in_train >= 0:
796 def subsets_as_tuples(batches, cs):
798 for batch in batches:
800 s.add(tuple([v.item() for v in x]))
806 nb_test, nb_in_train = 0, 0
807 for test_subset in subsets_as_tuples(task.batches(split="test"), 25000):
809 for train_subset in subsets_as_tuples(task.batches(split="train"), 25000):
810 in_train.update(test_subset.intersection(train_subset))
811 nb_in_train += len(in_train)
812 nb_test += len(test_subset)
815 f"data_check {nb_in_train*100/nb_test:.02f}% ({nb_in_train}/{nb_test}) of test samples are in the train set"
819 nb_in_train <= args.max_percents_of_test_in_train * nb_test / 100
820 ), f"More than {args.max_percents_of_test_in_train}% of test samples are in the train set"
822 ##############################
826 if nb_epochs_finished >= nb_epochs:
827 task.produce_results(
828 n_epoch=nb_epochs_finished,
830 result_dir=args.result_dir,
832 deterministic_synthesis=args.deterministic_synthesis,
835 time_pred_result = None
839 for n_epoch in range(nb_epochs_finished, nb_epochs):
840 if args.optim == "sgd":
841 optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate)
842 elif args.optim == "adam":
843 optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
844 elif args.optim == "adamw":
845 optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)
847 raise ValueError(f"Unknown optimizer {args.optim}.")
851 nb_train_samples, acc_train_loss, acc_train_inner_loss = 0, 0.0, 0.0
853 for input in task.batches(split="train"):
854 model.reset_inner_loss()
855 input = input.to(device)
857 output = model(mygpt.BracketedSequence(input)).x
858 loss = F.cross_entropy(output.transpose(1, 2), input)
859 inner_loss = model.get_inner_loss()
861 acc_train_loss += loss.item() * input.size(0)
862 acc_train_inner_loss += inner_loss.item() * input.size(0)
864 nb_train_samples += input.size(0)
865 nb_samples_seen += input.size(0)
867 total_loss = loss + (args.rho * inner_loss if args.rho > 0 else 0.0)
870 lr = get_lr(n_epoch, it)
871 for param_group in optimizer.param_groups:
872 param_group["lr"] = lr
874 # log_string(f"learning_rate {lr}")
876 optimizer.zero_grad()
877 total_loss.backward()
880 with torch.autograd.no_grad():
883 nb_test_samples, acc_test_loss = 0, 0.0
885 for input in task.batches(split="test"):
886 input = input.to(device)
888 output = model(mygpt.BracketedSequence(input)).x
889 loss = F.cross_entropy(output.transpose(1, 2), input)
890 acc_test_loss += loss.item() * input.size(0)
891 nb_test_samples += input.size(0)
894 f"loss {n_epoch} train_loss {acc_train_loss/nb_train_samples} train_inner_loss {acc_train_inner_loss/nb_train_samples} test_prediction {acc_test_loss/nb_test_samples}"
897 task.produce_results(
900 result_dir=args.result_dir,
902 deterministic_synthesis=args.deterministic_synthesis,
905 train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
906 test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
909 f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
912 time_current_result = datetime.datetime.now()
913 if time_pred_result is not None:
915 f"next_result {time_current_result + (time_current_result - time_pred_result)}"
917 time_pred_result = time_current_result
920 "nb_epochs_finished": n_epoch + 1,
921 "model_state": model.state_dict(),
922 "rng_state": torch.get_rng_state(),
925 if torch.cuda.is_available():
926 checkpoint["cuda_rng_state"] = torch.cuda.get_rng_state()
928 checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
929 torch.save(checkpoint, checkpoint_name)
930 log_string(f"saved checkpoint {checkpoint_name}")
932 ######################################################################