From: François Fleuret Date: Thu, 11 Jan 2024 21:04:31 +0000 (+0100) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=4f5d03d3371b124121e8f9fc0ff583553fea1e38;p=mygptrnn.git Update. --- diff --git a/mygpt.py b/mygpt.py index 633ad64..a62cf49 100755 --- a/mygpt.py +++ b/mygpt.py @@ -570,19 +570,10 @@ class Caterpillar(nn.Module): warnings.warn("rotating barrel", RuntimeWarning) - # print(f"SANITY2 {N=} {H=} {R=} {t0=} {t1=} {G.size()=}") - - n_barrel = torch.arange(N, device=G.device)[:, None, None, None] - h_barrel = torch.arange(H, device=G.device)[None, :, None, None] r_barrel = torch.arange(R, device=G.device)[None, None, :, None] t_barrel = torch.arange(t1 - t0, device=G.device)[None, None, None, :] r_barrel = (r_barrel + (t_barrel + t0) // L) % R - - # GG = G.gather(dim=2,index=r_barrel) - G = G[n_barrel, h_barrel, r_barrel, t_barrel] - - # print("SANITY", (GG-G).abs()) - # exit(0) + G = G.gather(dim=2, index=r_barrel.expand_as(G)) ###################################################################### # The "flashbacks"