Update.
[pytorch.git] / eingather.py
index c7552d7..03b713c 100755 (executable)
@@ -71,7 +71,7 @@ def lambda_eingather(op, src_shape, *indexes_shape):
                 idx.append(lambda indexes: a)
 
         print(f"{idx=}")
-        return lambda indexes: [ f(indexes) for f in idx]
+        return lambda indexes: [f(indexes) for f in idx]
 
     f = do(src_shape, s_src)
     print(f"{f(0)=}")
@@ -102,12 +102,12 @@ index2 = torch.randint(src.size(3), (src.size(1),))
 
 # result[a, c, e] = src[c, a, index1[e, a, e], index2[a]]
 
-#result = eingather("ca(eae)(a) -> ace", src, index1, index2)
+# result = eingather("ca(eae)(a) -> ace", src, index1, index2)
 
 from functorch.dim import dims
 
-a,c,e=dims(3)
-result=src[c,a,index1[e,a,e],index2[a]].order(a,c,e)
+a, c, e = dims(3)
+result = src[c, a, index1[e, a, e], index2[a]].order(a, c, e)
 
 # Check