- self.rec_K[:, :, t0:t1] = (
- mask * K[n, src_head, src_time, dk]
- + (1 - mask) * self.rec_K[:, :, t0:t1]
- )
+ # Associative scan
+
+ # Here there is a trick: Since the stack at position t is
+ # computed by updating that at position t-L, the parallel
+ # scan operates with a period of L. To do so we split the
+ # sequence indexing in two axes, the second of size L, and
+ # run the parallel scan using the first as the sequence index.
+
+ A = A.unflatten(2, (-1, L))
+ gated_V = gated_V.unflatten(2, (-1, L))
+ gated_K = gated_K.unflatten(2, (-1, L))
+
+ next_V = pscan_dim(A, gated_V, init_rec_V, dim=2)
+ next_K = pscan_dim(A, gated_K, init_rec_K, dim=2)
+
+ next_V = next_V.flatten(2, 3)
+ next_K = next_K.flatten(2, 3)
+
+ return next_V, next_K
+
+ #################################################################
+
+ next_V, next_K = recurrence(G, V, K)
+
+ self.rec_V[:, :, t0:t1] = next_V
+ self.rec_K[:, :, t0:t1] = next_K