projects
/
mygptrnn.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
f0ea1f2
)
Update.
author
François Fleuret
<francois@fleuret.org>
Wed, 10 Jan 2024 16:58:03 +0000
(17:58 +0100)
committer
François Fleuret
<francois@fleuret.org>
Wed, 10 Jan 2024 16:58:03 +0000
(17:58 +0100)
mygpt.py
patch
|
blob
|
history
diff --git
a/mygpt.py
b/mygpt.py
index
7c8e9f4
..
ba93851
100755
(executable)
--- a/
mygpt.py
+++ b/
mygpt.py
@@
-485,7
+485,7
@@
class Caterpillar(nn.Module):
self.caterpillar_height = caterpillar_height
self.attention_dropout = attention_dropout
self.caterpillar_height = caterpillar_height
self.attention_dropout = attention_dropout
- self.proba_gate_dropout = 0.
0
+ self.proba_gate_dropout = 0.
25
self.w_G = randw(nb_heads, caterpillar_height, dim_model, amplitude=1e-5)
self.b_G = nn.Parameter(
self.w_G = randw(nb_heads, caterpillar_height, dim_model, amplitude=1e-5)
self.b_G = nn.Parameter(
@@
-572,7
+572,7
@@
class Caterpillar(nn.Module):
warnings.warn("gate dropout", RuntimeWarning)
epsilon = 0.5
warnings.warn("gate dropout", RuntimeWarning)
epsilon = 0.5
- dropout_
start
= (
+ dropout_
head
= (
(
torch.rand(G.size(), device=G.device)
.flatten(2, 3)
(
torch.rand(G.size(), device=G.device)
.flatten(2, 3)
@@
-584,18
+584,18
@@
class Caterpillar(nn.Module):
.float()
)
.float()
)
- dropout_tail = dropout_
start.cumsum(dim=3) - dropout_start
+ dropout_tail = dropout_
head.cumsum(dim=3) - dropout_head
dropout_active = (
torch.rand(N, 1, 1, 1, device=G.device) < self.proba_gate_dropout
).long()
dropout_active = (
torch.rand(N, 1, 1, 1, device=G.device) < self.proba_gate_dropout
).long()
- dropout_
start
*= dropout_active
+ dropout_
head
*= dropout_active
dropout_tail *= dropout_active
G = (
G
dropout_tail *= dropout_active
G = (
G
-
+ dropout_start
* (1 - epsilon - G.detach())
+
# + dropout_head
* (1 - epsilon - G.detach())
- dropout_tail * G.detach()
)
- dropout_tail * G.detach()
)