X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=11cf0a31401eb0c21aff0875564f80d2e7270db3;hb=c8d0cf6842db19f84a78c1b3a4d2666b323a5d4a;hp=3bf7587a5f1e45bee2530ca56ce016738b95d117;hpb=046f35f38d629c9854104e855a53f0142449138f;p=mygpt.git diff --git a/main.py b/main.py index 3bf7587..11cf0a3 100755 --- a/main.py +++ b/main.py @@ -25,7 +25,7 @@ parser.add_argument('--log_filename', type = str, default = 'train.log') parser.add_argument('--download', - type = bool, default = False) + action='store_true', default = False) parser.add_argument('--seed', type = int, default = 0) @@ -67,7 +67,13 @@ parser.add_argument('--dropout', type = float, default = 0.1) parser.add_argument('--synthesis_sampling', - type = bool, default = True) + action='store_true', default = True) + +parser.add_argument('--checkpoint_name', + type = str, default = 'checkpoint.pth') + +parser.add_argument('--picoclvr_many_colors', + action='store_true', default = False) ###################################################################### @@ -350,7 +356,7 @@ if args.data == 'wiki103': elif args.data == 'mnist': task = TaskMNIST(batch_size = args.batch_size, device = device) elif args.data == 'picoclvr': - task = TaskPicoCLVR(batch_size = args.batch_size, device = device) + task = TaskPicoCLVR(batch_size = args.batch_size, many_colors = args.picoclvr_many_colors, device = device) else: raise ValueError(f'Unknown dataset {args.data}.') @@ -366,11 +372,11 @@ model = mygpt.MyGPT( nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout ) +model.to(device) + nb_parameters = sum(p.numel() for p in model.parameters()) log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)') -model.to(device) - ###################################################################### if args.optim == 'sgd': @@ -382,7 +388,27 @@ elif args.optim == 'adamw': else: raise ValueError(f'Unknown optimizer {args.optim}.') -for k in range(args.nb_epochs): +###################################################################### + +nb_epochs_finished = 0 + +try: + checkpoint = torch.load(args.checkpoint_name, map_location = device) + nb_epochs_finished = checkpoint['nb_epochs_finished'] + model.load_state_dict(checkpoint['model_state']) + optimizer.load_state_dict(checkpoint['optimizer_state']) + print(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.') + +except FileNotFoundError: + print('Starting from scratch.') + +except: + print('Error when loading the checkpoint.') + exit(1) + +###################################################################### + +for k in range(nb_epochs_finished, args.nb_epochs): model.train() @@ -419,4 +445,12 @@ for k in range(args.nb_epochs): task.produce_results(k, model) + checkpoint = { + 'nb_epochs_finished': k + 1, + 'model_state': model.state_dict(), + 'optimizer_state': optimizer.state_dict() + } + + torch.save(checkpoint, args.checkpoint_name) + ######################################################################