X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pysvrt.git;a=blobdiff_plain;f=cnn-svrt.py;h=966ea4a60262e71e7c6a5fd53347f9580058cb9b;hp=8baaacbc4fe7b4b37d8e6be27dd229b5b44bf6cc;hb=a25efec7618acd54806bd6ce69b30d473b7845f8;hpb=349b55a2d9ca213718df8941058d42689ba68163 diff --git a/cnn-svrt.py b/cnn-svrt.py index 8baaacb..966ea4a 100755 --- a/cnn-svrt.py +++ b/cnn-svrt.py @@ -24,8 +24,10 @@ import time import argparse import math + import distutils.util import re +import signal from colorama import Fore, Back, Style @@ -83,6 +85,9 @@ parser.add_argument('--compress_vignettes', type = distutils.util.strtobool, default = 'True', help = 'Use lossless compression to reduce the memory footprint') +parser.add_argument('--save_test_mistakes', + type = distutils.util.strtobool, default = 'False') + parser.add_argument('--model', type = str, default = 'deepnet', help = 'What model to use') @@ -127,7 +132,24 @@ def log_string(s, remark = ''): log_file.write(re.sub(' ', '_', time.ctime()) + ' ' + elapsed + ' ' + s + '\n') log_file.flush() - print(Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL) + print(Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed \ + + Style.RESET_ALL + + ' ' \ + + s + Fore.CYAN + remark \ + + Style.RESET_ALL) + +###################################################################### + +def handler_sigint(signum, frame): + log_string('got sigint') + exit(0) + +def handler_sigterm(signum, frame): + log_string('got sigterm') + exit(0) + +signal.signal(signal.SIGINT, handler_sigint) +signal.signal(signal.SIGTERM, handler_sigterm) ###################################################################### @@ -319,7 +341,7 @@ class DeepNet3(nn.Module): ###################################################################### -def nb_errors(model, data_set): +def nb_errors(model, data_set, mistake_filename_pattern = None): ne = 0 for b in range(0, data_set.nb_batches): input, target = data_set.get_batch(b) @@ -329,7 +351,14 @@ def nb_errors(model, data_set): for i in range(0, data_set.batch_size): if wta_prediction[i] != target[i]: ne = ne + 1 - + if mistake_filename_pattern is not None: + img = input[i].clone() + img.sub_(img.min()) + img.div_(img.max()) + k = b * data_set.batch_size + i + filename = mistake_filename_pattern.format(k, target[i]) + torchvision.utils.save_image(img, filename) + print(Fore.RED + 'Wrote ' + filename + Style.RESET_ALL) return ne ###################################################################### @@ -531,7 +560,8 @@ for problem_number in map(int, args.problems.split(',')): args.nb_test_samples, args.batch_size, cuda = torch.cuda.is_available()) - nb_test_errors = nb_errors(model, test_set) + nb_test_errors = nb_errors(model, test_set, + mistake_filename_pattern = 'mistake_{:06d}_{:d}.png') log_string('test_error {:d} {:.02f}% {:d} {:d}'.format( problem_number,