- ######################################################################
- # Clip the gating to avoid values greater than 1 when several
- # heads hit the same row
-
- G = G / G.sum(1, keepdim=True).clamp(min=1)
-
- ######################################################################
- # Roll the gating indexes
-
- # warnings.warn("rotating barrel", RuntimeWarning)
-
- # r_barrel = torch.arange(R, device=G.device)[None, None, :, None]
- # t_barrel = torch.arange(t1 - t0, device=G.device)[None, None, None, :]
- # r_barrel = (r_barrel + (t_barrel + t0) // L) % R
- # G = G.gather(dim=2, index=r_barrel.expand_as(G))
+ A = A.unflatten(2, (-1, L))
+ gated_V = gated_V.unflatten(2, (-1, L))
+ gated_K = gated_K.unflatten(2, (-1, L))