- G = G / G.sum(1, keepdim=True).clamp(min=1)
-
- ######################################################################
- # Roll the gating indexes
-
- # warnings.warn("rotating barrel", RuntimeWarning)
-
- # r_barrel = torch.arange(R, device=G.device)[None, None, :, None]
- # t_barrel = torch.arange(t1 - t0, device=G.device)[None, None, None, :]
- # r_barrel = (r_barrel + (t_barrel + t0) // L) % R
- # G = G.gather(dim=2, index=r_barrel.expand_as(G))
+ # G = (
+ # torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None]
+ # ).softmax(dim=2)