From 6e87fe0cb8bd8a0042bbf7b2ede9d8ed0372fb6b Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 11 Jan 2024 21:49:52 +0100 Subject: [PATCH] Update. --- mygpt.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/mygpt.py b/mygpt.py index 9d3abb6..633ad64 100755 --- a/mygpt.py +++ b/mygpt.py @@ -569,17 +569,20 @@ class Caterpillar(nn.Module): # Roll the gating indexes warnings.warn("rotating barrel", RuntimeWarning) + + # print(f"SANITY2 {N=} {H=} {R=} {t0=} {t1=} {G.size()=}") + n_barrel = torch.arange(N, device=G.device)[:, None, None, None] h_barrel = torch.arange(H, device=G.device)[None, :, None, None] r_barrel = torch.arange(R, device=G.device)[None, None, :, None] t_barrel = torch.arange(t1 - t0, device=G.device)[None, None, None, :] - r_barrel = (r_barrel + t_barrel + t0) % R - - # print(f"({N}, {H}, {R}, {t1-t0}) {G.size()=}") + r_barrel = (r_barrel + (t_barrel + t0) // L) % R + # GG = G.gather(dim=2,index=r_barrel) G = G[n_barrel, h_barrel, r_barrel, t_barrel] - # print(G.sum()) + # print("SANITY", (GG-G).abs()) + # exit(0) ###################################################################### # The "flashbacks" -- 2.39.5