X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;h=4d4824707baf1f8d22d961f8331cd8fe11cb510c;hb=c0750e416e28fbdc9f6dc03cc6d7b11edd1ac333;hp=9bacaffbe7100507c93297fb7acd418474e68729;hpb=f06a70eca52e988857ee043f1379d41b09dd365d;p=mygptrnn.git diff --git a/mygpt.py b/mygpt.py index 9bacaff..4d48247 100755 --- a/mygpt.py +++ b/mygpt.py @@ -545,6 +545,8 @@ class Caterpillar(nn.Module): torch.einsum("ntc,hec->nhet", X, self.w_G) + self.b_G[None, :, :, None] ).sigmoid() + G = F.dropout(G, self.attention_dropout, self.training) + V = torch.einsum("ntc,hdc->nhtd", X, self.w_V) K = torch.einsum("ntc,hdc->nhtd", X, self.w_K) @@ -561,7 +563,7 @@ class Caterpillar(nn.Module): # by updating that at time t-L, the parallel scan operates # with a period of L. To do so we split the time indexing in # two axes, the second of size CL, and run the parallel scan - # using the other alone as the sequence index. + # using the other as the sequence index. A = A.unflatten(2, (-1, CL)) gated_V = gated_V.unflatten(2, (-1, CL))