# attention_matrices = [ torch.rand(3,5), torch.rand(8,3), torch.rand(5,8) ]
# for a in attention_matrices: a=a/a.sum(-1,keepdim=True)
# attention_matrices = [ torch.rand(3,5), torch.rand(8,3), torch.rand(5,8) ]
# for a in attention_matrices: a=a/a.sum(-1,keepdim=True)