X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=cnn-svrt.py;h=084606aa67b18191a969043c214e075d22825fe0;hb=c71899cfec905c50302be54725a97d7fbff08f54;hp=bbce4c92e6426d48d15676372822ff86962066ed;hpb=6c83bf23d43bdbf2a8cae2df4654b26d46d53046;p=pysvrt.git diff --git a/cnn-svrt.py b/cnn-svrt.py index bbce4c9..084606a 100755 --- a/cnn-svrt.py +++ b/cnn-svrt.py @@ -36,7 +36,7 @@ from torch import nn from torch.nn import functional as fn from torchvision import datasets, transforms, utils -import svrt +from vignette_set import VignetteSet, CompressedVignetteSet ###################################################################### @@ -67,7 +67,7 @@ parser.add_argument('--log_file', parser.add_argument('--compress_vignettes', action='store_true', default = False, - help = 'Should we use lossless compression of vignette to reduce the memory footprint') + help = 'Use lossless compression to reduce the memory footprint') args = parser.parse_args() @@ -85,73 +85,6 @@ def log_string(s): ###################################################################### -class VignetteSet: - def __init__(self, problem_number, nb_batches): - self.batch_size = args.batch_size - self.problem_number = problem_number - self.nb_batches = nb_batches - self.nb_samples = self.nb_batches * self.batch_size - self.targets = [] - self.inputs = [] - - acc = 0.0 - acc_sq = 0.0 - - for k in range(0, self.nb_batches): - target = torch.LongTensor(self.batch_size).bernoulli_(0.5) - input = svrt.generate_vignettes(problem_number, target) - input = input.float().view(input.size(0), 1, input.size(1), input.size(2)) - if torch.cuda.is_available(): - input = input.cuda() - target = target.cuda() - acc += input.float().sum() / input.numel() - acc_sq += input.float().pow(2).sum() / input.numel() - self.targets.append(target) - self.inputs.append(input) - - mean = acc / self.nb_batches - std = math.sqrt(acc_sq / self.nb_batches - mean * mean) - for k in range(0, self.nb_batches): - self.inputs[k].sub_(mean).div_(std) - - def get_batch(self, b): - return self.inputs[b], self.targets[b] - -class CompressedVignetteSet: - def __init__(self, problem_number, nb_batches): - self.batch_size = args.batch_size - self.problem_number = problem_number - self.nb_batches = nb_batches - self.nb_samples = self.nb_batches * self.batch_size - self.targets = [] - self.input_storages = [] - - acc = 0.0 - acc_sq = 0.0 - for k in range(0, self.nb_batches): - target = torch.LongTensor(self.batch_size).bernoulli_(0.5) - input = svrt.generate_vignettes(problem_number, target) - acc += input.float().sum() / input.numel() - acc_sq += input.float().pow(2).sum() / input.numel() - self.targets.append(target) - self.input_storages.append(svrt.compress(input.storage())) - - self.mean = acc / self.nb_batches - self.std = math.sqrt(acc_sq / self.nb_batches - self.mean * self.mean) - - def get_batch(self, b): - input = torch.ByteTensor(svrt.uncompress(self.input_storages[b])).float() - input = input.view(self.batch_size, 1, 128, 128).sub_(self.mean).div_(self.std) - target = self.targets[b] - - if torch.cuda.is_available(): - input = input.cuda() - target = target.cuda() - - return input, target - -###################################################################### - # Afroze's ShallowNet # map size nb. maps @@ -193,7 +126,7 @@ def train_model(model, train_set): optimizer = optim.SGD(model.parameters(), lr = 1e-2) - for k in range(0, args.nb_epochs): + for e in range(0, args.nb_epochs): acc_loss = 0.0 for b in range(0, train_set.nb_batches): input, target = train_set.get_batch(b) @@ -203,7 +136,7 @@ def train_model(model, train_set): model.zero_grad() loss.backward() optimizer.step() - log_string('train_loss {:d} {:f}'.format(k, acc_loss)) + log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss)) return model @@ -229,11 +162,11 @@ for arg in vars(args): for problem_number in range(1, 24): if args.compress_vignettes: - train_set = CompressedVignetteSet(problem_number, args.nb_train_batches) - test_set = CompressedVignetteSet(problem_number, args.nb_test_batches) + train_set = CompressedVignetteSet(problem_number, args.nb_train_batches, args.batch_size) + test_set = CompressedVignetteSet(problem_number, args.nb_test_batches, args.batch_size) else: - train_set = VignetteSet(problem_number, args.nb_train_batches) - test_set = VignetteSet(problem_number, args.nb_test_batches) + train_set = VignetteSet(problem_number, args.nb_train_batches, args.batch_size) + test_set = VignetteSet(problem_number, args.nb_test_batches, args.batch_size) model = AfrozeShallowNet()