Replaced --synthesis_sampling with --deterministic_synthesis.
[mygpt.git] / main.py
diff --git a/main.py b/main.py
index d4a8cfb..ee44ebe 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -42,7 +42,10 @@ parser.add_argument('--optim',
                     type = str, default = 'adam')
 
 parser.add_argument('--learning_rate',
-                    type = float, default = 1e-4)
+                    type = float, default = 1e-3)
+
+parser.add_argument('--learning_rate_end',
+                    type = float, default = 1e-6)
 
 parser.add_argument('--dim_model',
                     type = int, default = 512)
@@ -62,8 +65,8 @@ parser.add_argument('--nb_blocks',
 parser.add_argument('--dropout',
                     type = float, default = 0.1)
 
-parser.add_argument('--synthesis_sampling',
-                    action='store_true', default = True)
+parser.add_argument('--deterministic_synthesis',
+                    action='store_true', default = False)
 
 parser.add_argument('--no_checkpoint',
                     action='store_true', default = False)
@@ -129,11 +132,11 @@ def autoregression(
         for s in range(first, input.size(1)):
             output = model(input)
             logits = output[:, s]
-            if args.synthesis_sampling:
+            if args.deterministic_synthesis:
+                t_next = logits.argmax(1)
+            else:
                 dist = torch.distributions.categorical.Categorical(logits = logits)
                 t_next = dist.sample()
-            else:
-                t_next = logits.argmax(1)
             input[:, s] = t_next
 
     return results
@@ -465,12 +468,20 @@ train_set_perplexity = math.exp(entropy)
 
 for n_epoch in range(nb_epochs_finished, nb_epochs):
 
+    if args.learning_rate_end < 0:
+        lr = args.learning_rate
+    else:
+        u = n_epoch / (nb_epochs - 1)
+        lr = math.exp((1 - u) * math.log(args.learning_rate) +
+                      u * math.log(args.learning_rate_end))
+        log_string(f'learning_rate {lr}')
+
     if args.optim == 'sgd':
-        optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
+        optimizer = torch.optim.SGD(model.parameters(), lr = lr)
     elif args.optim == 'adam':
-        optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
+        optimizer = torch.optim.Adam(model.parameters(), lr = lr)
     elif args.optim == 'adamw':
-        optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
+        optimizer = torch.optim.AdamW(model.parameters(), lr = lr)
     else:
         raise ValueError(f'Unknown optimizer {args.optim}.')