projects
/
mygptrnn.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (from parent 1:
e3d5af8
)
Update.
author
François Fleuret
<francois@fleuret.org>
Thu, 18 Jan 2024 12:06:27 +0000
(13:06 +0100)
committer
François Fleuret
<francois@fleuret.org>
Thu, 18 Jan 2024 12:06:27 +0000
(13:06 +0100)
mygpt.py
patch
|
blob
|
history
diff --git
a/mygpt.py
b/mygpt.py
index
492a9bb
..
5451584
100755
(executable)
--- a/
mygpt.py
+++ b/
mygpt.py
@@
-617,8
+617,6
@@
class Caterpillar(nn.Module):
init_rec_V = self.rec_V[:, :, t0 - L : t0]
init_rec_K = self.rec_K[:, :, t0 - L : t0]
init_rec_V = self.rec_V[:, :, t0 - L : t0]
init_rec_K = self.rec_K[:, :, t0 - L : t0]
- # Associative scan
-
# Here there is a trick: Since the stack at position t is
# computed by updating that at position t-L, the parallel
# scan operates with a period of L. To do so we split the
# Here there is a trick: Since the stack at position t is
# computed by updating that at position t-L, the parallel
# scan operates with a period of L. To do so we split the
@@
-646,9
+644,16
@@
class Caterpillar(nn.Module):
warnings.warn("gate dropout", RuntimeWarning)
warnings.warn("gate dropout", RuntimeWarning)
+ # kill = (
+ # torch.rand(G.size(), device=G.device) <= self.proba_gate_dropout
+ # ).float()
+
kill = (
kill = (
- torch.rand(G.size(), device=G.device) <= self.proba_gate_dropout
- ).float()
+ torch.rand(N, H, R, t1 - t0, device=G.device).sort(dim=3).indices == 0
+ ).cumsum(dim=3)
+ kill = kill * (
+ torch.rand(N, H, R, 1, device=G.device) <= self.proba_gate_dropout
+ )
mask = 1 - kill
mask = 1 - kill