X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;h=185df3870c1e0c8d35daf9f3e3257340d50fa1a7;hb=57332f677ef5ee535707c1b83a541aa0e79508e6;hp=7c8e9f4c894ad332e808d07f008ac4c569046bd1;hpb=f0ea1f2375fa3a0be38970a58185cddee97dccef;p=mygptrnn.git diff --git a/mygpt.py b/mygpt.py index 7c8e9f4..185df38 100755 --- a/mygpt.py +++ b/mygpt.py @@ -487,7 +487,7 @@ class Caterpillar(nn.Module): self.proba_gate_dropout = 0.0 - self.w_G = randw(nb_heads, caterpillar_height, dim_model, amplitude=1e-5) + self.w_G = randw(nb_heads, caterpillar_height, dim_model) self.b_G = nn.Parameter( torch.full( (nb_heads, caterpillar_height), -math.log(caterpillar_height - 1) @@ -500,10 +500,14 @@ class Caterpillar(nn.Module): self.w_O = randw(dim_v * nb_heads, dim_model) self.init_K_rec = randw( - caterpillar_height, caterpillar_length, dim_qk, amplitude=1e-5 + caterpillar_height, + caterpillar_length, + dim_qk, ) self.init_V_rec = randw( - caterpillar_height, caterpillar_length, dim_v, amplitude=1e-5 + caterpillar_height, + caterpillar_length, + dim_v, ) def reset_inner_loss(self): @@ -572,30 +576,24 @@ class Caterpillar(nn.Module): warnings.warn("gate dropout", RuntimeWarning) epsilon = 0.5 - dropout_start = ( - ( - torch.rand(G.size(), device=G.device) - .flatten(2, 3) - .sort(dim=2) - .indices - == 0 - ) - .unflatten(2, (CH, t1 - t0)) + dropout_head = ( + (torch.rand(N, H, 1, t1 - t0, device=G.device).sort(dim=3).indices == 0) + .expand_as(G) .float() ) - dropout_tail = dropout_start.cumsum(dim=3) - dropout_start + dropout_tail = dropout_head.cumsum(dim=3) - dropout_head dropout_active = ( torch.rand(N, 1, 1, 1, device=G.device) < self.proba_gate_dropout ).long() - dropout_start *= dropout_active + dropout_head *= dropout_active dropout_tail *= dropout_active G = ( G - + dropout_start * (1 - epsilon - G.detach()) + # + dropout_head * (1 - epsilon - G.detach()) - dropout_tail * G.detach() )