-#!/usr/bin/env python-for-pytorch
+#!/usr/bin/env python
+
+# svrt is the ``Synthetic Visual Reasoning Test'', an image
+# generator for evaluating classification performance of machine
+# learning systems, humans and primates.
+#
+# Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
+# Written by Francois Fleuret <francois.fleuret@idiap.ch>
+#
+# This file is part of svrt.
+#
+# svrt is free software: you can redistribute it and/or modify it
+# under the terms of the GNU General Public License version 3 as
+# published by the Free Software Foundation.
+#
+# svrt is distributed in the hope that it will be useful, but
+# WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with selector. If not, see <http://www.gnu.org/licenses/>.
import time
from torch.nn import functional as fn
from torchvision import datasets, transforms, utils
-from _ext import svrt
+import svrt
######################################################################
-# The data
def generate_set(p, n):
target = torch.LongTensor(n).bernoulli_(0.5)
for b in range(0, nb_test_samples, bs):
output = model.forward(test_input.narrow(0, b, bs))
- _, wta = torch.max(output.data, 1)
+ wta_prediction = output.data.max(1)[1].view(-1)
for i in range(0, bs):
- if wta[i][0] != test_target.narrow(0, b, bs).data[i]:
+ if wta_prediction[i] != test_target.narrow(0, b, bs).data[i]:
nb_test_errors = nb_test_errors + 1
print('TEST_ERROR {:.02f}% ({:d}/{:d})'.format(
t1 = time.time()
train_input, train_target = generate_set(p, nb_train_samples)
test_input, test_target = generate_set(p, nb_test_samples)
+
if torch.cuda.is_available():
train_input, train_target = train_input.cuda(), train_target.cuda()
test_input, test_target = test_input.cuda(), test_target.cuda()