X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=18c0730d7608866041a35874da1c1dc0b641133f;hb=0d25f8a86e80850cf6a6e27d419f7b043c6028f1;hp=74e70b2a9780601053e51e0d359d082e7a869e1d;hpb=cb737bdbd2f112826f739e4581fbe6546aeef638;p=mygptrnn.git diff --git a/main.py b/main.py index 74e70b2..18c0730 100755 --- a/main.py +++ b/main.py @@ -66,6 +66,16 @@ parser.add_argument("--learning_rate", type=float, default=6e-4) parser.add_argument("--min_learning_rate", type=float, default=6e-5) +# legacy + +parser.add_argument("--legacy_lr_schedule", action="store_true", default=False) + +parser.add_argument("--legacy_large_lr", type=float, default=1e-4) + +parser.add_argument("--legacy_small_lr", type=float, default=2e-5) + +parser.add_argument("--legacy_nb_epoch_large_lr", type=float, default=10) + ######################################## parser.add_argument("--model", type=str, default=None) @@ -460,10 +470,21 @@ for n in vars(args): ###################################################################### -# from nanoGPT +def get_lr(n_epoch, it): + if args.legacy_lr_schedule: + # my crude scheduling to compare to previous baseline, added + # warmup though + + if it < args.nb_warmup_iter: + return args.legacy_large_lr * it / args.nb_warmup_iter + elif n_epoch < args.legacy_nb_epoch_large_lr: + return args.legacy_large_lr + else: + return args.legacy_small_lr + + # from nanoGPT -def get_lr(it): # 1) linear warmup for warmup_iter steps if it < args.nb_warmup_iter: return args.learning_rate * it / args.nb_warmup_iter @@ -848,7 +869,7 @@ for n_epoch in range(nb_epochs_finished, nb_epochs): total_loss = loss + (args.rho * inner_loss if args.rho > 0 else 0.0) it += 1 - lr = get_lr(it) + lr = get_lr(n_epoch, it) for param_group in optimizer.param_groups: param_group["lr"] = lr