projects
/
mygptrnn.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
de08313
)
Update.
author
François Fleuret
<francois@fleuret.org>
Wed, 10 Jan 2024 07:43:27 +0000
(08:43 +0100)
committer
François Fleuret
<francois@fleuret.org>
Wed, 10 Jan 2024 07:43:27 +0000
(08:43 +0100)
mygpt.py
patch
|
blob
|
history
diff --git
a/mygpt.py
b/mygpt.py
index
17f2f6d
..
ed4b2a7
100755
(executable)
--- a/
mygpt.py
+++ b/
mygpt.py
@@
-540,6
+540,9
@@
class Caterpillar(nn.Module):
self.cache_Y = X.new_zeros(N, T, DM)
self.cache_Y = X.new_zeros(N, T, DM)
+ V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
+ K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
+
######################################################################
# Compute the recurrent state
######################################################################
# Compute the recurrent state
@@
-558,24
+561,21
@@
class Caterpillar(nn.Module):
G = G / G.sum(1, keepdim=True).clamp(min=1)
G = G / G.sum(1, keepdim=True).clamp(min=1)
- if self.training and self.proba_gate_dropout > 0.0:
- warnings.warn("gate dropout", RuntimeWarning)
- epsilon = 0.5
-
- V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
- K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
-
# We prepare the arguments for the parallel scan
A = 1 - G.sum(1)
gated_V = torch.einsum("nhet,nhtd->netd", G, V)
gated_K = torch.einsum("nhet,nhtd->netd", G, K)
# We prepare the arguments for the parallel scan
A = 1 - G.sum(1)
gated_V = torch.einsum("nhet,nhtd->netd", G, V)
gated_K = torch.einsum("nhet,nhtd->netd", G, K)
- #
Initial recurrent stat
e
+ #
We start from cached values, which matters in inferenc
e
init_rec_V = self.rec_V[:, :, t0 - CL : t0]
init_rec_K = self.rec_K[:, :, t0 - CL : t0]
init_rec_V = self.rec_V[:, :, t0 - CL : t0]
init_rec_K = self.rec_K[:, :, t0 - CL : t0]
+ if self.training and self.proba_gate_dropout > 0.0:
+ warnings.warn("gate dropout", RuntimeWarning)
+ epsilon = 0.5
+
#################################################################
# Associative scan
#################################################################
# Associative scan