projects
/
pytorch.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
663ddb2
)
Update.
author
Francois Fleuret
<francois@fleuret.org>
Thu, 15 Nov 2018 10:21:25 +0000
(11:21 +0100)
committer
Francois Fleuret
<francois@fleuret.org>
Thu, 15 Nov 2018 10:21:25 +0000
(11:21 +0100)
mine_mnist.py
patch
|
blob
|
history
diff --git
a/mine_mnist.py
b/mine_mnist.py
index
c6dc287
..
82f6530
100755
(executable)
--- a/
mine_mnist.py
+++ b/
mine_mnist.py
@@
-94,20
+94,21
@@
for e in range(nb_epochs):
input_br = input_b[torch.randperm(input_b.size(0))]
input_br = input_b[torch.randperm(input_b.size(0))]
- mi = 0.0
+
acc_
mi = 0.0
for batch_a, batch_b, batch_br in zip(input_a.split(batch_size),
input_b.split(batch_size),
input_br.split(batch_size)):
for batch_a, batch_b, batch_br in zip(input_a.split(batch_size),
input_b.split(batch_size),
input_br.split(batch_size)):
- loss = - (model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log())
- mi -= loss.item()
+ mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()
+ loss = - mi
+ acc_mi += mi.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
optimizer.zero_grad()
loss.backward()
optimizer.step()
- mi /= (input_a.size(0) // batch_size)
+
acc_
mi /= (input_a.size(0) // batch_size)
- print('%d %.04f %.04f'%(e, mi / math.log(2), class_entropy / math.log(2)))
+ print('%d %.04f %.04f'%(e,
acc_
mi / math.log(2), class_entropy / math.log(2)))
sys.stdout.flush()
sys.stdout.flush()
@@
-122,16
+123,16
@@
for e in range(nb_epochs):
input_br = input_b[torch.randperm(input_b.size(0))]
input_br = input_b[torch.randperm(input_b.size(0))]
- mi = 0.0
+
acc_
mi = 0.0
for batch_a, batch_b, batch_br in zip(input_a.split(batch_size),
input_b.split(batch_size),
input_br.split(batch_size)):
loss = - (model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log())
for batch_a, batch_b, batch_br in zip(input_a.split(batch_size),
input_b.split(batch_size),
input_br.split(batch_size)):
loss = - (model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log())
- mi -= loss.item()
+
acc_
mi -= loss.item()
- mi /= (input_a.size(0) // batch_size)
+
acc_
mi /= (input_a.size(0) // batch_size)
-print('test %.04f %.04f'%(mi / math.log(2), class_entropy / math.log(2)))
+print('test %.04f %.04f'%(
acc_
mi / math.log(2), class_entropy / math.log(2)))
######################################################################
######################################################################