X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=cnn-svrt.py;h=a41d42c41869009588f2c3e4d977503e0220e3bd;hb=9f566b0dc34a2caac941e33e12ca529ef887a171;hp=704b003284252601960ddb81f73870f7142456e0;hpb=8cdea45a2a54fa619d670ff30cfdf96308853f7e;p=pysvrt.git diff --git a/cnn-svrt.py b/cnn-svrt.py index 704b003..a41d42c 100755 --- a/cnn-svrt.py +++ b/cnn-svrt.py @@ -25,6 +25,7 @@ import time import argparse import math import distutils.util +import re from colorama import Fore, Back, Style @@ -56,6 +57,13 @@ parser.add_argument('--nb_train_samples', parser.add_argument('--nb_test_samples', type = int, default = 10000) +parser.add_argument('--nb_validation_samples', + type = int, default = 10000) + +parser.add_argument('--validation_error_threshold', + type = float, default = 0.0, + help = 'Early training termination criterion') + parser.add_argument('--nb_epochs', type = int, default = 50) @@ -77,19 +85,25 @@ parser.add_argument('--test_loaded_models', type = distutils.util.strtobool, default = 'False', help = 'Should we compute the test errors of loaded models') +parser.add_argument('--problems', + type = str, default = '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23', + help = 'What problems to process') + args = parser.parse_args() ###################################################################### -log_file = open(args.log_file, 'w') +log_file = open(args.log_file, 'a') pred_log_t = None +last_tag_t = time.time() print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL) # Log and prints the string, with a time stamp. Does not log the # remark + def log_string(s, remark = ''): - global pred_log_t + global pred_log_t, last_tag_t t = time.time() @@ -100,10 +114,14 @@ def log_string(s, remark = ''): pred_log_t = t - log_file.write('[' + time.ctime() + '] ' + elapsed + ' ' + s + '\n') + if t > last_tag_t + 3600: + last_tag_t = t + print(Fore.RED + time.ctime() + Style.RESET_ALL) + + log_file.write(re.sub(' ', '_', time.ctime()) + ' ' + elapsed + ' ' + s + '\n') log_file.flush() - print(Fore.BLUE + '[' + time.ctime() + '] ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL) + print(Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL) ###################################################################### @@ -190,7 +208,22 @@ class AfrozeDeepNet(nn.Module): ###################################################################### -def train_model(model, train_set): +def nb_errors(model, data_set): + ne = 0 + for b in range(0, data_set.nb_batches): + input, target = data_set.get_batch(b) + output = model.forward(Variable(input)) + wta_prediction = output.data.max(1)[1].view(-1) + + for i in range(0, data_set.batch_size): + if wta_prediction[i] != target[i]: + ne = ne + 1 + + return ne + +###################################################################### + +def train_model(model, train_set, validation_set): batch_size = args.batch_size criterion = nn.CrossEntropyLoss() @@ -212,25 +245,24 @@ def train_model(model, train_set): loss.backward() optimizer.step() dt = (time.time() - start_t) / (e + 1) + log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss), ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']') - return model + if validation_set is not None: + nb_validation_errors = nb_errors(model, validation_set) -###################################################################### + log_string('validation_error {:.02f}% {:d} {:d}'.format( + 100 * nb_validation_errors / validation_set.nb_samples, + nb_validation_errors, + validation_set.nb_samples) + ) -def nb_errors(model, data_set): - ne = 0 - for b in range(0, data_set.nb_batches): - input, target = data_set.get_batch(b) - output = model.forward(Variable(input)) - wta_prediction = output.data.max(1)[1].view(-1) + if nb_validation_errors / validation_set.nb_samples <= args.validation_error_threshold: + log_string('below validation_error_threshold') + break - for i in range(0, data_set.batch_size): - if wta_prediction[i] != target[i]: - ne = ne + 1 - - return ne + return model ###################################################################### @@ -269,6 +301,8 @@ if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_ print('The number of samples must be a multiple of the batch size.') raise +log_string('############### start ###############') + if args.compress_vignettes: log_string('using_compressed_vignettes') VignetteSet = svrtset.CompressedVignetteSet @@ -276,7 +310,7 @@ else: log_string('using_uncompressed_vignettes') VignetteSet = svrtset.VignetteSet -for problem_number in range(1, 24): +for problem_number in map(int, args.problems.split(',')): log_string('############### problem ' + str(problem_number) + ' ###############') @@ -323,7 +357,15 @@ for problem_number in range(1, 24): train_set.nb_samples / (time.time() - t)) ) - train_model(model, train_set) + if args.validation_error_threshold > 0.0: + validation_set = VignetteSet(problem_number, + args.nb_validation_samples, args.batch_size, + cuda = torch.cuda.is_available(), + logger = vignette_logger()) + else: + validation_set = None + + train_model(model, train_set, validation_set) torch.save(model.state_dict(), model_filename) log_string('saved_model ' + model_filename) @@ -347,10 +389,6 @@ for problem_number in range(1, 24): args.nb_test_samples, args.batch_size, cuda = torch.cuda.is_available()) - log_string('data_generation {:0.2f} samples / s'.format( - test_set.nb_samples / (time.time() - t)) - ) - nb_test_errors = nb_errors(model, test_set) log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(