projects
/
mygptrnn.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
ec5bad2
)
Update.
author
François Fleuret
<francois@fleuret.org>
Fri, 19 Jan 2024 13:11:45 +0000
(14:11 +0100)
committer
François Fleuret
<francois@fleuret.org>
Fri, 19 Jan 2024 13:11:45 +0000
(14:11 +0100)
mygpt.py
patch
|
blob
|
history
diff --git
a/mygpt.py
b/mygpt.py
index
0414bb6
..
760a3c6
100755
(executable)
--- a/
mygpt.py
+++ b/
mygpt.py
@@
-629,15
+629,21
@@
class Caterpillar(nn.Module):
warnings.warn("gate dropout", RuntimeWarning)
warnings.warn("gate dropout", RuntimeWarning)
+ if self.gate_dropout_sync:
+ shape_kill = (N, 1, 1)
+ else:
+ shape_kill = (N, H, R)
+
# Pick a point in each of the NxHxR timeline and set this
# entry and the following to 1
kill = (
# Pick a point in each of the NxHxR timeline and set this
# entry and the following to 1
kill = (
- torch.rand(N, H, R, t1 - t0, device=G.device).sort(dim=3).indices == 0
+ torch.rand(*shape_kill, t1 - t0, device=G.device).sort(dim=3).indices
+ == 0
).cumsum(dim=3)
# Keep these mask for only some of the NxHxR
kill = kill * (
).cumsum(dim=3)
# Keep these mask for only some of the NxHxR
kill = kill * (
- torch.rand(
N, H, R
, 1, device=G.device) <= self.gate_dropout_proba
+ torch.rand(
*shape_kill
, 1, device=G.device) <= self.gate_dropout_proba
)
# The coefficient to keep are the complementary
)
# The coefficient to keep are the complementary