type = str, default = 'adam')
parser.add_argument('--learning_rate',
- type = float, default = 1e-4)
+ type = float, default = 1e-3)
+
+parser.add_argument('--learning_rate_end',
+ type = float, default = 1e-6)
parser.add_argument('--dim_model',
type = int, default = 512)
parser.add_argument('--dropout',
type = float, default = 0.1)
-parser.add_argument('--synthesis_sampling',
- action='store_true', default = True)
+parser.add_argument('--deterministic_synthesis',
+ action='store_true', default = False)
parser.add_argument('--no_checkpoint',
action='store_true', default = False)
for s in range(first, input.size(1)):
output = model(input)
logits = output[:, s]
- if args.synthesis_sampling:
+ if args.deterministic_synthesis:
+ t_next = logits.argmax(1)
+ else:
dist = torch.distributions.categorical.Categorical(logits = logits)
t_next = dist.sample()
- else:
- t_next = logits.argmax(1)
input[:, s] = t_next
return results
for n_epoch in range(nb_epochs_finished, nb_epochs):
+ if args.learning_rate_end < 0:
+ lr = args.learning_rate
+ else:
+ u = n_epoch / (nb_epochs - 1)
+ lr = math.exp((1 - u) * math.log(args.learning_rate) +
+ u * math.log(args.learning_rate_end))
+ log_string(f'learning_rate {lr}')
+
if args.optim == 'sgd':
- optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
+ optimizer = torch.optim.SGD(model.parameters(), lr = lr)
elif args.optim == 'adam':
- optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
+ optimizer = torch.optim.Adam(model.parameters(), lr = lr)
elif args.optim == 'adamw':
- optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
+ optimizer = torch.optim.AdamW(model.parameters(), lr = lr)
else:
raise ValueError(f'Unknown optimizer {args.optim}.')