Update.
[mygptrnn.git] / main.py
diff --git a/main.py b/main.py
index 18c0730..c51035c 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -24,6 +24,17 @@ else:
 
 ######################################################################
 
+
+def str2bool(x):
+    x = x.lower()
+    if x in {"1", "true", "yes"}:
+        return True
+    elif x in {"0", "false", "no"}:
+        return False
+    else:
+        raise ValueError
+
+
 parser = argparse.ArgumentParser(
     description="An implementation of GPT with cache.",
     formatter_class=argparse.ArgumentDefaultsHelpFormatter,
@@ -68,7 +79,7 @@ parser.add_argument("--min_learning_rate", type=float, default=6e-5)
 
 # legacy
 
-parser.add_argument("--legacy_lr_schedule", action="store_true", default=False)
+parser.add_argument("--legacy_lr_schedule", type=str2bool, default=True)
 
 parser.add_argument("--legacy_large_lr", type=float, default=1e-4)
 
@@ -96,8 +107,6 @@ parser.add_argument("--caterpillar_height", type=int, default=None)
 
 parser.add_argument("--rho", type=float, default=0.0)
 
-parser.add_argument("--dim_rec_v", type=int, default=None)
-
 parser.add_argument("--nb_blocks", type=int, default=None)
 
 parser.add_argument("--dropout", type=float, default=0.1)
@@ -108,7 +117,7 @@ parser.add_argument("--deterministic_synthesis", action="store_true", default=Fa
 
 parser.add_argument("--no_checkpoint", action="store_true", default=False)
 
-parser.add_argument("--overwrite_results", action="store_true", default=False)
+parser.add_argument("--continue_training", action="store_true", default=False)
 
 parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
 
@@ -321,7 +330,6 @@ default_model_args = {
         "dim_keys": 32,
         "dim_hidden": 32,
         "nb_heads": 2,
-        "dim_rec_v": 16,
         "nb_blocks": 2,
     },
     "17K-C": {
@@ -332,7 +340,6 @@ default_model_args = {
         "nb_heads": 2,
         "nb_lines": 16,
         "caterpillar_height": 4,
-        "dim_rec_v": 16,
         "nb_blocks": 2,
     },
     "4M": {
@@ -341,7 +348,6 @@ default_model_args = {
         "dim_keys": 32,
         "dim_hidden": 1024,
         "nb_heads": 4,
-        "dim_rec_v": 64,
         "nb_blocks": 6,
     },
     "4M-C": {
@@ -352,7 +358,6 @@ default_model_args = {
         "nb_heads": 4,
         "nb_lines": 32,
         "caterpillar_height": 4,
-        "dim_rec_v": 64,  # dim_model / nb_heads
         "nb_blocks": 6,
     },
     "37M": {
@@ -361,7 +366,6 @@ default_model_args = {
         "dim_keys": 64,
         "dim_hidden": 2048,
         "nb_heads": 8,
-        "dim_rec_v": 64,
         "nb_blocks": 12,
     },
     "37M-C": {
@@ -372,7 +376,6 @@ default_model_args = {
         "nb_heads": 8,
         "nb_lines": 256,
         "caterpillar_height": 32,
-        "dim_rec_v": 64,
         "nb_blocks": 12,
     },
     "122M": {
@@ -381,7 +384,6 @@ default_model_args = {
         "dim_keys": 64,
         "dim_hidden": 2048,
         "nb_heads": 8,
-        "dim_rec_v": 96,
         "nb_blocks": 24,
     },
     "122M-C": {
@@ -391,7 +393,6 @@ default_model_args = {
         "dim_hidden": 2048,
         "nb_heads": 8,
         "nb_lines": 128,
-        "dim_rec_v": 96,
         "nb_blocks": 24,
     },
     "352M": {
@@ -400,7 +401,6 @@ default_model_args = {
         "dim_keys": 64,
         "dim_hidden": 2048,
         "nb_heads": 8,
-        "dim_rec_v": 128,
         "nb_blocks": 48,
     },
     "352M-C": {
@@ -410,7 +410,6 @@ default_model_args = {
         "dim_hidden": 2048,
         "nb_heads": 8,
         "nb_lines": 128,
-        "dim_rec_v": 128,
         "nb_blocks": 48,
     },
 }
@@ -427,7 +426,7 @@ else:
 try:
     os.mkdir(args.result_dir)
 except FileExistsError:
-    if not args.overwrite_results:
+    if not args.continue_training:
         print(f"result directory {args.result_dir} already exists")
         exit(1)
 
@@ -725,7 +724,6 @@ model = mygpt.MyGPT(
     nb_heads=args.nb_heads,
     nb_lines=args.nb_lines,
     caterpillar_height=args.caterpillar_height,
-    dim_rec_v=args.dim_rec_v,
     nb_blocks=args.nb_blocks,
     causal=True,
     dropout=args.dropout,