Fixed the interpreter name.
authorFrancois Fleuret <francois@fleuret.org>
Thu, 15 Jun 2017 10:06:18 +0000 (12:06 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Thu, 15 Jun 2017 10:06:18 +0000 (12:06 +0200)
cnn-svrt.py

index f731c2b..c1fe3ac 100755 (executable)
@@ -1,4 +1,4 @@
-#!/usr/bin/env python-for-pytorch
+#!/usr/bin/env python
 
 #  svrt is the ``Synthetic Visual Reasoning Test'', an image
 #  generator for evaluating classification performance of machine
@@ -90,10 +90,10 @@ def print_test_error(model, test_input, test_target):
 
     for b in range(0, nb_test_samples, bs):
         output = model.forward(test_input.narrow(0, b, bs))
-        _, wta = torch.max(output.data, 1)
+        wta_prediction = output.data.max(1)[1].view(-1)
 
         for i in range(0, bs):
-            if wta[i][0] != test_target.narrow(0, b, bs).data[i]:
+            if wta_prediction[i] != test_target.narrow(0, b, bs).data[i]:
                 nb_test_errors = nb_test_errors + 1
 
     print('TEST_ERROR {:.02f}% ({:d}/{:d})'.format(
@@ -113,6 +113,7 @@ for p in range(1, 24):
     t1 = time.time()
     train_input, train_target = generate_set(p, nb_train_samples)
     test_input, test_target = generate_set(p, nb_test_samples)
+
     if torch.cuda.is_available():
         train_input, train_target = train_input.cuda(), train_target.cuda()
         test_input, test_target = test_input.cuda(), test_target.cuda()