projects
/
pytorch.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Oups.
[pytorch.git]
/
attentiontoy1d.py
diff --git
a/attentiontoy1d.py
b/attentiontoy1d.py
index
cff8350
..
92d90cf
100755
(executable)
--- a/
attentiontoy1d.py
+++ b/
attentiontoy1d.py
@@
-1,18
+1,20
@@
#!/usr/bin/env python
#!/usr/bin/env python
-#
@XREMOTE_HOST: elk.fleuret.org
-#
@XREMOTE_EXEC: /home/fleuret/conda/bin/python
-# @XREMOTE_PRE: killall -q -9 python || echo "Nothing killed"
-#
@XREMOTE_GET: *.pdf *.log
+#
Any copyright is dedicated to the Public Domain.
+#
https://creativecommons.org/publicdomain/zero/1.0/
+
+#
Written by Francois Fleuret <francois@fleuret.org>
import torch, math, sys, argparse
from torch import nn
from torch.nn import functional as F
import torch, math, sys, argparse
from torch import nn
from torch.nn import functional as F
+import matplotlib.pyplot as plt
+
######################################################################
######################################################################
-parser = argparse.ArgumentParser(description='Toy
RNN
.')
+parser = argparse.ArgumentParser(description='Toy
attention model
.')
parser.add_argument('--nb_epochs',
type = int, default = 250)
parser.add_argument('--nb_epochs',
type = int, default = 250)
@@
-29,8
+31,15
@@
parser.add_argument('--positional_encoding',
help = 'Provide a positional encoding',
action='store_true', default=False)
help = 'Provide a positional encoding',
action='store_true', default=False)
+parser.add_argument('--seed',
+ type = int, default = 0,
+ help = 'Random seed (default 0, < 0 is no seeding)')
+
args = parser.parse_args()
args = parser.parse_args()
+if args.seed >= 0:
+ torch.manual_seed(args.seed)
+
######################################################################
label=''
######################################################################
label=''
@@
-60,8
+69,6
@@
if torch.cuda.is_available():
else:
device = torch.device('cpu')
else:
device = torch.device('cpu')
-torch.manual_seed(1)
-
######################################################################
seq_height_min, seq_height_max = 1.0, 25.0
######################################################################
seq_height_min, seq_height_max = 1.0, 25.0
@@
-146,9
+153,6
@@
def generate_sequences(nb):
######################################################################
######################################################################
-import matplotlib.pyplot as plt
-import matplotlib.collections as mc
-
def save_sequence_images(filename, sequences, tr = None, bx = None):
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
def save_sequence_images(filename, sequences, tr = None, bx = None):
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
@@
-310,8
+314,9
@@
test_input = torch.cat((test_input, positional_input.expand(test_input.size(0),
test_outputs = model((test_input - mu) / std).detach()
if args.with_attention:
test_outputs = model((test_input - mu) / std).detach()
if args.with_attention:
- x = model[0:4]((test_input - mu) / std)
- test_A = model[4].attention(x)
+ k = next(k for k, l in enumerate(model) if isinstance(l, AttentionLayer))
+ x = model[0:k]((test_input - mu) / std)
+ test_A = model[k].attention(x)
test_A = test_A.detach().to('cpu')
test_input = test_input.detach().to('cpu')
test_A = test_A.detach().to('cpu')
test_input = test_input.detach().to('cpu')