projects
/
mygptrnn.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
2bf045a
)
Update.
author
François Fleuret
<francois@fleuret.org>
Wed, 10 Jan 2024 18:44:26 +0000
(19:44 +0100)
committer
François Fleuret
<francois@fleuret.org>
Wed, 10 Jan 2024 18:44:26 +0000
(19:44 +0100)
mygpt.py
patch
|
blob
|
history
diff --git
a/mygpt.py
b/mygpt.py
index
ba93851
..
185df38
100755
(executable)
--- a/
mygpt.py
+++ b/
mygpt.py
@@
-485,9
+485,9
@@
class Caterpillar(nn.Module):
self.caterpillar_height = caterpillar_height
self.attention_dropout = attention_dropout
self.caterpillar_height = caterpillar_height
self.attention_dropout = attention_dropout
- self.proba_gate_dropout = 0.
25
+ self.proba_gate_dropout = 0.
0
- self.w_G = randw(nb_heads, caterpillar_height, dim_model
, amplitude=1e-5
)
+ self.w_G = randw(nb_heads, caterpillar_height, dim_model)
self.b_G = nn.Parameter(
torch.full(
(nb_heads, caterpillar_height), -math.log(caterpillar_height - 1)
self.b_G = nn.Parameter(
torch.full(
(nb_heads, caterpillar_height), -math.log(caterpillar_height - 1)
@@
-500,10
+500,14
@@
class Caterpillar(nn.Module):
self.w_O = randw(dim_v * nb_heads, dim_model)
self.init_K_rec = randw(
self.w_O = randw(dim_v * nb_heads, dim_model)
self.init_K_rec = randw(
- caterpillar_height, caterpillar_length, dim_qk, amplitude=1e-5
+ caterpillar_height,
+ caterpillar_length,
+ dim_qk,
)
self.init_V_rec = randw(
)
self.init_V_rec = randw(
- caterpillar_height, caterpillar_length, dim_v, amplitude=1e-5
+ caterpillar_height,
+ caterpillar_length,
+ dim_v,
)
def reset_inner_loss(self):
)
def reset_inner_loss(self):
@@
-573,14
+577,8
@@
class Caterpillar(nn.Module):
epsilon = 0.5
dropout_head = (
epsilon = 0.5
dropout_head = (
- (
- torch.rand(G.size(), device=G.device)
- .flatten(2, 3)
- .sort(dim=2)
- .indices
- == 0
- )
- .unflatten(2, (CH, t1 - t0))
+ (torch.rand(N, H, 1, t1 - t0, device=G.device).sort(dim=3).indices == 0)
+ .expand_as(G)
.float()
)
.float()
)