Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 6 Jan 2024 19:06:36 +0000 (20:06 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 6 Jan 2024 19:06:36 +0000 (20:06 +0100)
mygpt.py

index 4d48247..0e94672 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -514,7 +514,7 @@ class Caterpillar(nn.Module):
         T = bs.x.size(1)
         DV = self.w_V.size(1)
         DK = self.w_K.size(1)
-        Dout = self.w_O.size(1)
+        DM = self.w_O.size(1)
         CH = self.caterpillar_height
         CL = self.caterpillar_length
 
@@ -522,6 +522,8 @@ class Caterpillar(nn.Module):
             t0 >= CL and (t1 - t0) % CL == 0
         ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length"
 
+        # We cache values to deal efficiently with auto-regression
+
         if bs.init_cache:
             self.rec_V = X.new_zeros(N, CH, T, DV)
             self.rec_K = X.new_zeros(N, CH, T, DK)
@@ -530,7 +532,7 @@ class Caterpillar(nn.Module):
             self.rec_V[:, :, t0 - CL : t0] = self.init_V_rec[None, :, :, :]
             self.rec_K[:, :, t0 - CL : t0] = self.init_K_rec[None, :, :, :]
 
-            self.cache_Y = X.new_zeros(N, T, Dout)
+            self.cache_Y = X.new_zeros(N, T, DM)
 
         ######################################################################
         # Compute the recurrent state