Update.
[mygptrnn.git] / mygpt.py
index 24ba345..5ea927e 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -656,16 +656,13 @@ class Caterpillar(nn.Module):
         self.rec_K[:, :, t0:t1] = next_K.flatten(2, 3)
 
         if self.training and self.proba_flashback:
-            # insert_flash_back(
-            # self.rec_V,
-            # V,
-            # self.rec_K,
-            # K,
-            # t0,
-            # t1,
-            # CL,
-            # proba=self.proba_flashback / CL,
-            # )
+            # insert_flash_back(self.rec_V,V,self.rec_K,K,t0,t1,CL,proba=self.proba_flashback / CL,)
+
+            # This piece of code makes the assumption that there is
+            # nothing informative before t0, otherwise we'd have to
+            # implement a cache for V and K too. This should not be
+            # too much of a problem since this is used only during
+            # train, where full sequence are available
 
             n = torch.arange(N, device=X.device)[:, None, None, None]
             t = torch.arange(t0, t1, device=X.device)[None, None, :, None]