X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=cnn-svrt.py;h=fb1cad9f16054d1c66f502fe1052f7bf3c45ec8c;hb=3c9b04dccaaf2a42cca35d5ea266f442cbb726ea;hp=07c11b3e011d3ca9bc38dca6c0603d01d3b0a6e6;hpb=cdcaa8361d43a497d50c2f3703f9f5b4be9c2298;p=pysvrt.git diff --git a/cnn-svrt.py b/cnn-svrt.py index 07c11b3..fb1cad9 100755 --- a/cnn-svrt.py +++ b/cnn-svrt.py @@ -229,7 +229,7 @@ class DeepNet2(nn.Module): self.conv5 = nn.Conv2d(128, 128, kernel_size=3, padding=1) self.fc1 = nn.Linear(2048, 512) self.fc2 = nn.Linear(512, 512) - self.fc3 = nn.Linear(256, 2) + self.fc3 = nn.Linear(512, 2) def forward(self, x): x = self.conv1(x) @@ -279,7 +279,7 @@ def nb_errors(model, data_set): ###################################################################### -def train_model(model, train_set, validation_set): +def train_model(model, model_filename, train_set, validation_set, nb_epochs_done = 0): batch_size = args.batch_size criterion = nn.CrossEntropyLoss() @@ -290,7 +290,7 @@ def train_model(model, train_set, validation_set): start_t = time.time() - for e in range(0, args.nb_epochs): + for e in range(nb_epochs_done, args.nb_epochs): acc_loss = 0.0 for b in range(0, train_set.nb_batches): input, target = train_set.get_batch(b) @@ -305,6 +305,8 @@ def train_model(model, train_set, validation_set): log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss), ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']') + torch.save([ model.state_dict(), e + 1 ], model_filename) + if validation_set is not None: nb_validation_errors = nb_errors(model, validation_set) @@ -404,7 +406,7 @@ for problem_number in map(int, args.problems.split(',')): model_filename = model.name + '_pb:' + \ str(problem_number) + '_ns:' + \ - int_to_suffix(args.nb_train_samples) + '.param' + int_to_suffix(args.nb_train_samples) + '.state' nb_parameters = 0 for p in model.parameters(): nb_parameters += p.numel() @@ -415,15 +417,17 @@ for problem_number in map(int, args.problems.split(',')): need_to_train = False try: - model.load_state_dict(torch.load(model_filename)) + model_state_dict, nb_epochs_done = torch.load(model_filename) + model.load_state_dict(model_state_dict) log_string('loaded_model ' + model_filename) except: - need_to_train = True + nb_epochs_done = 0 + ################################################## # Train if necessary - if need_to_train: + if nb_epochs_done < args.nb_epochs: log_string('training_model ' + model_filename) @@ -450,8 +454,7 @@ for problem_number in map(int, args.problems.split(',')): else: validation_set = None - train_model(model, train_set, validation_set) - torch.save(model.state_dict(), model_filename) + train_model(model, model_filename, train_set, validation_set, nb_epochs_done = nb_epochs_done) log_string('saved_model ' + model_filename) nb_train_errors = nb_errors(model, train_set)