index1 = torch.randint(src.size(2), (src.size(3), src.size(1), src.size(3)))
index2 = torch.randint(src.size(3), (src.size(1),))
-# I want result[a, c, e] = src[c, a, index1[e, a, e], index2[a], e]
+# I want result[a, c, e] = src[c, a, index1[e, a, e], index2[a]]
result = eingather("ca(eae)(a) -> ace", src, index1, index2)