X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;fp=mygpt.py;h=0400b48b21631db0dc6806d5504d6287f2324357;hb=ef3bef5253ff719953dfffff28d4122c19acdd77;hp=ac1c55e84d91fecb06b453533f1800aead640ed7;hpb=b59fca62aa31de18a3e0cd0bb54e395d4b1254ae;p=picoclvr.git diff --git a/mygpt.py b/mygpt.py index ac1c55e..0400b48 100755 --- a/mygpt.py +++ b/mygpt.py @@ -169,9 +169,6 @@ class QKVAttention(nn.Module): "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs_q.first + bs_q.nb] ) / math.sqrt(self.w_q.size(1)) - if self.record_attention: - self.a = a - if self.causal: if bs_q.first == 0: self.cache_attzero = ( @@ -186,6 +183,10 @@ class QKVAttention(nn.Module): ) a = a.softmax(dim=3) + + if self.record_attention: + self.a = a + a = F.dropout(a, self.attention_dropout, self.training) y = torch.einsum(