X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pysvrt.git;a=blobdiff_plain;f=cnn-svrt.py;h=c1fe3acd7a34b5a341c6a261e6ee5eda1da59d4e;hp=f731c2b7f0210ff3146fa88c49c16f85f8758344;hb=3f17c323d2725ca189fa3502b89e9833cf6caa25;hpb=131f57030ff7f533c2aa07e92bcd91630b8430bf diff --git a/cnn-svrt.py b/cnn-svrt.py index f731c2b..c1fe3ac 100755 --- a/cnn-svrt.py +++ b/cnn-svrt.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python-for-pytorch +#!/usr/bin/env python # svrt is the ``Synthetic Visual Reasoning Test'', an image # generator for evaluating classification performance of machine @@ -90,10 +90,10 @@ def print_test_error(model, test_input, test_target): for b in range(0, nb_test_samples, bs): output = model.forward(test_input.narrow(0, b, bs)) - _, wta = torch.max(output.data, 1) + wta_prediction = output.data.max(1)[1].view(-1) for i in range(0, bs): - if wta[i][0] != test_target.narrow(0, b, bs).data[i]: + if wta_prediction[i] != test_target.narrow(0, b, bs).data[i]: nb_test_errors = nb_test_errors + 1 print('TEST_ERROR {:.02f}% ({:d}/{:d})'.format( @@ -113,6 +113,7 @@ for p in range(1, 24): t1 = time.time() train_input, train_target = generate_set(p, nb_train_samples) test_input, test_target = generate_set(p, nb_test_samples) + if torch.cuda.is_available(): train_input, train_target = train_input.cuda(), train_target.cuda() test_input, test_target = test_input.cuda(), test_target.cuda()