Cosmetics.
authorFrancois Fleuret <francois@fleuret.org>
Thu, 15 Jun 2017 21:37:20 +0000 (23:37 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Thu, 15 Jun 2017 21:37:20 +0000 (23:37 +0200)
cnn-svrt.py

index bbce4c9..8840c4b 100755 (executable)
@@ -67,7 +67,7 @@ parser.add_argument('--log_file',
 
 parser.add_argument('--compress_vignettes',
                     action='store_true', default = False,
-                    help = 'Should we use lossless compression of vignette to reduce the memory footprint')
+                    help = 'Use lossless compression to reduce the memory footprint')
 
 args = parser.parse_args()
 
@@ -97,26 +97,28 @@ class VignetteSet:
         acc = 0.0
         acc_sq = 0.0
 
-        for k in range(0, self.nb_batches):
+        for b in range(0, self.nb_batches):
             target = torch.LongTensor(self.batch_size).bernoulli_(0.5)
             input = svrt.generate_vignettes(problem_number, target)
             input = input.float().view(input.size(0), 1, input.size(1), input.size(2))
             if torch.cuda.is_available():
                 input = input.cuda()
                 target = target.cuda()
-            acc += input.float().sum() / input.numel()
-            acc_sq += input.float().pow(2).sum() /  input.numel()
+            acc += input.sum() / input.numel()
+            acc_sq += input.pow(2).sum() /  input.numel()
             self.targets.append(target)
             self.inputs.append(input)
 
         mean = acc / self.nb_batches
         std = math.sqrt(acc_sq / self.nb_batches - mean * mean)
-        for k in range(0, self.nb_batches):
-            self.inputs[k].sub_(mean).div_(std)
+        for b in range(0, self.nb_batches):
+            self.inputs[b].sub_(mean).div_(std)
 
     def get_batch(self, b):
         return self.inputs[b], self.targets[b]
 
+######################################################################
+
 class CompressedVignetteSet:
     def __init__(self, problem_number, nb_batches):
         self.batch_size = args.batch_size
@@ -128,7 +130,7 @@ class CompressedVignetteSet:
 
         acc = 0.0
         acc_sq = 0.0
-        for k in range(0, self.nb_batches):
+        for b in range(0, self.nb_batches):
             target = torch.LongTensor(self.batch_size).bernoulli_(0.5)
             input = svrt.generate_vignettes(problem_number, target)
             acc += input.float().sum() / input.numel()
@@ -193,7 +195,7 @@ def train_model(model, train_set):
 
     optimizer = optim.SGD(model.parameters(), lr = 1e-2)
 
-    for k in range(0, args.nb_epochs):
+    for e in range(0, args.nb_epochs):
         acc_loss = 0.0
         for b in range(0, train_set.nb_batches):
             input, target = train_set.get_batch(b)
@@ -203,7 +205,7 @@ def train_model(model, train_set):
             model.zero_grad()
             loss.backward()
             optimizer.step()
-        log_string('train_loss {:d} {:f}'.format(k, acc_loss))
+        log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss))
 
     return model