attention_matrices, # list of 2d tensors T1xT2, T2xT3, ..., Tk-1xTk
# do not draw links with a lesser attention
min_link_attention=0,
attention_matrices, # list of 2d tensors T1xT2, T2xT3, ..., Tk-1xTk
# do not draw links with a lesser attention
min_link_attention=0,
min_total_attention=None,
# draw only the top k links
k_top=None,
min_total_attention=None,
# draw only the top k links
k_top=None,
ni = torch.arange(at.size(0))[:, None].expand_as(at)
nj = torch.arange(at.size(1))[None, :].expand_as(at)
at = at.flatten()
ni = torch.arange(at.size(0))[:, None].expand_as(at)
nj = torch.arange(at.size(1))[None, :].expand_as(at)
at = at.flatten()