######################################################################
+
+def str2bool(x):
+ x = x.lower()
+ if x in {"1", "true", "yes"}:
+ return True
+ elif x in {"0", "false", "no"}:
+ return False
+ else:
+ raise ValueError
+
+
parser = argparse.ArgumentParser(
description="An implementation of GPT with cache.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
# legacy
-parser.add_argument("--legacy_lr_schedule", action="store_true", default=False)
+parser.add_argument("--legacy_lr_schedule", type=str2bool, default=True)
parser.add_argument("--legacy_large_lr", type=float, default=1e-4)
if it < args.nb_warmup_iter:
return args.legacy_large_lr * it / args.nb_warmup_iter
- elif it < args.legacy_nb_epoch_large_lr:
+ elif n_epoch < args.legacy_nb_epoch_large_lr:
return args.legacy_large_lr
else:
return args.legacy_small_lr