projects
/
beaver.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update
[beaver.git]
/
mygpt.py
diff --git
a/mygpt.py
b/mygpt.py
index
06b56df
..
4555b1e
100755
(executable)
--- a/
mygpt.py
+++ b/
mygpt.py
@@
-148,8
+148,7
@@
class QKVAttention(nn.Module):
if amm_generator is None:
self.amm_generator = (
if amm_generator is None:
self.amm_generator = (
- lambda d: torch.arange(d)[:, None]
- < torch.arange(d)[None, :]
+ lambda d: torch.arange(d)[:, None] < torch.arange(d)[None, :]
)
else:
self.amm_generator = amm_generator
)
else:
self.amm_generator = amm_generator
@@
-190,7
+189,9
@@
class QKVAttention(nn.Module):
if self.causal:
if bs_q.first == 0:
if self.causal:
if bs_q.first == 0:
- self.cache_attzero = self.amm_generator(x_q.size(1)).to(q.device)[None, None,:,:]
+ self.cache_attzero = self.amm_generator(x_q.size(1)).to(q.device)[
+ None, None, :, :
+ ]
a = a.masked_fill(
self.cache_attzero[
:, :, bs_q.first : bs_q.first + bs_q.nb, : bs_q.first + bs_q.nb
a = a.masked_fill(
self.cache_attzero[
:, :, bs_q.first : bs_q.first + bs_q.nb, : bs_q.first + bs_q.nb