Update.
authorFrancois Fleuret <francois@fleuret.org>
Sat, 2 Jul 2022 19:07:54 +0000 (21:07 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Sat, 2 Jul 2022 19:07:54 +0000 (21:07 +0200)
main.py

diff --git a/main.py b/main.py
index 3bf7587..85cf4cf 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -69,6 +69,9 @@ parser.add_argument('--dropout',
 parser.add_argument('--synthesis_sampling',
                     type = bool, default = True)
 
+parser.add_argument('--checkpoint_name',
+                    type = str, default = 'checkpoint.pth')
+
 ######################################################################
 
 args = parser.parse_args()
@@ -366,11 +369,11 @@ model = mygpt.MyGPT(
     nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
 )
 
+model.to(device)
+
 nb_parameters = sum(p.numel() for p in model.parameters())
 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
 
-model.to(device)
-
 ######################################################################
 
 if args.optim == 'sgd':
@@ -382,7 +385,27 @@ elif args.optim == 'adamw':
 else:
     raise ValueError(f'Unknown optimizer {args.optim}.')
 
-for k in range(args.nb_epochs):
+######################################################################
+
+nb_epochs_finished = 0
+
+try:
+    checkpoint = torch.load(args.checkpoint_name, map_location = device)
+    nb_epochs_finished = checkpoint['nb_epochs_finished']
+    model.load_state_dict(checkpoint['model_state'])
+    optimizer.load_state_dict(checkpoint['optimizer_state'])
+    print(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.')
+
+except FileNotFoundError:
+    print('Starting from scratch.')
+
+except:
+    print('Error when loading the checkpoint.')
+    exit(1)
+
+######################################################################
+
+for k in range(nb_epochs_finished, args.nb_epochs):
 
     model.train()
 
@@ -419,4 +442,12 @@ for k in range(args.nb_epochs):
 
         task.produce_results(k, model)
 
+    checkpoint = {
+        'nb_epochs_finished': k + 1,
+        'model_state': model.state_dict(),
+        'optimizer_state': optimizer.state_dict()
+    }
+
+    torch.save(checkpoint, args.checkpoint_name)
+
 ######################################################################