torch.einsum("ntc,hec->nhet", X, self.w_G) + self.b_G[None, :, :, None]
).sigmoid()
+ G = F.dropout(G, self.attention_dropout, self.training)
+
V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
# by updating that at time t-L, the parallel scan operates
# with a period of L. To do so we split the time indexing in
# two axes, the second of size CL, and run the parallel scan
- # using the other alone as the sequence index.
+ # using the other as the sequence index.
A = A.unflatten(2, (-1, CL))
gated_V = gated_V.unflatten(2, (-1, CL))