X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=cnn-svrt.py;h=7fe2db2d569aefc92e904ff21d11cdfec8e04d61;hb=e9f012349010d2a4f5d2ed0869974611fade32f1;hp=f3d350eb9ea9203f408a0603e7c0458b88801e95;hpb=aca8ab8e7d30f1f79829d57897238469df5468b0;p=pysvrt.git diff --git a/cnn-svrt.py b/cnn-svrt.py index f3d350e..7fe2db2 100755 --- a/cnn-svrt.py +++ b/cnn-svrt.py @@ -25,18 +25,22 @@ import time import argparse import math import distutils.util +import re from colorama import Fore, Back, Style # Pytorch import torch +import torchvision from torch import optim +from torch import multiprocessing from torch import FloatTensor as Tensor from torch.autograd import Variable from torch import nn from torch.nn import functional as fn + from torchvision import datasets, transforms, utils # SVRT @@ -72,6 +76,9 @@ parser.add_argument('--batch_size', parser.add_argument('--log_file', type = str, default = 'default.log') +parser.add_argument('--nb_exemplar_vignettes', + type = int, default = -1) + parser.add_argument('--compress_vignettes', type = distutils.util.strtobool, default = 'True', help = 'Use lossless compression to reduce the memory footprint') @@ -94,13 +101,15 @@ args = parser.parse_args() log_file = open(args.log_file, 'a') pred_log_t = None +last_tag_t = time.time() print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL) # Log and prints the string, with a time stamp. Does not log the # remark + def log_string(s, remark = ''): - global pred_log_t + global pred_log_t, last_tag_t t = time.time() @@ -111,10 +120,14 @@ def log_string(s, remark = ''): pred_log_t = t - log_file.write('[' + time.ctime() + '] ' + elapsed + ' ' + s + '\n') + if t > last_tag_t + 3600: + last_tag_t = t + print(Fore.RED + time.ctime() + Style.RESET_ALL) + + log_file.write(re.sub(' ', '_', time.ctime()) + ' ' + elapsed + ' ' + s + '\n') log_file.flush() - print(Fore.BLUE + '[' + time.ctime() + '] ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL) + print(Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL) ###################################################################### @@ -288,6 +301,21 @@ class vignette_logger(): ) self.last_t = t +def save_examplar_vignettes(data_set, nb, name): + n = torch.randperm(data_set.nb_samples).narrow(0, 0, nb) + + for k in range(0, nb): + b = n[k] // data_set.batch_size + m = n[k] % data_set.batch_size + i, t = data_set.get_batch(b) + i = i[m].float() + i.sub_(i.min()) + i.div_(i.max()) + if k == 0: patchwork = Tensor(nb, 1, i.size(1), i.size(2)) + patchwork[k].copy_(i) + + torchvision.utils.save_image(patchwork, name) + ###################################################################### if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0: @@ -350,6 +378,10 @@ for problem_number in map(int, args.problems.split(',')): train_set.nb_samples / (time.time() - t)) ) + if args.nb_exemplar_vignettes > 0: + save_examplar_vignettes(train_set, args.nb_exemplar_vignettes, + 'examplar_{:d}.png'.format(problem_number)) + if args.validation_error_threshold > 0.0: validation_set = VignetteSet(problem_number, args.nb_validation_samples, args.batch_size,