X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=eingather.py;h=03b713c1f21fe1e5e792d75d8ca56475097652e1;hb=7cf92d14892ccce7c5a1eaa38c0d6b8fff03e751;hp=c7552d740d34ce53dbc01e4c7be4b7dc0b24d563;hpb=a2ccdd2f5e9fb3e7ed52492729b880f815ddfbcb;p=pytorch.git diff --git a/eingather.py b/eingather.py index c7552d7..03b713c 100755 --- a/eingather.py +++ b/eingather.py @@ -71,7 +71,7 @@ def lambda_eingather(op, src_shape, *indexes_shape): idx.append(lambda indexes: a) print(f"{idx=}") - return lambda indexes: [ f(indexes) for f in idx] + return lambda indexes: [f(indexes) for f in idx] f = do(src_shape, s_src) print(f"{f(0)=}") @@ -102,12 +102,12 @@ index2 = torch.randint(src.size(3), (src.size(1),)) # result[a, c, e] = src[c, a, index1[e, a, e], index2[a]] -#result = eingather("ca(eae)(a) -> ace", src, index1, index2) +# result = eingather("ca(eae)(a) -> ace", src, index1, index2) from functorch.dim import dims -a,c,e=dims(3) -result=src[c,a,index1[e,a,e],index2[a]].order(a,c,e) +a, c, e = dims(3) +result = src[c, a, index1[e, a, e], index2[a]].order(a, c, e) # Check