Update.
[mygptrnn.git] / main.py
diff --git a/main.py b/main.py
index df46652..74e1d6c 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -24,6 +24,17 @@ else:
 
 ######################################################################
 
+
+def str2bool(x):
+    x = x.lower()
+    if x in {"1", "true", "yes"}:
+        return True
+    elif x in {"0", "false", "no"}:
+        return False
+    else:
+        raise ValueError
+
+
 parser = argparse.ArgumentParser(
     description="An implementation of GPT with cache.",
     formatter_class=argparse.ArgumentDefaultsHelpFormatter,
@@ -66,6 +77,16 @@ parser.add_argument("--learning_rate", type=float, default=6e-4)
 
 parser.add_argument("--min_learning_rate", type=float, default=6e-5)
 
+# legacy
+
+parser.add_argument("--legacy_lr_schedule", type=str2bool, default=True)
+
+parser.add_argument("--legacy_large_lr", type=float, default=1e-4)
+
+parser.add_argument("--legacy_small_lr", type=float, default=2e-5)
+
+parser.add_argument("--legacy_nb_epoch_large_lr", type=float, default=10)
+
 ########################################
 
 parser.add_argument("--model", type=str, default=None)
@@ -86,8 +107,6 @@ parser.add_argument("--caterpillar_height", type=int, default=None)
 
 parser.add_argument("--rho", type=float, default=0.0)
 
-parser.add_argument("--dim_rec_v", type=int, default=None)
-
 parser.add_argument("--nb_blocks", type=int, default=None)
 
 parser.add_argument("--dropout", type=float, default=0.1)
@@ -311,7 +330,6 @@ default_model_args = {
         "dim_keys": 32,
         "dim_hidden": 32,
         "nb_heads": 2,
-        "dim_rec_v": 16,
         "nb_blocks": 2,
     },
     "17K-C": {
@@ -322,7 +340,6 @@ default_model_args = {
         "nb_heads": 2,
         "nb_lines": 16,
         "caterpillar_height": 4,
-        "dim_rec_v": 16,
         "nb_blocks": 2,
     },
     "4M": {
@@ -331,7 +348,6 @@ default_model_args = {
         "dim_keys": 32,
         "dim_hidden": 1024,
         "nb_heads": 4,
-        "dim_rec_v": 64,
         "nb_blocks": 6,
     },
     "4M-C": {
@@ -342,15 +358,14 @@ default_model_args = {
         "nb_heads": 4,
         "nb_lines": 32,
         "caterpillar_height": 4,
-        "dim_rec_v": 64,  # dim_model / nb_heads
         "nb_blocks": 6,
     },
     "37M": {
+        "attention": "mha",
         "dim_model": 512,
         "dim_keys": 64,
         "dim_hidden": 2048,
         "nb_heads": 8,
-        "dim_rec_v": 64,
         "nb_blocks": 12,
     },
     "37M-C": {
@@ -361,7 +376,6 @@ default_model_args = {
         "nb_heads": 8,
         "nb_lines": 256,
         "caterpillar_height": 32,
-        "dim_rec_v": 64,
         "nb_blocks": 12,
     },
     "122M": {
@@ -370,7 +384,6 @@ default_model_args = {
         "dim_keys": 64,
         "dim_hidden": 2048,
         "nb_heads": 8,
-        "dim_rec_v": 96,
         "nb_blocks": 24,
     },
     "122M-C": {
@@ -380,7 +393,6 @@ default_model_args = {
         "dim_hidden": 2048,
         "nb_heads": 8,
         "nb_lines": 128,
-        "dim_rec_v": 96,
         "nb_blocks": 24,
     },
     "352M": {
@@ -389,7 +401,6 @@ default_model_args = {
         "dim_keys": 64,
         "dim_hidden": 2048,
         "nb_heads": 8,
-        "dim_rec_v": 128,
         "nb_blocks": 48,
     },
     "352M-C": {
@@ -399,7 +410,6 @@ default_model_args = {
         "dim_hidden": 2048,
         "nb_heads": 8,
         "nb_lines": 128,
-        "dim_rec_v": 128,
         "nb_blocks": 48,
     },
 }
@@ -459,10 +469,21 @@ for n in vars(args):
 
 ######################################################################
 
-# from nanoGPT
 
+def get_lr(n_epoch, it):
+    if args.legacy_lr_schedule:
+        # my crude scheduling to compare to previous baseline, added
+        # warmup though
+
+        if it < args.nb_warmup_iter:
+            return args.legacy_large_lr * it / args.nb_warmup_iter
+        elif n_epoch < args.legacy_nb_epoch_large_lr:
+            return args.legacy_large_lr
+        else:
+            return args.legacy_small_lr
+
+    # from nanoGPT
 
-def get_lr(it):
     # 1) linear warmup for warmup_iter steps
     if it < args.nb_warmup_iter:
         return args.learning_rate * it / args.nb_warmup_iter
@@ -703,7 +724,6 @@ model = mygpt.MyGPT(
     nb_heads=args.nb_heads,
     nb_lines=args.nb_lines,
     caterpillar_height=args.caterpillar_height,
-    dim_rec_v=args.dim_rec_v,
     nb_blocks=args.nb_blocks,
     causal=True,
     dropout=args.dropout,
@@ -847,7 +867,7 @@ for n_epoch in range(nb_epochs_finished, nb_epochs):
         total_loss = loss + (args.rho * inner_loss if args.rho > 0 else 0.0)
 
         it += 1
-        lr = get_lr(it)
+        lr = get_lr(n_epoch, it)
         for param_group in optimizer.param_groups:
             param_group["lr"] = lr