projects
/
mygptrnn.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
6e87fe0
)
Update.
author
François Fleuret
<francois@fleuret.org>
Thu, 11 Jan 2024 21:04:31 +0000
(22:04 +0100)
committer
François Fleuret
<francois@fleuret.org>
Thu, 11 Jan 2024 21:04:31 +0000
(22:04 +0100)
mygpt.py
patch
|
blob
|
history
diff --git
a/mygpt.py
b/mygpt.py
index
633ad64
..
a62cf49
100755
(executable)
--- a/
mygpt.py
+++ b/
mygpt.py
@@
-570,19
+570,10
@@
class Caterpillar(nn.Module):
warnings.warn("rotating barrel", RuntimeWarning)
warnings.warn("rotating barrel", RuntimeWarning)
- # print(f"SANITY2 {N=} {H=} {R=} {t0=} {t1=} {G.size()=}")
-
- n_barrel = torch.arange(N, device=G.device)[:, None, None, None]
- h_barrel = torch.arange(H, device=G.device)[None, :, None, None]
r_barrel = torch.arange(R, device=G.device)[None, None, :, None]
t_barrel = torch.arange(t1 - t0, device=G.device)[None, None, None, :]
r_barrel = (r_barrel + (t_barrel + t0) // L) % R
r_barrel = torch.arange(R, device=G.device)[None, None, :, None]
t_barrel = torch.arange(t1 - t0, device=G.device)[None, None, None, :]
r_barrel = (r_barrel + (t_barrel + t0) // L) % R
-
- # GG = G.gather(dim=2,index=r_barrel)
- G = G[n_barrel, h_barrel, r_barrel, t_barrel]
-
- # print("SANITY", (GG-G).abs())
- # exit(0)
+ G = G.gather(dim=2, index=r_barrel.expand_as(G))
######################################################################
# The "flashbacks"
######################################################################
# The "flashbacks"