+
+def noncausal_prompt_amm_generator(d):
+ q = torch.arange(d)[:, None]
+ k = torch.arange(d)[None, :]
+ s = args.maze_height * args.maze_width
+ # return torch.logical_and(q < k, torch.logical_or(q >= s, k >= s))
+ return q < k
+
+
+amm_generator = None
+
+if args.noncausal_prompt:
+ amm_generator = noncausal_prompt_amm_generator
+