- def recurrence(G, V, K):
- # We prepare the arguments for the parallel scan
-
- A = 1 - G.sum(dim=1)
-
- gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V)
- gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K)
-
- # We start from cached values, which matters in inference
-
- init_rec_V = self.rec_V[:, :, t0 - L : t0]
- init_rec_K = self.rec_K[:, :, t0 - L : t0]
-
- # Here there is a trick: Since the stack at position t is
- # computed by updating that at position t-L, the parallel
- # scan operates with a period of L. To do so we split the
- # sequence indexing in two axes, the second of size L, and
- # run the parallel scan using the first as the sequence index.
-
- A = A.unflatten(2, (-1, L))
- gated_V = gated_V.unflatten(2, (-1, L))
- gated_K = gated_K.unflatten(2, (-1, L))
-
- next_V = pscan_dim(A, gated_V, init_rec_V, dim=2).flatten(2, 3)
- next_K = pscan_dim(A, gated_K, init_rec_K, dim=2).flatten(2, 3)