projects
/
pysvrt.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
131f570
)
Fixed the interpreter name.
author
Francois Fleuret
<francois@fleuret.org>
Thu, 15 Jun 2017 10:06:18 +0000
(12:06 +0200)
committer
Francois Fleuret
<francois@fleuret.org>
Thu, 15 Jun 2017 10:06:18 +0000
(12:06 +0200)
cnn-svrt.py
patch
|
blob
|
history
diff --git
a/cnn-svrt.py
b/cnn-svrt.py
index
f731c2b
..
c1fe3ac
100755
(executable)
--- a/
cnn-svrt.py
+++ b/
cnn-svrt.py
@@
-1,4
+1,4
@@
-#!/usr/bin/env python
-for-pytorch
+#!/usr/bin/env python
# svrt is the ``Synthetic Visual Reasoning Test'', an image
# generator for evaluating classification performance of machine
# svrt is the ``Synthetic Visual Reasoning Test'', an image
# generator for evaluating classification performance of machine
@@
-90,10
+90,10
@@
def print_test_error(model, test_input, test_target):
for b in range(0, nb_test_samples, bs):
output = model.forward(test_input.narrow(0, b, bs))
for b in range(0, nb_test_samples, bs):
output = model.forward(test_input.narrow(0, b, bs))
-
_, wta = torch.max(output.data,
1)
+
wta_prediction = output.data.max(1)[1].view(-
1)
for i in range(0, bs):
for i in range(0, bs):
- if wta
[i][0
] != test_target.narrow(0, b, bs).data[i]:
+ if wta
_prediction[i
] != test_target.narrow(0, b, bs).data[i]:
nb_test_errors = nb_test_errors + 1
print('TEST_ERROR {:.02f}% ({:d}/{:d})'.format(
nb_test_errors = nb_test_errors + 1
print('TEST_ERROR {:.02f}% ({:d}/{:d})'.format(
@@
-113,6
+113,7
@@
for p in range(1, 24):
t1 = time.time()
train_input, train_target = generate_set(p, nb_train_samples)
test_input, test_target = generate_set(p, nb_test_samples)
t1 = time.time()
train_input, train_target = generate_set(p, nb_train_samples)
test_input, test_target = generate_set(p, nb_test_samples)
+
if torch.cuda.is_available():
train_input, train_target = train_input.cuda(), train_target.cuda()
test_input, test_target = test_input.cuda(), test_target.cuda()
if torch.cuda.is_available():
train_input, train_target = train_input.cuda(), train_target.cuda()
test_input, test_target = test_input.cuda(), test_target.cuda()