- V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
- K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
+ G = G / G.sum(1, keepdim=True).clamp(min=1)
+
+ ######################################################################
+ # Roll the gating indexes
+
+ # warnings.warn("rotating barrel", RuntimeWarning)
+
+ # r_barrel = torch.arange(R, device=G.device)[None, None, :, None]
+ # t_barrel = torch.arange(t1 - t0, device=G.device)[None, None, None, :]
+ # r_barrel = (r_barrel + (t_barrel + t0) // L) % R
+ # G = G.gather(dim=2, index=r_barrel.expand_as(G))
+
+ ######################################################################
+ # The "flashbacks"
+
+ if self.training and self.proba_gate_dropout > 0.0:
+ # This is a better implementation of "flashbacks".
+
+ # G is NxHxExT where e is the caterpillar's row.
+
+ warnings.warn("gate dropout", RuntimeWarning)
+ epsilon = 0.5
+
+ dropout_head = (
+ (torch.rand(N, H, 1, t1 - t0, device=G.device).sort(dim=3).indices == 0)
+ .expand_as(G)
+ .float()
+ )
+
+ dropout_tail = dropout_head.cumsum(dim=3) - dropout_head
+
+ dropout_active = (
+ torch.rand(N, 1, 1, 1, device=G.device) < self.proba_gate_dropout
+ ).long()
+
+ dropout_head *= dropout_active
+ dropout_tail *= dropout_active
+
+ G = (
+ G
+ + dropout_head * (1 - epsilon - G.detach())
+ - dropout_tail * G.detach()
+ )
+
+ ######################################################################