Fixed the interpreter name.
[pysvrt.git] / cnn-svrt.py
index d5685f4..c1fe3ac 100755 (executable)
@@ -1,4 +1,25 @@
-#!/usr/bin/env python-for-pytorch
+#!/usr/bin/env python
+
+#  svrt is the ``Synthetic Visual Reasoning Test'', an image
+#  generator for evaluating classification performance of machine
+#  learning systems, humans and primates.
+#
+#  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
+#  Written by Francois Fleuret <francois.fleuret@idiap.ch>
+#
+#  This file is part of svrt.
+#
+#  svrt is free software: you can redistribute it and/or modify it
+#  under the terms of the GNU General Public License version 3 as
+#  published by the Free Software Foundation.
+#
+#  svrt is distributed in the hope that it will be useful, but
+#  WITHOUT ANY WARRANTY; without even the implied warranty of
+#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+#  General Public License for more details.
+#
+#  You should have received a copy of the GNU General Public License
+#  along with selector.  If not, see <http://www.gnu.org/licenses/>.
 
 import time
 
@@ -11,10 +32,9 @@ from torch import nn
 from torch.nn import functional as fn
 from torchvision import datasets, transforms, utils
 
-from _ext import svrt
+import svrt
 
 ######################################################################
-# The data
 
 def generate_set(p, n):
     target = torch.LongTensor(n).bernoulli_(0.5)
@@ -70,10 +90,10 @@ def print_test_error(model, test_input, test_target):
 
     for b in range(0, nb_test_samples, bs):
         output = model.forward(test_input.narrow(0, b, bs))
-        _, wta = torch.max(output.data, 1)
+        wta_prediction = output.data.max(1)[1].view(-1)
 
         for i in range(0, bs):
-            if wta[i][0] != test_target.narrow(0, b, bs).data[i]:
+            if wta_prediction[i] != test_target.narrow(0, b, bs).data[i]:
                 nb_test_errors = nb_test_errors + 1
 
     print('TEST_ERROR {:.02f}% ({:d}/{:d})'.format(
@@ -93,6 +113,7 @@ for p in range(1, 24):
     t1 = time.time()
     train_input, train_target = generate_set(p, nb_train_samples)
     test_input, test_target = generate_set(p, nb_test_samples)
+
     if torch.cuda.is_available():
         train_input, train_target = train_input.cuda(), train_target.cuda()
         test_input, test_target = test_input.cuda(), test_target.cuda()