projects
/
picoclvr.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
6230689
)
Update.
author
François Fleuret
<francois@fleuret.org>
Fri, 7 Jul 2023 15:48:30 +0000
(17:48 +0200)
committer
François Fleuret
<francois@fleuret.org>
Fri, 7 Jul 2023 15:48:30 +0000
(17:48 +0200)
mygpt.py
patch
|
blob
|
history
diff --git
a/mygpt.py
b/mygpt.py
index
c93010a
..
8cd0152
100755
(executable)
--- a/
mygpt.py
+++ b/
mygpt.py
@@
-62,9
+62,7
@@
class CacheWrapper(nn.Module):
else:
self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
else:
self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
- bs.x = self.cache_y
-
- return bs
+ return BracketedSequence(self.cache_y, bs.first, bs.nb)
##############################
##############################
@@
-76,8
+74,7
@@
class WithResidual(nn.Module):
self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
def forward(self, bs):
self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
def forward(self, bs):
- bs.x = bs.x + self.f(bs).x
- return bs
+ return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb)
##############################
##############################
@@
-108,9
+105,7
@@
class AddPositionalEncoding(nn.Module):
bs.slice() + self.pe[bs.first : bs.first + bs.nb]
)
bs.slice() + self.pe[bs.first : bs.first + bs.nb]
)
- bs.x = self.cache_y
-
- return bs
+ return BracketedSequence(self.cache_y, bs.first, bs.nb)
##############################
##############################
@@
-125,6
+120,7
@@
class QKVAttention(nn.Module):
def randw(*d):
return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
def randw(*d):
return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
+ assert causal, "TODO: Switch off the cache when non-causal!!!"
self.causal = causal
self.attention_dropout = attention_dropout
self.causal = causal
self.attention_dropout = attention_dropout
@@
-148,6
+144,7
@@
class QKVAttention(nn.Module):
q = torch.einsum(
"ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_q
)
q = torch.einsum(
"ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_q
)
+
self.cache_k[:, :, bs_q.first : bs_q.first + bs_q.nb] = torch.einsum(
"ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_k
)
self.cache_k[:, :, bs_q.first : bs_q.first + bs_q.nb] = torch.einsum(
"ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_k
)
@@
-181,9
+178,7
@@
class QKVAttention(nn.Module):
self.cache_y[:, bs_q.first : bs_q.first + bs_q.nb] = y @ self.w_o
self.cache_y[:, bs_q.first : bs_q.first + bs_q.nb] = y @ self.w_o
- bs_q.x = self.cache_y
-
- return bs_q
+ return BracketedSequence(self.cache_y, bs_q.first, bs_q.nb)
##############################
##############################
@@
-252,7
+247,7
@@
class MyGPT(nn.Module):
m.weight.fill_(1.0)
def forward(self, bs):
m.weight.fill_(1.0)
def forward(self, bs):
- bs
.x = F.pad(bs.x, (1, -1)
)
+ bs
= BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb
)
bs = self.embedding(bs)
bs = self.trunk(bs)
bs = self.readout(bs)
bs = self.embedding(bs)
bs = self.trunk(bs)
bs = self.readout(bs)
@@
-288,27
+283,27
@@
class MyGPT(nn.Module):
if __name__ == "__main__":
print("Basic check.")
if __name__ == "__main__":
print("Basic check.")
- vocabulary_size =
10
- x = torch.randint(vocabulary_size, (
9, 7
))
+ vocabulary_size =
3
+ x = torch.randint(vocabulary_size, (
1, 5
))
model = MyGPT(
vocabulary_size=vocabulary_size,
model = MyGPT(
vocabulary_size=vocabulary_size,
- dim_model=
18
,
- dim_keys=
50
,
- dim_hidden=
100
,
+ dim_model=
4
,
+ dim_keys=
2
,
+ dim_hidden=
2
,
nb_heads=2,
nb_blocks=1,
dropout=0.1,
nb_heads=2,
nb_blocks=1,
dropout=0.1,
+ causal=True,
)
model.eval()
y1 = model(BracketedSequence(x)).x
)
model.eval()
y1 = model(BracketedSequence(x)).x
-
y2 = torch.randn_like(y1)
for s in range(x.size(1)):
z = model(BracketedSequence(x, s, 1))
y2 = torch.randn_like(y1)
for s in range(x.size(1)):
z = model(BracketedSequence(x, s, 1))
- y2[:, s] = z.
x[:, s]
+ y2[:, s] = z.
slice()
print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")