+ # Clip the gating to avoid values greater than 1 when several
+ # heads hit the same row
+
+ G = G / G.sum(1, keepdim=True).clamp(min=1)
+
+ ######################################################################
+ # Roll the gating indexes
+
+ # warnings.warn("rotating barrel", RuntimeWarning)
+
+ # r_barrel = torch.arange(R, device=G.device)[None, None, :, None]
+ # t_barrel = torch.arange(t1 - t0, device=G.device)[None, None, None, :]
+ # r_barrel = (r_barrel + (t_barrel + t0) // L) % R
+ # G = G.gather(dim=2, index=r_barrel.expand_as(G))
+
+ ######################################################################
+ # The "flashbacks"
+