From: François Fleuret Date: Wed, 10 Jan 2024 16:58:03 +0000 (+0100) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=2bf045a277bb300dc851ffb8a93db3a6726faa60;p=mygptrnn.git Update. --- diff --git a/mygpt.py b/mygpt.py index 7c8e9f4..ba93851 100755 --- a/mygpt.py +++ b/mygpt.py @@ -485,7 +485,7 @@ class Caterpillar(nn.Module): self.caterpillar_height = caterpillar_height self.attention_dropout = attention_dropout - self.proba_gate_dropout = 0.0 + self.proba_gate_dropout = 0.25 self.w_G = randw(nb_heads, caterpillar_height, dim_model, amplitude=1e-5) self.b_G = nn.Parameter( @@ -572,7 +572,7 @@ class Caterpillar(nn.Module): warnings.warn("gate dropout", RuntimeWarning) epsilon = 0.5 - dropout_start = ( + dropout_head = ( ( torch.rand(G.size(), device=G.device) .flatten(2, 3) @@ -584,18 +584,18 @@ class Caterpillar(nn.Module): .float() ) - dropout_tail = dropout_start.cumsum(dim=3) - dropout_start + dropout_tail = dropout_head.cumsum(dim=3) - dropout_head dropout_active = ( torch.rand(N, 1, 1, 1, device=G.device) < self.proba_gate_dropout ).long() - dropout_start *= dropout_active + dropout_head *= dropout_active dropout_tail *= dropout_active G = ( G - + dropout_start * (1 - epsilon - G.detach()) + # + dropout_head * (1 - epsilon - G.detach()) - dropout_tail * G.detach() )