# Roll the gating indexes
warnings.warn("rotating barrel", RuntimeWarning)
# Roll the gating indexes
warnings.warn("rotating barrel", RuntimeWarning)
n_barrel = torch.arange(N, device=G.device)[:, None, None, None]
h_barrel = torch.arange(H, device=G.device)[None, :, None, None]
r_barrel = torch.arange(R, device=G.device)[None, None, :, None]
t_barrel = torch.arange(t1 - t0, device=G.device)[None, None, None, :]
n_barrel = torch.arange(N, device=G.device)[:, None, None, None]
h_barrel = torch.arange(H, device=G.device)[None, :, None, None]
r_barrel = torch.arange(R, device=G.device)[None, None, :, None]
t_barrel = torch.arange(t1 - t0, device=G.device)[None, None, None, :]
- r_barrel = (r_barrel + t_barrel + t0) % R
-
- # print(f"({N}, {H}, {R}, {t1-t0}) {G.size()=}")
+ r_barrel = (r_barrel + (t_barrel + t0) // L) % R
G = G[n_barrel, h_barrel, r_barrel, t_barrel]
G = G[n_barrel, h_barrel, r_barrel, t_barrel]