+def snake_solver(input, ar_mask):
+ for n in range(input.size(0)):
+ i, j, memory = 0, 0, {}
+ # print(input[n])
+ # print(ar_mask[n])
+ for l in range(input.size(1) // 2):
+ if ar_mask[n, 2 * l] == 1:
+ if memory.get((i, j)) is None:
+ input[n, 2 * l] = -1
+ else:
+ input[n, 2 * l] = memory[(i, j)]
+ else:
+ # print(f'@3 {memory=}')
+ if memory.get((i, j)) is None:
+ memory[(i, j)] = input[n, 2 * l]
+ else:
+ assert memory[(i, j)] == input[n, 2 * l], f"n={n} l={l}"
+ # print(f'@1 {i=} {j=}')
+ d = input[n, 2 * l + 1].item()
+ i += (d + 1) % 2 * (d - 1)
+ j += d % 2 * (d - 2)
+ # print(f'@2 {i=} {j=}')
+
+