projects
/
picoclvr.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[picoclvr.git]
/
mygpt.py
diff --git
a/mygpt.py
b/mygpt.py
index
8cd0152
..
45b7b59
100755
(executable)
--- a/
mygpt.py
+++ b/
mygpt.py
@@
-45,6
+45,9
@@
class BracketedSequence:
def slice(self):
return self.x[:, self.first : self.first + self.nb]
def slice(self):
return self.x[:, self.first : self.first + self.nb]
+ def complete(self):
+ return self.first == 0 and self.nb == x.size(1)
+
######################################################################
######################################################################
@@
-120,7
+123,6
@@
class QKVAttention(nn.Module):
def randw(*d):
return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
def randw(*d):
return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
- assert causal, "TODO: Switch off the cache when non-causal!!!"
self.causal = causal
self.attention_dropout = attention_dropout
self.causal = causal
self.attention_dropout = attention_dropout
@@
-132,6
+134,10
@@
class QKVAttention(nn.Module):
def forward(self, bs_q):
x_q = bs_q.x
def forward(self, bs_q):
x_q = bs_q.x
+ assert (
+ self.causal or bs_q.complete()
+ ), "Partial evaluation is only possible for causal models"
+
if bs_q.first == 0:
self.cache_k = x_q.new_zeros(
x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
if bs_q.first == 0:
self.cache_k = x_q.new_zeros(
x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)