X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;h=a62cf4908ba88622a3f567d082c7a94711887fde;hb=4f5d03d3371b124121e8f9fc0ff583553fea1e38;hp=9d3abb62cc8b6d95ce8be5b291d1c9e36e7f100d;hpb=cebc20b3608a41bfd27b2ab9d950c082f9b7ea89;p=mygptrnn.git diff --git a/mygpt.py b/mygpt.py index 9d3abb6..a62cf49 100755 --- a/mygpt.py +++ b/mygpt.py @@ -569,17 +569,11 @@ class Caterpillar(nn.Module): # Roll the gating indexes warnings.warn("rotating barrel", RuntimeWarning) - n_barrel = torch.arange(N, device=G.device)[:, None, None, None] - h_barrel = torch.arange(H, device=G.device)[None, :, None, None] + r_barrel = torch.arange(R, device=G.device)[None, None, :, None] t_barrel = torch.arange(t1 - t0, device=G.device)[None, None, None, :] - r_barrel = (r_barrel + t_barrel + t0) % R - - # print(f"({N}, {H}, {R}, {t1-t0}) {G.size()=}") - - G = G[n_barrel, h_barrel, r_barrel, t_barrel] - - # print(G.sum()) + r_barrel = (r_barrel + (t_barrel + t0) // L) % R + G = G.gather(dim=2, index=r_barrel.expand_as(G)) ###################################################################### # The "flashbacks"