-#!/usr/bin/env python-for-pytorch
+#!/usr/bin/env python
# svrt is the ``Synthetic Visual Reasoning Test'', an image
# generator for evaluating classification performance of machine
for b in range(0, nb_test_samples, bs):
output = model.forward(test_input.narrow(0, b, bs))
- _, wta = torch.max(output.data, 1)
+ wta_prediction = output.data.max(1)[1].view(-1)
for i in range(0, bs):
- if wta[i][0] != test_target.narrow(0, b, bs).data[i]:
+ if wta_prediction[i] != test_target.narrow(0, b, bs).data[i]:
nb_test_errors = nb_test_errors + 1
print('TEST_ERROR {:.02f}% ({:d}/{:d})'.format(
t1 = time.time()
train_input, train_target = generate_set(p, nb_train_samples)
test_input, test_target = generate_set(p, nb_test_samples)
+
if torch.cuda.is_available():
train_input, train_target = train_input.cuda(), train_target.cuda()
test_input, test_target = test_input.cuda(), test_target.cuda()