T = bs.x.size(1)
DV = self.w_V.size(1)
DK = self.w_K.size(1)
- Dout = self.w_O.size(1)
+ DM = self.w_O.size(1)
CH = self.caterpillar_height
CL = self.caterpillar_length
t0 >= CL and (t1 - t0) % CL == 0
), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length"
+ # We cache values to deal efficiently with auto-regression
+
if bs.init_cache:
self.rec_V = X.new_zeros(N, CH, T, DV)
self.rec_K = X.new_zeros(N, CH, T, DK)
self.rec_V[:, :, t0 - CL : t0] = self.init_V_rec[None, :, :, :]
self.rec_K[:, :, t0 - CL : t0] = self.init_K_rec[None, :, :, :]
- self.cache_Y = X.new_zeros(N, T, Dout)
+ self.cache_Y = X.new_zeros(N, T, DM)
######################################################################
# Compute the recurrent state