Added --no_checkpoint
authorFrancois Fleuret <francois@fleuret.org>
Mon, 25 Jul 2022 19:03:46 +0000 (21:03 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Mon, 25 Jul 2022 19:03:46 +0000 (21:03 +0200)
main.py

diff --git a/main.py b/main.py
index 77c4b9e..f496e99 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -69,6 +69,9 @@ parser.add_argument('--dropout',
 parser.add_argument('--synthesis_sampling',
                     action='store_true', default = True)
 
+parser.add_argument('--no_checkpoint',
+                    action='store_true', default = False)
+
 parser.add_argument('--checkpoint_name',
                     type = str, default = 'checkpoint.pth')
 
@@ -429,19 +432,23 @@ else:
 
 nb_epochs_finished = 0
 
-try:
-    checkpoint = torch.load(args.checkpoint_name, map_location = device)
-    nb_epochs_finished = checkpoint['nb_epochs_finished']
-    model.load_state_dict(checkpoint['model_state'])
-    optimizer.load_state_dict(checkpoint['optimizer_state'])
-    print(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.')
+if args.no_checkpoint:
+    log_string(f'Not trying to load checkpoint.')
 
-except FileNotFoundError:
-    print('Starting from scratch.')
-
-except:
-    print('Error when loading the checkpoint.')
-    exit(1)
+else:
+    try:
+        checkpoint = torch.load(args.checkpoint_name, map_location = device)
+        nb_epochs_finished = checkpoint['nb_epochs_finished']
+        model.load_state_dict(checkpoint['model_state'])
+        optimizer.load_state_dict(checkpoint['optimizer_state'])
+        log_string(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.')
+
+    except FileNotFoundError:
+        log_string('Starting from scratch.')
+
+    except:
+        log_string('Error when loading the checkpoint.')
+        exit(1)
 
 ######################################################################