From: Francois Fleuret Date: Thu, 15 Jun 2017 22:27:31 +0000 (+0200) Subject: Make the name of the saved model more explicit. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=c8bc2db12cddbf90b851798cd101632e7e9511ba;p=pysvrt.git Make the name of the saved model more explicit. --- diff --git a/cnn-svrt.py b/cnn-svrt.py index 084606a..a2ab1a3 100755 --- a/cnn-svrt.py +++ b/cnn-svrt.py @@ -107,6 +107,7 @@ class AfrozeShallowNet(nn.Module): self.conv3 = nn.Conv2d(16, 120, kernel_size=18) self.fc1 = nn.Linear(120, 84) self.fc2 = nn.Linear(84, 2) + self.name = 'shallownet' def forward(self, x): x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2)) @@ -117,6 +118,8 @@ class AfrozeShallowNet(nn.Module): x = self.fc2(x) return x +###################################################################### + def train_model(model, train_set): batch_size = args.batch_size criterion = nn.CrossEntropyLoss() @@ -178,7 +181,7 @@ for problem_number in range(1, 24): nb_parameters += p.numel() log_string('nb_parameters {:d}'.format(nb_parameters)) - model_filename = 'model_' + str(problem_number) + '.param' + model_filename = model.name + '_' + str(problem_number) + '_' + str(train_set.nb_batches) + '.param' try: model.load_state_dict(torch.load(model_filename))