+nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
+
+token_count = 0
+for input in task.batches(split = 'train'):
+ token_count += F.one_hot(input, num_classes = task.vocabulary_size()).sum((0, 1))
+token_probas = token_count / token_count.sum()
+entropy = -torch.xlogy(token_probas, token_probas).sum()
+train_set_perplexity = math.exp(entropy)
+
+for n_epoch in range(nb_epochs_finished, nb_epochs):
+
+ if args.optim == 'sgd':
+ optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
+ elif args.optim == 'adam':
+ optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
+ elif args.optim == 'adamw':
+ optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
+ else:
+ raise ValueError(f'Unknown optimizer {args.optim}.')