projects
/
culture.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[culture.git]
/
mygpt.py
diff --git
a/mygpt.py
b/mygpt.py
index
ac1c55e
..
0cf70e0
100755
(executable)
--- a/
mygpt.py
+++ b/
mygpt.py
@@
-46,7
+46,7
@@
class BracketedSequence:
return self.x[:, self.first : self.first + self.nb]
def complete(self):
return self.x[:, self.first : self.first + self.nb]
def complete(self):
- return self.first == 0 and self.nb == x.size(1)
+ return self.first == 0 and self.nb ==
self.
x.size(1)
######################################################################
######################################################################
@@
-169,9
+169,6
@@
class QKVAttention(nn.Module):
"nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs_q.first + bs_q.nb]
) / math.sqrt(self.w_q.size(1))
"nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs_q.first + bs_q.nb]
) / math.sqrt(self.w_q.size(1))
- if self.record_attention:
- self.a = a
-
if self.causal:
if bs_q.first == 0:
self.cache_attzero = (
if self.causal:
if bs_q.first == 0:
self.cache_attzero = (
@@
-186,6
+183,10
@@
class QKVAttention(nn.Module):
)
a = a.softmax(dim=3)
)
a = a.softmax(dim=3)
+
+ if self.record_attention:
+ self.a = a
+
a = F.dropout(a, self.attention_dropout, self.training)
y = torch.einsum(
a = F.dropout(a, self.attention_dropout, self.training)
y = torch.einsum(