From 57332f677ef5ee535707c1b83a541aa0e79508e6 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 10 Jan 2024 19:44:26 +0100 Subject: [PATCH] Update. --- mygpt.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/mygpt.py b/mygpt.py index ba93851..185df38 100755 --- a/mygpt.py +++ b/mygpt.py @@ -485,9 +485,9 @@ class Caterpillar(nn.Module): self.caterpillar_height = caterpillar_height self.attention_dropout = attention_dropout - self.proba_gate_dropout = 0.25 + self.proba_gate_dropout = 0.0 - self.w_G = randw(nb_heads, caterpillar_height, dim_model, amplitude=1e-5) + self.w_G = randw(nb_heads, caterpillar_height, dim_model) self.b_G = nn.Parameter( torch.full( (nb_heads, caterpillar_height), -math.log(caterpillar_height - 1) @@ -500,10 +500,14 @@ class Caterpillar(nn.Module): self.w_O = randw(dim_v * nb_heads, dim_model) self.init_K_rec = randw( - caterpillar_height, caterpillar_length, dim_qk, amplitude=1e-5 + caterpillar_height, + caterpillar_length, + dim_qk, ) self.init_V_rec = randw( - caterpillar_height, caterpillar_length, dim_v, amplitude=1e-5 + caterpillar_height, + caterpillar_length, + dim_v, ) def reset_inner_loss(self): @@ -573,14 +577,8 @@ class Caterpillar(nn.Module): epsilon = 0.5 dropout_head = ( - ( - torch.rand(G.size(), device=G.device) - .flatten(2, 3) - .sort(dim=2) - .indices - == 0 - ) - .unflatten(2, (CH, t1 - t0)) + (torch.rand(N, H, 1, t1 - t0, device=G.device).sort(dim=3).indices == 0) + .expand_as(G) .float() ) -- 2.20.1