+######################################################################
+
+nb_epochs_finished = 0
+
+if args.no_checkpoint:
+ log_string(f'not trying to load checkpoint.')
+
+else:
+ try:
+ checkpoint = torch.load(args.checkpoint_name, map_location = device)
+ nb_epochs_finished = checkpoint['nb_epochs_finished']
+ model.load_state_dict(checkpoint['model_state'])
+ optimizer.load_state_dict(checkpoint['optimizer_state'])
+ log_string(f'checkpoint loaded with {nb_epochs_finished} epochs finished.')
+
+ except FileNotFoundError:
+ log_string('starting from scratch.')
+
+ except:
+ log_string('error when loading the checkpoint.')
+ exit(1)
+
+######################################################################
+
+nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
+
+token_count = 0
+for input in task.batches(split = 'train'):
+ token_count += F.one_hot(input, num_classes = task.vocabulary_size()).sum((0, 1))
+token_probas = token_count / token_count.sum()
+h = -torch.xlogy(token_probas, token_probas).sum()
+train_set_perplexity = math.exp(h)
+log_string(f'train set perplexity {train_set_perplexity}')
+
+for k in range(nb_epochs_finished, nb_epochs):