Added the rng state in the checkpoint.
authorFrancois Fleuret <francois@fleuret.org>
Sun, 7 Aug 2022 19:50:36 +0000 (21:50 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Sun, 7 Aug 2022 19:50:36 +0000 (21:50 +0200)
main.py

diff --git a/main.py b/main.py
index b01ea0a..d4a8cfb 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -430,17 +430,6 @@ log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
 
 ######################################################################
 
-if args.optim == 'sgd':
-    optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
-elif args.optim == 'adam':
-    optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
-elif args.optim == 'adamw':
-    optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
-else:
-    raise ValueError(f'Unknown optimizer {args.optim}.')
-
-######################################################################
-
 nb_epochs_finished = 0
 
 if args.no_checkpoint:
@@ -448,10 +437,12 @@ if args.no_checkpoint:
 
 else:
     try:
-        checkpoint = torch.load(args.checkpoint_name, map_location = device)
+        checkpoint = torch.load(args.checkpoint_name)
         nb_epochs_finished = checkpoint['nb_epochs_finished']
         model.load_state_dict(checkpoint['model_state'])
-        optimizer.load_state_dict(checkpoint['optimizer_state'])
+        torch.set_rng_state(checkpoint['rng_state'])
+        if torch.cuda.is_available():
+            torch.cuda.set_rng_state(checkpoint['cuda_rng_state'])
         log_string(f'checkpoint loaded with {nb_epochs_finished} epochs finished.')
 
     except FileNotFoundError:
@@ -472,7 +463,16 @@ token_probas = token_count / token_count.sum()
 entropy = -torch.xlogy(token_probas, token_probas).sum()
 train_set_perplexity = math.exp(entropy)
 
-for k in range(nb_epochs_finished, nb_epochs):
+for n_epoch in range(nb_epochs_finished, nb_epochs):
+
+    if args.optim == 'sgd':
+        optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
+    elif args.optim == 'adam':
+        optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
+    elif args.optim == 'adamw':
+        optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
+    else:
+        raise ValueError(f'Unknown optimizer {args.optim}.')
 
     model.train()
 
@@ -505,16 +505,19 @@ for k in range(nb_epochs_finished, nb_epochs):
         train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
         test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
 
-        log_string(f'perplexity {k} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}')
+        log_string(f'perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}')
 
-        task.produce_results(k, model)
+        task.produce_results(n_epoch, model)
 
     checkpoint = {
-        'nb_epochs_finished': k + 1,
+        'nb_epochs_finished': n_epoch + 1,
         'model_state': model.state_dict(),
-        'optimizer_state': optimizer.state_dict()
+        'rng_state': torch.get_rng_state(),
     }
 
+    if torch.cuda.is_available():
+        checkpoint['cuda_rng_state'] = torch.cuda.get_rng_state()
+
     torch.save(checkpoint, args.checkpoint_name)
 
 ######################################################################