Update.
authorFrancois Fleuret <francois@fleuret.org>
Fri, 22 May 2020 11:22:54 +0000 (13:22 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Fri, 22 May 2020 11:22:54 +0000 (13:22 +0200)
attentiontoy1d.py

index 6540a0f..d7f06fe 100755 (executable)
@@ -309,8 +309,9 @@ test_input = torch.cat((test_input, positional_input.expand(test_input.size(0),
 test_outputs = model((test_input - mu) / std).detach()
 
 if args.with_attention:
-    x = model[0:4]((test_input - mu) / std)
-    test_A = model[4].attention(x)
+    k = next(k for k, l in enumerate(model) if isinstance(l, AttentionLayer))
+    x = model[0:k]((test_input - mu) / std)
+    test_A = model[k].attention(x)
     test_A = test_A.detach().to('cpu')
 
 test_input = test_input.detach().to('cpu')