- n_barrel = torch.arange(N, device=G.device)[:, None, None, None]
- h_barrel = torch.arange(H, device=G.device)[None, :, None, None]
- r_barrel = torch.arange(R, device=G.device)[None, None, :, None]
- t_barrel = torch.arange(t1 - t0, device=G.device)[None, None, None, :]
- r_barrel = (r_barrel + (t_barrel + t0) // L) % R
-
- # GG = G.gather(dim=2,index=r_barrel)
- G = G[n_barrel, h_barrel, r_barrel, t_barrel]
-
- # print("SANITY", (GG-G).abs())
- # exit(0)
+ # G = (
+ # torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None]
+ # ).softmax(dim=2)