projects
/
pytorch.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[pytorch.git]
/
attentiontoy1d.py
diff --git
a/attentiontoy1d.py
b/attentiontoy1d.py
index
cff8350
..
d7f06fe
100755
(executable)
--- a/
attentiontoy1d.py
+++ b/
attentiontoy1d.py
@@
-1,18
+1,20
@@
#!/usr/bin/env python
#!/usr/bin/env python
-#
@XREMOTE_HOST: elk.fleuret.org
-#
@XREMOTE_EXEC: /home/fleuret/conda/bin/python
-# @XREMOTE_PRE: killall -q -9 python || echo "Nothing killed"
-#
@XREMOTE_GET: *.pdf *.log
+#
Any copyright is dedicated to the Public Domain.
+#
https://creativecommons.org/publicdomain/zero/1.0/
+
+#
Written by Francois Fleuret <francois@fleuret.org>
import torch, math, sys, argparse
from torch import nn
from torch.nn import functional as F
import torch, math, sys, argparse
from torch import nn
from torch.nn import functional as F
+import matplotlib.pyplot as plt
+
######################################################################
######################################################################
-parser = argparse.ArgumentParser(description='Toy
RNN
.')
+parser = argparse.ArgumentParser(description='Toy
attention model
.')
parser.add_argument('--nb_epochs',
type = int, default = 250)
parser.add_argument('--nb_epochs',
type = int, default = 250)
@@
-146,9
+148,6
@@
def generate_sequences(nb):
######################################################################
######################################################################
-import matplotlib.pyplot as plt
-import matplotlib.collections as mc
-
def save_sequence_images(filename, sequences, tr = None, bx = None):
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
def save_sequence_images(filename, sequences, tr = None, bx = None):
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
@@
-310,8
+309,9
@@
test_input = torch.cat((test_input, positional_input.expand(test_input.size(0),
test_outputs = model((test_input - mu) / std).detach()
if args.with_attention:
test_outputs = model((test_input - mu) / std).detach()
if args.with_attention:
- x = model[0:4]((test_input - mu) / std)
- test_A = model[4].attention(x)
+ k = next(k for k, l in enumerate(model) if isinstance(l, AttentionLayer))
+ x = model[0:k]((test_input - mu) / std)
+ test_A = model[k].attention(x)
test_A = test_A.detach().to('cpu')
test_input = test_input.detach().to('cpu')
test_A = test_A.detach().to('cpu')
test_input = test_input.detach().to('cpu')