if args.noncausal_prompt:
amm_generator = lambda d: torch.logical_and(
torch.arange(d)[None, None, :, None] < torch.arange(d)[None, None, None, :],
if args.noncausal_prompt:
amm_generator = lambda d: torch.logical_and(
torch.arange(d)[None, None, :, None] < torch.arange(d)[None, None, None, :],