From: Francois Fleuret Date: Mon, 25 Jul 2022 19:03:46 +0000 (+0200) Subject: Added --no_checkpoint X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=6cefbf0f80402effa39d13ed6952f5b08a0688e1;p=mygpt.git Added --no_checkpoint --- diff --git a/main.py b/main.py index 77c4b9e..f496e99 100755 --- a/main.py +++ b/main.py @@ -69,6 +69,9 @@ parser.add_argument('--dropout', parser.add_argument('--synthesis_sampling', action='store_true', default = True) +parser.add_argument('--no_checkpoint', + action='store_true', default = False) + parser.add_argument('--checkpoint_name', type = str, default = 'checkpoint.pth') @@ -429,19 +432,23 @@ else: nb_epochs_finished = 0 -try: - checkpoint = torch.load(args.checkpoint_name, map_location = device) - nb_epochs_finished = checkpoint['nb_epochs_finished'] - model.load_state_dict(checkpoint['model_state']) - optimizer.load_state_dict(checkpoint['optimizer_state']) - print(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.') +if args.no_checkpoint: + log_string(f'Not trying to load checkpoint.') -except FileNotFoundError: - print('Starting from scratch.') - -except: - print('Error when loading the checkpoint.') - exit(1) +else: + try: + checkpoint = torch.load(args.checkpoint_name, map_location = device) + nb_epochs_finished = checkpoint['nb_epochs_finished'] + model.load_state_dict(checkpoint['model_state']) + optimizer.load_state_dict(checkpoint['optimizer_state']) + log_string(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.') + + except FileNotFoundError: + log_string('Starting from scratch.') + + except: + log_string('Error when loading the checkpoint.') + exit(1) ######################################################################