X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=cae20f89347b076226b6525d561dab1d89a55e8b;hb=f3f490def0be8a3ea2b9a0ac60f5bb33c5c45fb5;hp=1a17e51a63741da0fa8c9f2910a319c331667c14;hpb=a09ee76c8283b7daf4c914df47f86d1964fc25d4;p=mygptrnn.git diff --git a/main.py b/main.py index 1a17e51..cae20f8 100755 --- a/main.py +++ b/main.py @@ -24,6 +24,17 @@ else: ###################################################################### + +def str2bool(x): + x = x.lower() + if x in {"1", "true", "yes"}: + return True + elif x in {"0", "false", "no"}: + return False + else: + raise ValueError + + parser = argparse.ArgumentParser( description="An implementation of GPT with cache.", formatter_class=argparse.ArgumentDefaultsHelpFormatter, @@ -68,13 +79,13 @@ parser.add_argument("--min_learning_rate", type=float, default=6e-5) # legacy -parser.add_argument("--legacy_lr_schedule", action="store_true", default=False) +parser.add_argument("--legacy_lr_schedule", type=str2bool, default=True) -parser.add_argument("--legacy_learning_rate", type=float, default=1e-4) +parser.add_argument("--legacy_large_lr", type=float, default=1e-4) -parser.add_argument("--legacy_min_learning_rate", type=float, default=2e-5) +parser.add_argument("--legacy_small_lr", type=float, default=2e-5) -parser.add_argument("--nb_large_lr_epochs", type=float, default=10) +parser.add_argument("--legacy_nb_epoch_large_lr", type=float, default=10) ######################################## @@ -477,11 +488,11 @@ def get_lr(n_epoch, it): # warmup though if it < args.nb_warmup_iter: - return args.legacy_learning_rate * it / args.nb_warmup_iter - elif it < args.nb_large_lr_epochs: - return args.legacy_learning_rate + return args.legacy_large_lr * it / args.nb_warmup_iter + elif n_epoch < args.legacy_nb_epoch_large_lr: + return args.legacy_large_lr else: - return args.legacy_min_learning_rate + return args.legacy_small_lr # from nanoGPT