projects
/
picoclvr.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[picoclvr.git]
/
mygpt.py
diff --git
a/mygpt.py
b/mygpt.py
index
ac1c55e
..
0400b48
100755
(executable)
--- a/
mygpt.py
+++ b/
mygpt.py
@@
-169,9
+169,6
@@
class QKVAttention(nn.Module):
"nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs_q.first + bs_q.nb]
) / math.sqrt(self.w_q.size(1))
"nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs_q.first + bs_q.nb]
) / math.sqrt(self.w_q.size(1))
- if self.record_attention:
- self.a = a
-
if self.causal:
if bs_q.first == 0:
self.cache_attzero = (
if self.causal:
if bs_q.first == 0:
self.cache_attzero = (
@@
-186,6
+183,10
@@
class QKVAttention(nn.Module):
)
a = a.softmax(dim=3)
)
a = a.softmax(dim=3)
+
+ if self.record_attention:
+ self.a = a
+
a = F.dropout(a, self.attention_dropout, self.training)
y = torch.einsum(
a = F.dropout(a, self.attention_dropout, self.training)
y = torch.einsum(