projects
/
mygpt.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
9425d32
)
Update.
author
Francois Fleuret
<francois@fleuret.org>
Mon, 13 Jun 2022 13:33:56 +0000
(15:33 +0200)
committer
Francois Fleuret
<francois@fleuret.org>
Mon, 13 Jun 2022 13:33:56 +0000
(15:33 +0200)
mygpt.py
patch
|
blob
|
history
diff --git
a/mygpt.py
b/mygpt.py
index
a23470b
..
080083a
100755
(executable)
--- a/
mygpt.py
+++ b/
mygpt.py
@@
-47,16
+47,16
@@
class QKVAttention(nn.Module):
def randw(*d):
return nn.Parameter(torch.empty(*d).normal_(0, 1 / math.sqrt(d[-1])))
def randw(*d):
return nn.Parameter(torch.empty(*d).normal_(0, 1 / math.sqrt(d[-1])))
- self.wq = randw(nb_heads, dim_qk, dim_in)
- self.wk = randw(nb_heads, dim_qk, dim_in)
- self.wv = randw(nb_heads, dim_v, dim_in)
+ self.w
_
q = randw(nb_heads, dim_qk, dim_in)
+ self.w
_
k = randw(nb_heads, dim_qk, dim_in)
+ self.w
_
v = randw(nb_heads, dim_v, dim_in)
self.causal = causal
self.attention_dropout = attention_dropout
def forward(self, x):
self.causal = causal
self.attention_dropout = attention_dropout
def forward(self, x):
- q = torch.einsum('ntc,hdc->nhtd', x, self.wq)
- k = torch.einsum('ntc,hdc->nhtd', x, self.wk)
- v = torch.einsum('ntc,hdc->nhtd', x, self.wv)
+ q = torch.einsum('ntc,hdc->nhtd', x, self.w
_
q)
+ k = torch.einsum('ntc,hdc->nhtd', x, self.w
_
k)
+ v = torch.einsum('ntc,hdc->nhtd', x, self.w
_
v)
r = math.sqrt(q.size(3))
a = torch.einsum('nhtd,nhsd->nhts', q, k).div(r)
if self.causal:
r = math.sqrt(q.size(3))
a = torch.einsum('nhtd,nhsd->nhts', q, k).div(r)
if self.causal: