X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=cnn-svrt.py;h=c1fe3acd7a34b5a341c6a261e6ee5eda1da59d4e;hb=3f17c323d2725ca189fa3502b89e9833cf6caa25;hp=d5685f426f518233d710598ea5f1ece4c1e7ce68;hpb=664435944d9750efb805d9a2035f1d4f4c238a25;p=pysvrt.git diff --git a/cnn-svrt.py b/cnn-svrt.py index d5685f4..c1fe3ac 100755 --- a/cnn-svrt.py +++ b/cnn-svrt.py @@ -1,4 +1,25 @@ -#!/usr/bin/env python-for-pytorch +#!/usr/bin/env python + +# svrt is the ``Synthetic Visual Reasoning Test'', an image +# generator for evaluating classification performance of machine +# learning systems, humans and primates. +# +# Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ +# Written by Francois Fleuret +# +# This file is part of svrt. +# +# svrt is free software: you can redistribute it and/or modify it +# under the terms of the GNU General Public License version 3 as +# published by the Free Software Foundation. +# +# svrt is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with selector. If not, see . import time @@ -11,10 +32,9 @@ from torch import nn from torch.nn import functional as fn from torchvision import datasets, transforms, utils -from _ext import svrt +import svrt ###################################################################### -# The data def generate_set(p, n): target = torch.LongTensor(n).bernoulli_(0.5) @@ -70,10 +90,10 @@ def print_test_error(model, test_input, test_target): for b in range(0, nb_test_samples, bs): output = model.forward(test_input.narrow(0, b, bs)) - _, wta = torch.max(output.data, 1) + wta_prediction = output.data.max(1)[1].view(-1) for i in range(0, bs): - if wta[i][0] != test_target.narrow(0, b, bs).data[i]: + if wta_prediction[i] != test_target.narrow(0, b, bs).data[i]: nb_test_errors = nb_test_errors + 1 print('TEST_ERROR {:.02f}% ({:d}/{:d})'.format( @@ -93,6 +113,7 @@ for p in range(1, 24): t1 = time.time() train_input, train_target = generate_set(p, nb_train_samples) test_input, test_target = generate_set(p, nb_test_samples) + if torch.cuda.is_available(): train_input, train_target = train_input.cuda(), train_target.cuda() test_input, test_target = test_input.cuda(), test_target.cuda()