projects
/
pytorch.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[pytorch.git]
/
eingather.py
diff --git
a/eingather.py
b/eingather.py
index
c7552d7
..
03b713c
100755
(executable)
--- a/
eingather.py
+++ b/
eingather.py
@@
-71,7
+71,7
@@
def lambda_eingather(op, src_shape, *indexes_shape):
idx.append(lambda indexes: a)
print(f"{idx=}")
idx.append(lambda indexes: a)
print(f"{idx=}")
- return lambda indexes: [
f(indexes) for f in idx]
+ return lambda indexes: [f(indexes) for f in idx]
f = do(src_shape, s_src)
print(f"{f(0)=}")
f = do(src_shape, s_src)
print(f"{f(0)=}")
@@
-102,12
+102,12
@@
index2 = torch.randint(src.size(3), (src.size(1),))
# result[a, c, e] = src[c, a, index1[e, a, e], index2[a]]
# result[a, c, e] = src[c, a, index1[e, a, e], index2[a]]
-#result = eingather("ca(eae)(a) -> ace", src, index1, index2)
+#
result = eingather("ca(eae)(a) -> ace", src, index1, index2)
from functorch.dim import dims
from functorch.dim import dims
-a,
c,e=
dims(3)
-result
=src[c,a,index1[e,a,e],index2[a]].order(a,c,
e)
+a,
c, e =
dims(3)
+result
= src[c, a, index1[e, a, e], index2[a]].order(a, c,
e)
# Check
# Check