Added args.learning_rate_end for an exponential decay.
authorFrancois Fleuret <francois@fleuret.org>
Mon, 8 Aug 2022 15:59:08 +0000 (17:59 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Mon, 8 Aug 2022 15:59:08 +0000 (17:59 +0200)
main.py

diff --git a/main.py b/main.py
index d4a8cfb..f6934b7 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -42,7 +42,10 @@ parser.add_argument('--optim',
                     type = str, default = 'adam')
 
 parser.add_argument('--learning_rate',
-                    type = float, default = 1e-4)
+                    type = float, default = 1e-3)
+
+parser.add_argument('--learning_rate_end',
+                    type = float, default = 1e-6)
 
 parser.add_argument('--dim_model',
                     type = int, default = 512)
@@ -465,12 +468,20 @@ train_set_perplexity = math.exp(entropy)
 
 for n_epoch in range(nb_epochs_finished, nb_epochs):
 
+    if args.learning_rate_end < 0:
+        lr = args.learning_rate
+    else:
+        u = n_epoch / (nb_epochs - 1)
+        lr = math.exp((1 - u) * math.log(args.learning_rate) +
+                      u * math.log(args.learning_rate_end))
+        log_string(f'learning_rate {lr}')
+
     if args.optim == 'sgd':
-        optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
+        optimizer = torch.optim.SGD(model.parameters(), lr = lr)
     elif args.optim == 'adam':
-        optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
+        optimizer = torch.optim.Adam(model.parameters(), lr = lr)
     elif args.optim == 'adamw':
-        optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
+        optimizer = torch.optim.AdamW(model.parameters(), lr = lr)
     else:
         raise ValueError(f'Unknown optimizer {args.optim}.')