From: Francois Fleuret Date: Thu, 15 Jun 2017 10:06:18 +0000 (+0200) Subject: Fixed the interpreter name. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=3f17c323d2725ca189fa3502b89e9833cf6caa25;p=pysvrt.git Fixed the interpreter name. --- diff --git a/cnn-svrt.py b/cnn-svrt.py index f731c2b..c1fe3ac 100755 --- a/cnn-svrt.py +++ b/cnn-svrt.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python-for-pytorch +#!/usr/bin/env python # svrt is the ``Synthetic Visual Reasoning Test'', an image # generator for evaluating classification performance of machine @@ -90,10 +90,10 @@ def print_test_error(model, test_input, test_target): for b in range(0, nb_test_samples, bs): output = model.forward(test_input.narrow(0, b, bs)) - _, wta = torch.max(output.data, 1) + wta_prediction = output.data.max(1)[1].view(-1) for i in range(0, bs): - if wta[i][0] != test_target.narrow(0, b, bs).data[i]: + if wta_prediction[i] != test_target.narrow(0, b, bs).data[i]: nb_test_errors = nb_test_errors + 1 print('TEST_ERROR {:.02f}% ({:d}/{:d})'.format( @@ -113,6 +113,7 @@ for p in range(1, 24): t1 = time.time() train_input, train_target = generate_set(p, nb_train_samples) test_input, test_target = generate_set(p, nb_test_samples) + if torch.cuda.is_available(): train_input, train_target = train_input.cuda(), train_target.cuda() test_input, test_target = test_input.cuda(), test_target.cuda()