Update.
[mygptrnn.git] / main.py
diff --git a/main.py b/main.py
index 18c0730..cae20f8 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -24,6 +24,17 @@ else:
 
 ######################################################################
 
+
+def str2bool(x):
+    x = x.lower()
+    if x in {"1", "true", "yes"}:
+        return True
+    elif x in {"0", "false", "no"}:
+        return False
+    else:
+        raise ValueError
+
+
 parser = argparse.ArgumentParser(
     description="An implementation of GPT with cache.",
     formatter_class=argparse.ArgumentDefaultsHelpFormatter,
@@ -68,7 +79,7 @@ parser.add_argument("--min_learning_rate", type=float, default=6e-5)
 
 # legacy
 
-parser.add_argument("--legacy_lr_schedule", action="store_true", default=False)
+parser.add_argument("--legacy_lr_schedule", type=str2bool, default=True)
 
 parser.add_argument("--legacy_large_lr", type=float, default=1e-4)