X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=eingather.py;h=03b713c1f21fe1e5e792d75d8ca56475097652e1;hp=c7552d740d34ce53dbc01e4c7be4b7dc0b24d563;hb=HEAD;hpb=5598ee31cbcd2ebbedeadbd66518f082c66aaaa9 diff --git a/eingather.py b/eingather.py index c7552d7..03b713c 100755 --- a/eingather.py +++ b/eingather.py @@ -71,7 +71,7 @@ def lambda_eingather(op, src_shape, *indexes_shape): idx.append(lambda indexes: a) print(f"{idx=}") - return lambda indexes: [ f(indexes) for f in idx] + return lambda indexes: [f(indexes) for f in idx] f = do(src_shape, s_src) print(f"{f(0)=}") @@ -102,12 +102,12 @@ index2 = torch.randint(src.size(3), (src.size(1),)) # result[a, c, e] = src[c, a, index1[e, a, e], index2[a]] -#result = eingather("ca(eae)(a) -> ace", src, index1, index2) +# result = eingather("ca(eae)(a) -> ace", src, index1, index2) from functorch.dim import dims -a,c,e=dims(3) -result=src[c,a,index1[e,a,e],index2[a]].order(a,c,e) +a, c, e = dims(3) +result = src[c, a, index1[e, a, e], index2[a]].order(a, c, e) # Check