- V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
- K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
+ if self.training and self.proba_gate_dropout > 0.0:
+ # This is a better implementation of "flashbacks".
+
+ # G is NxHxExT where e is the caterpillar's row.
+
+ warnings.warn("gate dropout", RuntimeWarning)
+ epsilon = 0.5
+
+ dropout_start = (
+ (
+ torch.rand(G.size(), device=G.device)
+ .flatten(2, 3)
+ .sort(dim=2)
+ .indices
+ == 0
+ )
+ .unflatten(2, (CH, t1 - t0))
+ .float()
+ )
+
+ dropout_tail = dropout_start.cumsum(dim=3) - dropout_start
+
+ dropout_active = (
+ torch.rand(N, 1, 1, 1, device=G.device) < self.proba_gate_dropout
+ ).long()
+
+ dropout_start *= dropout_active
+ dropout_tail *= dropout_active
+
+ G = (
+ G
+ + dropout_start * (1 - epsilon - G.detach())
+ - dropout_tail * G.detach()
+ )
+
+ ######################################################################