Update.
[mygptrnn.git] / main.py
diff --git a/main.py b/main.py
index 74e70b2..c22ae57 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -16,13 +16,16 @@ import mygpt, tasks, problems
 
 ######################################################################
 
-if torch.cuda.is_available():
-    device = torch.device("cuda")
-    torch.backends.cuda.matmul.allow_tf32 = True
-else:
-    device = torch.device("cpu")
 
-######################################################################
+def str2bool(x):
+    x = x.lower()
+    if x in {"1", "true", "yes"}:
+        return True
+    elif x in {"0", "false", "no"}:
+        return False
+    else:
+        raise ValueError
+
 
 parser = argparse.ArgumentParser(
     description="An implementation of GPT with cache.",
@@ -44,6 +47,8 @@ parser.add_argument("--seed", type=int, default=0)
 
 parser.add_argument("--max_percents_of_test_in_train", type=int, default=1)
 
+parser.add_argument("--force_cpu", type=str2bool, default=False)
+
 ########################################
 
 parser.add_argument("--nb_epochs", type=int, default=50)
@@ -66,6 +71,16 @@ parser.add_argument("--learning_rate", type=float, default=6e-4)
 
 parser.add_argument("--min_learning_rate", type=float, default=6e-5)
 
+# legacy
+
+parser.add_argument("--legacy_lr_schedule", type=str2bool, default=True)
+
+parser.add_argument("--legacy_large_lr", type=float, default=1e-4)
+
+parser.add_argument("--legacy_small_lr", type=float, default=2e-5)
+
+parser.add_argument("--legacy_nb_epoch_large_lr", type=float, default=10)
+
 ########################################
 
 parser.add_argument("--model", type=str, default=None)
@@ -86,8 +101,6 @@ parser.add_argument("--caterpillar_height", type=int, default=None)
 
 parser.add_argument("--rho", type=float, default=0.0)
 
-parser.add_argument("--dim_rec_v", type=int, default=None)
-
 parser.add_argument("--nb_blocks", type=int, default=None)
 
 parser.add_argument("--dropout", type=float, default=0.1)
@@ -98,7 +111,7 @@ parser.add_argument("--deterministic_synthesis", action="store_true", default=Fa
 
 parser.add_argument("--no_checkpoint", action="store_true", default=False)
 
-parser.add_argument("--overwrite_results", action="store_true", default=False)
+parser.add_argument("--continue_training", action="store_true", default=False)
 
 parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
 
@@ -189,15 +202,25 @@ parser.add_argument("--mixing_deterministic_start", action="store_true", default
 
 ######################################################################
 
-args = parser.parse_args()
+args = parser.parse_args()
 
-assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"}
+args, sup_args = parser.parse_known_args()
+
+sup_args = dict([x.removeprefix("--").split("=") for x in sup_args])
 
 if args.result_dir is None:
     args.result_dir = f"results_{args.task}_{args.model}"
 
 ######################################################################
 
+if not args.force_cpu and torch.cuda.is_available():
+    device = torch.device("cuda")
+    torch.backends.cuda.matmul.allow_tf32 = True
+else:
+    device = torch.device("cpu")
+
+######################################################################
+
 default_task_args = {
     "addition": {
         "model": "352M",
@@ -311,7 +334,6 @@ default_model_args = {
         "dim_keys": 32,
         "dim_hidden": 32,
         "nb_heads": 2,
-        "dim_rec_v": 16,
         "nb_blocks": 2,
     },
     "17K-C": {
@@ -322,7 +344,6 @@ default_model_args = {
         "nb_heads": 2,
         "nb_lines": 16,
         "caterpillar_height": 4,
-        "dim_rec_v": 16,
         "nb_blocks": 2,
     },
     "4M": {
@@ -331,7 +352,6 @@ default_model_args = {
         "dim_keys": 32,
         "dim_hidden": 1024,
         "nb_heads": 4,
-        "dim_rec_v": 64,
         "nb_blocks": 6,
     },
     "4M-C": {
@@ -342,7 +362,6 @@ default_model_args = {
         "nb_heads": 4,
         "nb_lines": 32,
         "caterpillar_height": 4,
-        "dim_rec_v": 64,  # dim_model / nb_heads
         "nb_blocks": 6,
     },
     "37M": {
@@ -351,7 +370,6 @@ default_model_args = {
         "dim_keys": 64,
         "dim_hidden": 2048,
         "nb_heads": 8,
-        "dim_rec_v": 64,
         "nb_blocks": 12,
     },
     "37M-C": {
@@ -362,7 +380,6 @@ default_model_args = {
         "nb_heads": 8,
         "nb_lines": 256,
         "caterpillar_height": 32,
-        "dim_rec_v": 64,
         "nb_blocks": 12,
     },
     "122M": {
@@ -371,7 +388,6 @@ default_model_args = {
         "dim_keys": 64,
         "dim_hidden": 2048,
         "nb_heads": 8,
-        "dim_rec_v": 96,
         "nb_blocks": 24,
     },
     "122M-C": {
@@ -381,7 +397,6 @@ default_model_args = {
         "dim_hidden": 2048,
         "nb_heads": 8,
         "nb_lines": 128,
-        "dim_rec_v": 96,
         "nb_blocks": 24,
     },
     "352M": {
@@ -390,7 +405,6 @@ default_model_args = {
         "dim_keys": 64,
         "dim_hidden": 2048,
         "nb_heads": 8,
-        "dim_rec_v": 128,
         "nb_blocks": 48,
     },
     "352M-C": {
@@ -400,7 +414,6 @@ default_model_args = {
         "dim_hidden": 2048,
         "nb_heads": 8,
         "nb_lines": 128,
-        "dim_rec_v": 128,
         "nb_blocks": 48,
     },
 }
@@ -417,10 +430,12 @@ else:
 try:
     os.mkdir(args.result_dir)
 except FileExistsError:
-    if not args.overwrite_results:
+    if not args.continue_training:
         print(f"result directory {args.result_dir} already exists")
         exit(1)
 
+loss_file = open(os.path.join(args.result_dir, "loss.dat"), "a")
+
 log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
 
 if args.seed >= 0:
@@ -450,20 +465,34 @@ with os.popen("sha256sum *.py") as f:
         log_string(f"sha256sum {l.strip()}")
 
 now = time.strftime("%Y%m%d-%H%M%S", time.localtime())
-os.system(f"tar zcvf {args.result_dir}/src-{now}.tgz *.py *.sh")
+os.system(f"tar --ignore-failed-read zcvf {args.result_dir}/src-{now}.tgz *.py *.sh")
 
 log_string(f"argv {' '.join(sys.argv)}")
 
 for n in vars(args):
     log_string(f"args.{n} {getattr(args, n)}")
 
+for n in vars(sup_args):
+    log_string(f"sup_args.{n} {getattr(sup_args, n)}")
+
 
 ######################################################################
 
-# from nanoGPT
 
+def get_lr(n_epoch, it):
+    if args.legacy_lr_schedule:
+        # my crude scheduling to compare to previous baseline, added
+        # warmup though
+
+        if it < args.nb_warmup_iter:
+            return args.legacy_large_lr * it / args.nb_warmup_iter
+        elif n_epoch < args.legacy_nb_epoch_large_lr:
+            return args.legacy_large_lr
+        else:
+            return args.legacy_small_lr
+
+    # from nanoGPT
 
-def get_lr(it):
     # 1) linear warmup for warmup_iter steps
     if it < args.nb_warmup_iter:
         return args.learning_rate * it / args.nb_warmup_iter
@@ -483,6 +512,9 @@ def get_lr(it):
 ######################################################################
 
 
+assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"}
+
+
 def picoclvr_pruner_horizontal_green(p):
     return not ("green" in p and ("left" in p or "right" in p))
 
@@ -704,11 +736,12 @@ model = mygpt.MyGPT(
     nb_heads=args.nb_heads,
     nb_lines=args.nb_lines,
     caterpillar_height=args.caterpillar_height,
-    dim_rec_v=args.dim_rec_v,
     nb_blocks=args.nb_blocks,
     causal=True,
     dropout=args.dropout,
     attention_layer=args.attention,
+    logger=log_string,
+    **sup_args,
 )
 
 model.to(device)
@@ -813,10 +846,12 @@ if nb_epochs_finished >= nb_epochs:
         deterministic_synthesis=args.deterministic_synthesis,
     )
 
-time_pred_result = None
+time_pred_result = datetime.datetime.now()
 
 it = 0
 
+n_batch = 0
+
 for n_epoch in range(nb_epochs_finished, nb_epochs):
     if args.optim == "sgd":
         optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate)
@@ -848,7 +883,7 @@ for n_epoch in range(nb_epochs_finished, nb_epochs):
         total_loss = loss + (args.rho * inner_loss if args.rho > 0 else 0.0)
 
         it += 1
-        lr = get_lr(it)
+        lr = get_lr(n_epoch, it)
         for param_group in optimizer.param_groups:
             param_group["lr"] = lr
 
@@ -858,6 +893,12 @@ for n_epoch in range(nb_epochs_finished, nb_epochs):
         total_loss.backward()
         optimizer.step()
 
+        grad_norm = sum([p.grad.pow(2).sum() for p in model.parameters()]).sqrt()
+
+        loss_file.write(f"{n_epoch} {n_batch} {loss.item()} {grad_norm.item()}\n")
+
+        n_batch += 1
+
     with torch.autograd.no_grad():
         model.eval()
 
@@ -891,10 +932,9 @@ for n_epoch in range(nb_epochs_finished, nb_epochs):
         )
 
         time_current_result = datetime.datetime.now()
-        if time_pred_result is not None:
-            log_string(
-                f"next_result {time_current_result + (time_current_result - time_pred_result)}"
-            )
+        log_string(
+            f"next_result {time_current_result + (time_current_result - time_pred_result)}"
+        )
         time_pred_result = time_current_result
 
     checkpoint = {