def noncausal_prompt_amm_generator(d):
q = torch.arange(d)[:, None]
k = torch.arange(d)[None, :]
s = args.maze_height * args.maze_width
def noncausal_prompt_amm_generator(d):
q = torch.arange(d)[:, None]
k = torch.arange(d)[None, :]
s = args.maze_height * args.maze_width