Update.
[mygptrnn.git] / mygpt.py
index 9bacaff..4d48247 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -545,6 +545,8 @@ class Caterpillar(nn.Module):
             torch.einsum("ntc,hec->nhet", X, self.w_G) + self.b_G[None, :, :, None]
         ).sigmoid()
 
+        G = F.dropout(G, self.attention_dropout, self.training)
+
         V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
         K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
 
@@ -561,7 +563,7 @@ class Caterpillar(nn.Module):
         # by updating that at time t-L, the parallel scan operates
         # with a period of L. To do so we split the time indexing in
         # two axes, the second of size CL, and run the parallel scan
-        # using the other alone as the sequence index.
+        # using the other as the sequence index.
 
         A = A.unflatten(2, (-1, CL))
         gated_V = gated_V.unflatten(2, (-1, CL))