projects
/
beaver.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
d0561c9
)
Update.
author
François Fleuret
<francois@fleuret.org>
Sat, 25 Mar 2023 20:02:38 +0000
(21:02 +0100)
committer
François Fleuret
<francois@fleuret.org>
Sat, 25 Mar 2023 20:02:38 +0000
(21:02 +0100)
mygpt.py
patch
|
blob
|
history
diff --git
a/mygpt.py
b/mygpt.py
index
7166788
..
06b56df
100755
(executable)
--- a/
mygpt.py
+++ b/
mygpt.py
@@
-148,8
+148,8
@@
class QKVAttention(nn.Module):
if amm_generator is None:
self.amm_generator = (
if amm_generator is None:
self.amm_generator = (
- lambda d: torch.arange(d)[
None, None,
:, None]
- < torch.arange(d)[None,
None, None,
:]
+ lambda d: torch.arange(d)[:, None]
+ < torch.arange(d)[None, :]
)
else:
self.amm_generator = amm_generator
)
else:
self.amm_generator = amm_generator
@@
-190,7
+190,7
@@
class QKVAttention(nn.Module):
if self.causal:
if bs_q.first == 0:
if self.causal:
if bs_q.first == 0:
- self.cache_attzero = self.amm_generator(x_q.size(1)).to(q.device)
+ self.cache_attzero = self.amm_generator(x_q.size(1)).to(q.device)
[None, None,:,:]
a = a.masked_fill(
self.cache_attzero[
:, :, bs_q.first : bs_q.first + bs_q.nb, : bs_q.first + bs_q.nb
a = a.masked_fill(
self.cache_attzero[
:, :, bs_q.first : bs_q.first + bs_q.nb, : bs_q.first + bs_q.nb