warnings.warn("gate dropout", RuntimeWarning)
- # kill = (
- # torch.rand(G.size(), device=G.device) <= self.proba_gate_dropout
- # ).float()
-
+ # Pick a point in each of the NxHxR timeline and set this
+ # entry and the following to 1
kill = (
torch.rand(N, H, R, t1 - t0, device=G.device).sort(dim=3).indices == 0
).cumsum(dim=3)
+
+ # Keep these mask for only some of the NxHxR
kill = kill * (
torch.rand(N, H, R, 1, device=G.device) <= self.proba_gate_dropout
)
+ # The coefficient to keep are the complementary
mask = 1 - kill
masked_next_V, masked_next_K = recurrence(G * mask, V, K)
Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
- # We build tensors NxHxTxFxL where N is the sample index, H
- # the head, T the time, F the row in the caterpillar, and L
+ # We build tensors NxHxTxRxL where N is the sample index, H
+ # the head, T the time, R the row in the caterpillar, and L
# the column in the caterpillar
windowed_V = moving_window(
# We have an attention score for each of the RxL values
ar = torch.einsum(
- "nhtd,nftld->nhtfl",
+ "nhtd,nrtld->nhtrl",
Q,
windowed_K,
) / math.sqrt(DK)