- amm_generator = lambda d: torch.logical_and(
- torch.arange(d)[None, None, :, None] < torch.arange(d)[None, None, None, :],
- torch.logical_or(
- torch.arange(d)[None, None, :, None] >= d // 2,
- torch.arange(d)[None, None, None, :] >= d // 2,
- ),
- )