From 4d0e56bee81c535293367628dd73cbf993d0690a Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Fri, 22 May 2020 13:22:54 +0200 Subject: [PATCH] Update. --- attentiontoy1d.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/attentiontoy1d.py b/attentiontoy1d.py index 6540a0f..d7f06fe 100755 --- a/attentiontoy1d.py +++ b/attentiontoy1d.py @@ -309,8 +309,9 @@ test_input = torch.cat((test_input, positional_input.expand(test_input.size(0), test_outputs = model((test_input - mu) / std).detach() if args.with_attention: - x = model[0:4]((test_input - mu) / std) - test_A = model[4].attention(x) + k = next(k for k, l in enumerate(model) if isinstance(l, AttentionLayer)) + x = model[0:k]((test_input - mu) / std) + test_A = model[k].attention(x) test_A = test_A.detach().to('cpu') test_input = test_input.detach().to('cpu') -- 2.39.5