help = 'Provide a positional encoding',
action='store_true', default=False)
+parser.add_argument('--seed',
+ type = int, default = 0,
+ help = 'Random seed (default 0, < 0 is no seeding)')
+
args = parser.parse_args()
+if args.seed >= 0:
+ torch.manual_seed(args.seed)
+
######################################################################
label=''
else:
device = torch.device('cpu')
-torch.manual_seed(1)
-
######################################################################
seq_height_min, seq_height_max = 1.0, 25.0
test_outputs = model((test_input - mu) / std).detach()
if args.with_attention:
- x = model[0:4]((test_input - mu) / std)
- test_A = model[4].attention(x)
+ k = next(k for k, l in enumerate(model) if isinstance(l, AttentionLayer))
+ x = model[0:k]((test_input - mu) / std)
+ test_A = model[k].attention(x)
test_A = test_A.detach().to('cpu')
test_input = test_input.detach().to('cpu')