From: Francois Fleuret Date: Mon, 8 Aug 2022 15:59:08 +0000 (+0200) Subject: Added args.learning_rate_end for an exponential decay. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=13a6ecc6e00a75ce5a95c54c11ce6f60902f57f1;p=mygpt.git Added args.learning_rate_end for an exponential decay. --- diff --git a/main.py b/main.py index d4a8cfb..f6934b7 100755 --- a/main.py +++ b/main.py @@ -42,7 +42,10 @@ parser.add_argument('--optim', type = str, default = 'adam') parser.add_argument('--learning_rate', - type = float, default = 1e-4) + type = float, default = 1e-3) + +parser.add_argument('--learning_rate_end', + type = float, default = 1e-6) parser.add_argument('--dim_model', type = int, default = 512) @@ -465,12 +468,20 @@ train_set_perplexity = math.exp(entropy) for n_epoch in range(nb_epochs_finished, nb_epochs): + if args.learning_rate_end < 0: + lr = args.learning_rate + else: + u = n_epoch / (nb_epochs - 1) + lr = math.exp((1 - u) * math.log(args.learning_rate) + + u * math.log(args.learning_rate_end)) + log_string(f'learning_rate {lr}') + if args.optim == 'sgd': - optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate) + optimizer = torch.optim.SGD(model.parameters(), lr = lr) elif args.optim == 'adam': - optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate) + optimizer = torch.optim.Adam(model.parameters(), lr = lr) elif args.optim == 'adamw': - optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate) + optimizer = torch.optim.AdamW(model.parameters(), lr = lr) else: raise ValueError(f'Unknown optimizer {args.optim}.')