+amm_generator = None
+
+if args.noncausal_prompt:
+ amm_generator = lambda d: torch.logical_and(
+ torch.arange(d)[None, None, :, None] < torch.arange(d)[None, None, None, :],
+ torch.logical_or(
+ torch.arange(d)[None, None, :, None] >= d // 2,
+ torch.arange(d)[None, None, None, :] >= d // 2,
+ ),
+ )
+