##############################
+# Returns a tensor with an additional index at rank win_dim, that move
+# along the same dimension as dim, on a domain {0...win_size-1}, and
+# dim is restricted on a domain reduced by win_size-1 values.
+
+
def moving_window(x, dim, win_dim, win_size):
size, stride = x.size(), x.stride()
size = size[:dim] + (size[dim] - win_size + 1,) + size[dim + 1 :]
T = bs.x.size(1)
DV = self.w_V.size(1)
DK = self.w_K.size(1)
- Dout = self.w_O.size(1)
+ DM = self.w_O.size(1)
CH = self.caterpillar_height
CL = self.caterpillar_length
t0 >= CL and (t1 - t0) % CL == 0
), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length"
+ # We cache values to deal efficiently with auto-regression
+
if bs.init_cache:
self.rec_V = X.new_zeros(N, CH, T, DV)
self.rec_K = X.new_zeros(N, CH, T, DK)
self.rec_V[:, :, t0 - CL : t0] = self.init_V_rec[None, :, :, :]
self.rec_K[:, :, t0 - CL : t0] = self.init_K_rec[None, :, :, :]
- self.cache_Y = X.new_zeros(N, T, Dout)
+ self.cache_Y = X.new_zeros(N, T, DM)
######################################################################
# Compute the recurrent state
- # This is the Gating sequence that modulates if they key and
- # values should be stored in one of the CH pairs of the
- # current stack. The CH gating values are independent, which
- # means that the same thing could be stored up to CH times or
- # not at all
+ # This is the Gating sequence that modulates the storing of
+ # the new key and value in the CH pairs of the current
+ # stack. The CH gating values are independent, which means
+ # that the current K/V could be stored in all the pairs of the
+ # recurrent state, or not at all.
G = (
torch.einsum("ntc,hec->nhet", X, self.w_G) + self.b_G[None, :, :, None]
).sigmoid()
+ G = F.dropout(G, self.attention_dropout, self.training)
+
V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
init_rec_V = self.rec_V[:, :, t0 - CL : t0]
init_rec_K = self.rec_K[:, :, t0 - CL : t0]
- # Here there is a trick: The parallel scan operates with a
- # period of L, so we split the sequence indexing in two axes,
- # the second of size CL, and run the parallel scan using the
- # other alone as the sequence index.
+ # Here there is a trick: Since the stack at time t is computed
+ # by updating that at time t-L, the parallel scan operates
+ # with a period of L. To do so we split the time indexing in
+ # two axes, the second of size CL, and run the parallel scan
+ # using the other as the sequence index.
A = A.unflatten(2, (-1, CL))
gated_V = gated_V.unflatten(2, (-1, CL))