X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pysvrt.git;a=blobdiff_plain;f=cnn-svrt.py;h=966ea4a60262e71e7c6a5fd53347f9580058cb9b;hp=fb1cad9f16054d1c66f502fe1052f7bf3c45ec8c;hb=a25efec7618acd54806bd6ce69b30d473b7845f8;hpb=3c9b04dccaaf2a42cca35d5ea266f442cbb726ea diff --git a/cnn-svrt.py b/cnn-svrt.py index fb1cad9..966ea4a 100755 --- a/cnn-svrt.py +++ b/cnn-svrt.py @@ -24,8 +24,10 @@ import time import argparse import math + import distutils.util import re +import signal from colorama import Fore, Back, Style @@ -83,6 +85,9 @@ parser.add_argument('--compress_vignettes', type = distutils.util.strtobool, default = 'True', help = 'Use lossless compression to reduce the memory footprint') +parser.add_argument('--save_test_mistakes', + type = distutils.util.strtobool, default = 'False') + parser.add_argument('--model', type = str, default = 'deepnet', help = 'What model to use') @@ -127,7 +132,24 @@ def log_string(s, remark = ''): log_file.write(re.sub(' ', '_', time.ctime()) + ' ' + elapsed + ' ' + s + '\n') log_file.flush() - print(Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL) + print(Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed \ + + Style.RESET_ALL + + ' ' \ + + s + Fore.CYAN + remark \ + + Style.RESET_ALL) + +###################################################################### + +def handler_sigint(signum, frame): + log_string('got sigint') + exit(0) + +def handler_sigterm(signum, frame): + log_string('got sigterm') + exit(0) + +signal.signal(signal.SIGINT, handler_sigint) +signal.signal(signal.SIGTERM, handler_sigterm) ###################################################################### @@ -223,13 +245,62 @@ class DeepNet2(nn.Module): def __init__(self): super(DeepNet2, self).__init__() self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3) + self.conv2 = nn.Conv2d( 32, 256, kernel_size=5, padding=2) + self.conv3 = nn.Conv2d(256, 256, kernel_size=3, padding=1) + self.conv4 = nn.Conv2d(256, 256, kernel_size=3, padding=1) + self.conv5 = nn.Conv2d(256, 256, kernel_size=3, padding=1) + self.fc1 = nn.Linear(4096, 512) + self.fc2 = nn.Linear(512, 512) + self.fc3 = nn.Linear(512, 2) + + def forward(self, x): + x = self.conv1(x) + x = fn.max_pool2d(x, kernel_size=2) + x = fn.relu(x) + + x = self.conv2(x) + x = fn.max_pool2d(x, kernel_size=2) + x = fn.relu(x) + + x = self.conv3(x) + x = fn.relu(x) + + x = self.conv4(x) + x = fn.relu(x) + + x = self.conv5(x) + x = fn.max_pool2d(x, kernel_size=2) + x = fn.relu(x) + + x = x.view(-1, 4096) + + x = self.fc1(x) + x = fn.relu(x) + + x = self.fc2(x) + x = fn.relu(x) + + x = self.fc3(x) + + return x + +###################################################################### + +class DeepNet3(nn.Module): + name = 'deepnet3' + + def __init__(self): + super(DeepNet3, self).__init__() + self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3) self.conv2 = nn.Conv2d( 32, 128, kernel_size=5, padding=2) self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1) self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1) self.conv5 = nn.Conv2d(128, 128, kernel_size=3, padding=1) - self.fc1 = nn.Linear(2048, 512) - self.fc2 = nn.Linear(512, 512) - self.fc3 = nn.Linear(512, 2) + self.conv6 = nn.Conv2d(128, 128, kernel_size=3, padding=1) + self.conv7 = nn.Conv2d(128, 128, kernel_size=3, padding=1) + self.fc1 = nn.Linear(2048, 256) + self.fc2 = nn.Linear(256, 256) + self.fc3 = nn.Linear(256, 2) def forward(self, x): x = self.conv1(x) @@ -250,6 +321,12 @@ class DeepNet2(nn.Module): x = fn.max_pool2d(x, kernel_size=2) x = fn.relu(x) + x = self.conv6(x) + x = fn.relu(x) + + x = self.conv7(x) + x = fn.relu(x) + x = x.view(-1, 2048) x = self.fc1(x) @@ -264,7 +341,7 @@ class DeepNet2(nn.Module): ###################################################################### -def nb_errors(model, data_set): +def nb_errors(model, data_set, mistake_filename_pattern = None): ne = 0 for b in range(0, data_set.nb_batches): input, target = data_set.get_batch(b) @@ -274,7 +351,14 @@ def nb_errors(model, data_set): for i in range(0, data_set.batch_size): if wta_prediction[i] != target[i]: ne = ne + 1 - + if mistake_filename_pattern is not None: + img = input[i].clone() + img.sub_(img.min()) + img.div_(img.max()) + k = b * data_set.batch_size + i + filename = mistake_filename_pattern.format(k, target[i]) + torchvision.utils.save_image(img, filename) + print(Fore.RED + 'Wrote ' + filename + Style.RESET_ALL) return ne ###################################################################### @@ -385,7 +469,7 @@ else: ######################################## model_class = None -for m in [ AfrozeShallowNet, AfrozeDeepNet, DeepNet2 ]: +for m in [ AfrozeShallowNet, AfrozeDeepNet, DeepNet2, DeepNet3 ]: if args.model == m.name: model_class = m break @@ -415,7 +499,6 @@ for problem_number in map(int, args.problems.split(',')): ################################################## # Tries to load the model - need_to_train = False try: model_state_dict, nb_epochs_done = torch.load(model_filename) model.load_state_dict(model_state_dict) @@ -469,7 +552,7 @@ for problem_number in map(int, args.problems.split(',')): ################################################## # Test if necessary - if need_to_train or args.test_loaded_models: + if nb_epochs_done < args.nb_epochs or args.test_loaded_models: t = time.time() @@ -477,7 +560,8 @@ for problem_number in map(int, args.problems.split(',')): args.nb_test_samples, args.batch_size, cuda = torch.cuda.is_available()) - nb_test_errors = nb_errors(model, test_set) + nb_test_errors = nb_errors(model, test_set, + mistake_filename_pattern = 'mistake_{:06d}_{:d}.png') log_string('test_error {:d} {:.02f}% {:d} {:d}'.format( problem_number,