From 13a6ecc6e00a75ce5a95c54c11ce6f60902f57f1 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Mon, 8 Aug 2022 17:59:08 +0200 Subject: [PATCH] Added args.learning_rate_end for an exponential decay. --- main.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/main.py b/main.py index d4a8cfb..f6934b7 100755 --- a/main.py +++ b/main.py @@ -42,7 +42,10 @@ parser.add_argument('--optim', type = str, default = 'adam') parser.add_argument('--learning_rate', - type = float, default = 1e-4) + type = float, default = 1e-3) + +parser.add_argument('--learning_rate_end', + type = float, default = 1e-6) parser.add_argument('--dim_model', type = int, default = 512) @@ -465,12 +468,20 @@ train_set_perplexity = math.exp(entropy) for n_epoch in range(nb_epochs_finished, nb_epochs): + if args.learning_rate_end < 0: + lr = args.learning_rate + else: + u = n_epoch / (nb_epochs - 1) + lr = math.exp((1 - u) * math.log(args.learning_rate) + + u * math.log(args.learning_rate_end)) + log_string(f'learning_rate {lr}') + if args.optim == 'sgd': - optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate) + optimizer = torch.optim.SGD(model.parameters(), lr = lr) elif args.optim == 'adam': - optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate) + optimizer = torch.optim.Adam(model.parameters(), lr = lr) elif args.optim == 'adamw': - optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate) + optimizer = torch.optim.AdamW(model.parameters(), lr = lr) else: raise ValueError(f'Unknown optimizer {args.optim}.') -- 2.39.5