- V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
- K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
+ # Clip the gating to avoid values greater than 1 when several
+ # heads hit the same row
+
+ G = G / G.sum(1, keepdim=True).clamp(min=1)
+
+ ######################################################################
+
+ def recurrence(G, V, K):
+ # We prepare the arguments for the parallel scan
+
+ A = 1 - G.sum(1)
+
+ gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V)
+ gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K)
+
+ # We start from cached values, which matters in inference
+
+ init_rec_V = self.rec_V[:, :, t0 - L : t0]
+ init_rec_K = self.rec_K[:, :, t0 - L : t0]
+
+ # Associative scan