3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, argparse, time, tqdm, os, datetime, warnings
10 import torch, torchvision
12 from torch.nn import functional as F
15 import mygpt, tasks, problems
17 ######################################################################
22 if x in {"1", "true", "yes"}:
24 elif x in {"0", "false", "no"}:
30 parser = argparse.ArgumentParser(
31 description="An implementation of GPT with cache.",
32 formatter_class=argparse.ArgumentDefaultsHelpFormatter,
39 help="byheart, learnop, guessop, mixing, memory, twotargets, addition, picoclvr, mnist, maze, snake, stack, expr, rpl, grid, qmlp",
42 parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
44 parser.add_argument("--result_dir", type=str, default=None)
46 parser.add_argument("--seed", type=int, default=0)
48 parser.add_argument("--max_percents_of_test_in_train", type=int, default=1)
50 parser.add_argument("--force_cpu", type=str2bool, default=False)
52 ########################################
54 parser.add_argument("--nb_epochs", type=int, default=50)
56 parser.add_argument("--batch_size", type=int, default=None)
58 parser.add_argument("--nb_train_samples", type=int, default=None)
60 parser.add_argument("--nb_test_samples", type=int, default=None)
62 parser.add_argument("--optim", type=str, default="adam")
64 ########################################
66 parser.add_argument("--nb_warmup_iter", type=int, default=100)
68 parser.add_argument("--nb_decay_iter", type=int, default=5000)
70 parser.add_argument("--learning_rate", type=float, default=6e-4)
72 parser.add_argument("--min_learning_rate", type=float, default=6e-5)
76 parser.add_argument("--legacy_lr_schedule", type=str2bool, default=True)
78 parser.add_argument("--legacy_large_lr", type=float, default=1e-4)
80 parser.add_argument("--legacy_small_lr", type=float, default=2e-5)
82 parser.add_argument("--legacy_nb_epoch_large_lr", type=float, default=10)
84 ########################################
86 parser.add_argument("--model", type=str, default=None)
88 parser.add_argument("--attention", type=str, default=None)
90 parser.add_argument("--proportion_memex", type=float, default=0)
92 parser.add_argument("--dim_model", type=int, default=None)
94 parser.add_argument("--dim_keys", type=int, default=None)
96 parser.add_argument("--dim_hidden", type=int, default=None)
98 parser.add_argument("--nb_heads", type=int, default=None)
100 parser.add_argument("--nb_lines", type=int, default=None)
102 parser.add_argument("--caterpillar_height", type=int, default=None)
104 parser.add_argument("--gate_dropout_proba", type=float, default=0.0)
106 parser.add_argument("--gate_dropout_sync", type=str2bool, default=False)
108 parser.add_argument("--gate_dropout_replace", type=str2bool, default=False)
110 parser.add_argument("--rho_inner_loss", type=float, default=0.0)
112 parser.add_argument("--nb_blocks", type=int, default=None)
114 parser.add_argument("--dropout", type=float, default=0.1)
116 ########################################
118 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
120 parser.add_argument("--no_checkpoint", action="store_true", default=False)
122 parser.add_argument("--continue_training", action="store_true", default=False)
124 parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
126 ##############################
129 parser.add_argument("--rpl_nb_starting_values", type=int, default=3)
131 parser.add_argument("--rpl_max_input", type=int, default=9)
133 parser.add_argument("--rpl_prog_len", type=int, default=8)
135 parser.add_argument("--rpl_nb_runs", type=int, default=5)
137 parser.add_argument("--rpl_no_prog", action="store_true", default=False)
139 ##############################
142 parser.add_argument("--grid_size", type=int, default=6)
144 parser.add_argument("--grid_nb_colors", type=int, default=6)
146 parser.add_argument("--grid_nb_shapes", type=int, default=6)
148 ##############################
151 parser.add_argument("--picoclvr_nb_colors", type=int, default=5)
153 parser.add_argument("--picoclvr_height", type=int, default=12)
155 parser.add_argument("--picoclvr_width", type=int, default=16)
157 parser.add_argument("--picocvlr_prune_properties", type=str, default="none")
159 ##############################
162 parser.add_argument("--maze_height", type=int, default=13)
164 parser.add_argument("--maze_width", type=int, default=21)
166 parser.add_argument("--maze_nb_walls", type=int, default=15)
168 ##############################
171 parser.add_argument("--snake_height", type=int, default=9)
173 parser.add_argument("--snake_width", type=int, default=12)
175 parser.add_argument("--snake_nb_colors", type=int, default=5)
177 parser.add_argument("--snake_length", type=int, default=200)
179 ##############################
182 parser.add_argument("--stack_nb_steps", type=int, default=100)
184 parser.add_argument("--stack_nb_stacks", type=int, default=3)
186 parser.add_argument("--stack_nb_digits", type=int, default=3)
188 parser.add_argument("--stack_fraction_values_for_train", type=float, default=0.75)
190 ##############################
193 parser.add_argument("--expr_nb_variables", type=int, default=5)
195 parser.add_argument("--expr_sequence_length", type=int, default=40)
197 parser.add_argument("--expr_operand_max", type=int, default=9)
199 parser.add_argument("--expr_result_max", type=int, default=99)
201 parser.add_argument("--expr_input_file", type=str, default=None)
203 ##############################
206 parser.add_argument("--memory_len_total", type=int, default=32)
208 ##############################
211 parser.add_argument("--mixing_hard", action="store_true", default=False)
213 parser.add_argument("--mixing_deterministic_start", action="store_true", default=False)
215 ######################################################################
217 # args = parser.parse_args()
219 args, sup_args = parser.parse_known_args()
221 sup_args = dict([x.removeprefix("--").split("=") for x in sup_args])
223 if args.result_dir is None:
224 args.result_dir = f"results_{args.task}_{args.model}"
226 ######################################################################
228 if not args.force_cpu and torch.cuda.is_available():
229 device = torch.device("cuda")
230 torch.backends.cuda.matmul.allow_tf32 = True
232 device = torch.device("cpu")
234 ######################################################################
236 default_task_args = {
240 "nb_train_samples": 250000,
241 "nb_test_samples": 10000,
246 "nb_train_samples": 50000,
247 "nb_test_samples": 10000,
252 "nb_train_samples": 2500000,
253 "nb_test_samples": 10000,
258 "nb_train_samples": 250000,
259 "nb_test_samples": 10000,
264 "nb_train_samples": 100000,
265 "nb_test_samples": 1000,
270 "nb_train_samples": 1000000,
271 "nb_test_samples": 10000,
276 "nb_train_samples": 50000,
277 "nb_test_samples": 10000,
282 "nb_train_samples": 100000,
283 "nb_test_samples": 10000,
288 "nb_train_samples": 250000,
289 "nb_test_samples": 10000,
294 "nb_train_samples": 2500000,
295 "nb_test_samples": 10000,
300 "nb_train_samples": 250000,
301 "nb_test_samples": 10000,
306 "nb_train_samples": 100000,
307 "nb_test_samples": 1000,
312 "nb_train_samples": 50000,
313 "nb_test_samples": 10000,
318 "nb_train_samples": 25000,
319 "nb_test_samples": 10000,
324 "nb_train_samples": 250000,
325 "nb_test_samples": 10000,
330 "nb_train_samples": 60000,
331 "nb_test_samples": 10000,
335 if args.task in default_task_args:
336 for k, v in default_task_args[args.task].items():
337 if getattr(args, k) is None:
340 ######################################################################
342 default_model_args = {
352 "attention": "caterpillar",
358 "caterpillar_height": 4,
370 "attention": "caterpillar",
376 "caterpillar_height": 4,
388 "attention": "caterpillar",
394 "caterpillar_height": 32,
406 "attention": "caterpillar",
423 "attention": "caterpillar",
433 if args.model in default_model_args:
434 for k, v in default_model_args[args.model].items():
435 if getattr(args, k) is None:
438 raise ValueError(f"Unknown model {args.model}")
440 ######################################################################
443 os.mkdir(args.result_dir)
444 except FileExistsError:
445 if not args.continue_training:
446 print(f"result directory {args.result_dir} already exists")
449 loss_file = open(os.path.join(args.result_dir, "loss.dat"), "a")
451 log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
454 # torch.backends.cudnn.deterministic = True
455 # torch.backends.cudnn.benchmark = False
456 # torch.use_deterministic_algorithms(True)
457 torch.manual_seed(args.seed)
458 if torch.cuda.is_available():
459 torch.cuda.manual_seed_all(args.seed)
461 ######################################################################
465 t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
467 if log_file is not None:
468 log_file.write(t + s + "\n")
475 with os.popen("sha256sum *.py") as f:
477 log_string(f"sha256sum {l.strip()}")
479 now = time.strftime("%Y%m%d-%H%M%S", time.localtime())
480 os.system(f"tar zcvf {args.result_dir}/src-{now}.tgz *.py *.sh")
482 log_string(f"argv {' '.join(sys.argv)}")
485 log_string(f"args.{n} {getattr(args, n)}")
487 for k, v in sup_args.items():
488 log_string(f'sup_args["{k}"] "{v}"')
491 ######################################################################
494 def get_lr(n_epoch, it):
495 if args.legacy_lr_schedule:
496 # my crude scheduling to compare to previous baseline, added
499 if it < args.nb_warmup_iter:
500 return args.legacy_large_lr * it / args.nb_warmup_iter
501 elif n_epoch < args.legacy_nb_epoch_large_lr:
502 return args.legacy_large_lr
504 return args.legacy_small_lr
508 # 1) linear warmup for warmup_iter steps
509 if it < args.nb_warmup_iter:
510 return args.learning_rate * it / args.nb_warmup_iter
511 # 2) if it > nb_decay_iter, return min learning rate
512 if it > args.nb_decay_iter:
513 return args.min_learning_rate
514 # 3) in between, use cosine decay down to min learning rate
515 decay_ratio = (it - args.nb_warmup_iter) / (
516 args.nb_decay_iter - args.nb_warmup_iter
518 coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
519 return args.min_learning_rate + coeff * (
520 args.learning_rate - args.min_learning_rate
524 ######################################################################
527 assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"}
530 def picoclvr_pruner_horizontal_green(p):
531 return not ("green" in p and ("left" in p or "right" in p))
534 picoclvr_pruner_train = (
535 picoclvr_pruner_horizontal_green
536 if args.picocvlr_prune_properties in {"train+eval"}
540 picoclvr_pruner_eval = (
541 (lambda p: not picoclvr_pruner_horizontal_green(p))
542 if args.picocvlr_prune_properties in {"train+eval", "eval"}
546 ######################################################################
550 if args.task == "byheart":
551 task = tasks.SandBox(
552 problem=problems.ProblemByHeart(),
553 nb_train_samples=args.nb_train_samples,
554 nb_test_samples=args.nb_test_samples,
555 batch_size=args.batch_size,
559 args.max_percents_of_test_in_train = -1
561 elif args.task == "learnop":
562 task = tasks.SandBox(
563 problem=problems.ProblemLearnOperator(),
564 nb_train_samples=args.nb_train_samples,
565 nb_test_samples=args.nb_test_samples,
566 batch_size=args.batch_size,
572 elif args.task == "guessop":
573 task = tasks.SandBox(
574 problem=problems.ProblemGuessOperator(),
575 nb_train_samples=args.nb_train_samples,
576 nb_test_samples=args.nb_test_samples,
577 batch_size=args.batch_size,
583 elif args.task == "twotargets":
584 task = tasks.SandBox(
585 problem=problems.ProblemTwoTargets(),
586 nb_train_samples=args.nb_train_samples,
587 nb_test_samples=args.nb_test_samples,
588 batch_size=args.batch_size,
593 elif args.task == "memory":
594 task = tasks.SandBox(
595 problem=problems.ProblemMemory(len_total=args.memory_len_total),
596 nb_train_samples=args.nb_train_samples,
597 nb_test_samples=args.nb_test_samples,
598 batch_size=args.batch_size,
603 elif args.task == "mixing":
604 task = tasks.SandBox(
605 problem=problems.ProblemMixing(
606 hard=args.mixing_hard, random_start=not args.mixing_deterministic_start
608 nb_train_samples=args.nb_train_samples,
609 nb_test_samples=args.nb_test_samples,
610 batch_size=args.batch_size,
615 elif args.task == "addition":
616 task = tasks.SandBox(
617 problem=problems.ProblemAddition(),
618 nb_train_samples=args.nb_train_samples,
619 nb_test_samples=args.nb_test_samples,
620 batch_size=args.batch_size,
625 elif args.task == "picoclvr":
626 task = tasks.PicoCLVR(
627 nb_train_samples=args.nb_train_samples,
628 nb_test_samples=args.nb_test_samples,
629 batch_size=args.batch_size,
630 height=args.picoclvr_height,
631 width=args.picoclvr_width,
632 nb_colors=args.picoclvr_nb_colors,
635 pruner_train=picoclvr_pruner_train,
636 pruner_eval=picoclvr_pruner_eval,
639 elif args.task == "mnist":
641 nb_train_samples=args.nb_train_samples,
642 nb_test_samples=args.nb_test_samples,
643 batch_size=args.batch_size,
647 elif args.task == "maze":
649 nb_train_samples=args.nb_train_samples,
650 nb_test_samples=args.nb_test_samples,
651 batch_size=args.batch_size,
652 height=args.maze_height,
653 width=args.maze_width,
654 nb_walls=args.maze_nb_walls,
658 elif args.task == "snake":
660 nb_train_samples=args.nb_train_samples,
661 nb_test_samples=args.nb_test_samples,
662 batch_size=args.batch_size,
663 height=args.snake_height,
664 width=args.snake_width,
665 nb_colors=args.snake_nb_colors,
666 length=args.snake_length,
667 prompt_length=args.snake_length // 2,
671 elif args.task == "stack":
673 nb_train_samples=args.nb_train_samples,
674 nb_test_samples=args.nb_test_samples,
675 batch_size=args.batch_size,
677 nb_steps=args.stack_nb_steps,
678 nb_stacks=args.stack_nb_stacks,
679 nb_digits=args.stack_nb_digits,
680 fraction_values_for_train=args.stack_fraction_values_for_train,
684 elif args.task == "expr":
686 nb_train_samples=args.nb_train_samples,
687 nb_test_samples=args.nb_test_samples,
688 nb_variables=args.expr_nb_variables,
689 sequence_length=args.expr_sequence_length,
690 operand_max=args.expr_operand_max,
691 result_max=args.expr_result_max,
692 batch_size=args.batch_size,
696 elif args.task == "rpl":
698 nb_train_samples=args.nb_train_samples,
699 nb_test_samples=args.nb_test_samples,
700 batch_size=args.batch_size,
701 nb_starting_values=args.rpl_nb_starting_values,
702 max_input=args.rpl_max_input,
703 prog_len=args.rpl_prog_len,
704 nb_runs=args.rpl_nb_runs,
705 no_prog=args.rpl_no_prog,
710 elif args.task == "grid":
712 nb_train_samples=args.nb_train_samples,
713 nb_test_samples=args.nb_test_samples,
714 batch_size=args.batch_size,
716 nb_shapes=args.grid_nb_shapes,
717 nb_colors=args.grid_nb_colors,
722 elif args.task == "qmlp":
724 nb_train_samples=args.nb_train_samples,
725 nb_test_samples=args.nb_test_samples,
726 batch_size=args.batch_size,
727 result_dir=args.result_dir,
733 raise ValueError(f"Unknown task {args.task}")
735 ######################################################################
737 log_string(f"device {device}")
739 vocabulary_size = task.vocabulary_size()
741 if args.proportion_memex > 0:
744 log_string(f"vocabulary_size {vocabulary_size}")
746 ##############################
749 vocabulary_size=vocabulary_size,
750 dim_model=args.dim_model,
751 dim_keys=args.dim_keys,
752 dim_hidden=args.dim_hidden,
753 nb_heads=args.nb_heads,
754 nb_lines=args.nb_lines,
755 caterpillar_height=args.caterpillar_height,
756 nb_blocks=args.nb_blocks,
758 dropout=args.dropout,
759 attention_layer=args.attention,
766 nb_parameters = sum(p.numel() for p in model.parameters())
767 log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
769 ######################################################################
771 nb_epochs_finished = 0
773 if args.no_checkpoint:
774 log_string(f"not trying to load checkpoint.")
778 checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
779 checkpoint = torch.load(checkpoint_name)
780 nb_epochs_finished = checkpoint["nb_epochs_finished"]
781 model.load_state_dict(checkpoint["model_state"])
782 torch.set_rng_state(checkpoint["rng_state"])
783 if torch.cuda.is_available():
784 torch.cuda.set_rng_state(checkpoint["cuda_rng_state"])
786 log_string(f"checkpoint loaded with {nb_epochs_finished} epochs finished.")
788 except FileNotFoundError:
789 log_string("starting from scratch.")
792 log_string("error when loading the checkpoint.")
795 ######################################################################
797 if args.task == "expr" and args.expr_input_file is not None:
798 task.produce_results(
799 n_epoch=nb_epochs_finished,
801 result_dir=args.result_dir,
803 deterministic_synthesis=args.deterministic_synthesis,
804 input_file=args.expr_input_file,
809 ######################################################################
811 nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
813 # Compute the entropy of the training tokens
816 for input in task.batches(split="train"):
817 token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1))
818 token_probas = token_count / token_count.sum()
819 entropy = -torch.xlogy(token_probas, token_probas).sum()
820 train_set_perplexity = math.exp(entropy)
822 ######################################################################
823 # A bit of paranoia never hurts
825 if args.max_percents_of_test_in_train >= 0:
827 def subsets_as_tuples(batches, cs):
829 for batch in batches:
831 s.add(tuple([v.item() for v in x]))
837 nb_test, nb_in_train = 0, 0
838 for test_subset in subsets_as_tuples(task.batches(split="test"), 25000):
840 for train_subset in subsets_as_tuples(task.batches(split="train"), 25000):
841 in_train.update(test_subset.intersection(train_subset))
842 nb_in_train += len(in_train)
843 nb_test += len(test_subset)
846 f"data_check {nb_in_train*100/nb_test:.02f}% ({nb_in_train}/{nb_test}) of test samples are in the train set"
850 nb_in_train <= args.max_percents_of_test_in_train * nb_test / 100
851 ), f"More than {args.max_percents_of_test_in_train}% of test samples are in the train set"
853 ##############################
855 if "calibrate" in sup_args:
856 for input in task.batches(split="train", desc="calibrate"):
857 input = input.to(device)
858 output = model(mygpt.BracketedSequence(input)).x
860 for n, m in model.named_modules():
863 if isinstance(x, mygpt.Calibrator):
864 print(f"####### ${n} | ${a} ########################")
865 mean, std = x.moments()
866 print("mean\n", mean, "\n")
867 print("std\n", std, "\n")
868 print(f"############################################\n\n")
872 ##############################
876 if nb_epochs_finished >= nb_epochs:
877 task.produce_results(
878 n_epoch=nb_epochs_finished,
880 result_dir=args.result_dir,
882 deterministic_synthesis=args.deterministic_synthesis,
885 time_pred_result = datetime.datetime.now()
891 for n_epoch in range(nb_epochs_finished, nb_epochs):
892 if args.optim == "sgd":
893 optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate)
894 elif args.optim == "adam":
895 optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
896 elif args.optim == "adamw":
897 optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)
899 raise ValueError(f"Unknown optimizer {args.optim}.")
903 nb_train_samples, acc_train_loss, acc_train_inner_loss = 0, 0.0, 0.0
905 def add_memex(batches, proportion_memex):
906 for input in batches:
907 if torch.rand(1).item() < proportion_memex:
912 (input.size(0), 1), vocabulary_size - 1, device=input.device
920 train_batches = add_memex(task.batches(split="train"), args.proportion_memex)
922 for input in train_batches:
923 model.reset_inner_loss()
924 input = input.to(device)
926 output = model(mygpt.BracketedSequence(input)).x
927 loss = F.cross_entropy(output.transpose(1, 2), input)
928 inner_loss = model.get_inner_loss()
930 acc_train_loss += loss.item() * input.size(0)
931 acc_train_inner_loss += inner_loss.item() * input.size(0)
933 nb_train_samples += input.size(0)
934 nb_samples_seen += input.size(0)
936 total_loss = loss + (
937 args.rho_inner_loss * inner_loss if args.rho_inner_loss > 0 else 0.0
941 lr = get_lr(n_epoch, it)
942 for param_group in optimizer.param_groups:
943 param_group["lr"] = lr
945 # log_string(f"learning_rate {lr}")
947 optimizer.zero_grad()
948 total_loss.backward()
951 grad_norm = sum([p.grad.pow(2).sum() for p in model.parameters()]).sqrt()
953 loss_file.write(f"{n_epoch} {n_batch} {loss.item()} {grad_norm.item()}\n")
957 with torch.autograd.no_grad():
960 nb_test_samples, acc_test_loss = 0, 0.0
962 for input in task.batches(split="test"):
963 input = input.to(device)
965 output = model(mygpt.BracketedSequence(input)).x
966 loss = F.cross_entropy(output.transpose(1, 2), input)
967 acc_test_loss += loss.item() * input.size(0)
968 nb_test_samples += input.size(0)
971 f"loss {n_epoch} train_loss {acc_train_loss/nb_train_samples} train_inner_loss {acc_train_inner_loss/nb_train_samples} test_prediction {acc_test_loss/nb_test_samples}"
974 task.produce_results(
977 result_dir=args.result_dir,
979 deterministic_synthesis=args.deterministic_synthesis,
982 train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
983 test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
986 f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
989 time_current_result = datetime.datetime.now()
991 f"next_result {time_current_result + (time_current_result - time_pred_result)}"
993 time_pred_result = time_current_result
996 "nb_epochs_finished": n_epoch + 1,
997 "model_state": model.state_dict(),
998 "rng_state": torch.get_rng_state(),
1001 if torch.cuda.is_available():
1002 checkpoint["cuda_rng_state"] = torch.cuda.get_rng_state()
1004 checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
1005 torch.save(checkpoint, checkpoint_name)
1006 log_string(f"saved checkpoint {checkpoint_name}")
1008 ######################################################################