Update.
[mygptrnn.git] / mygpt.py
index 7c8e9f4..185df38 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -487,7 +487,7 @@ class Caterpillar(nn.Module):
 
         self.proba_gate_dropout = 0.0
 
-        self.w_G = randw(nb_heads, caterpillar_height, dim_model, amplitude=1e-5)
+        self.w_G = randw(nb_heads, caterpillar_height, dim_model)
         self.b_G = nn.Parameter(
             torch.full(
                 (nb_heads, caterpillar_height), -math.log(caterpillar_height - 1)
@@ -500,10 +500,14 @@ class Caterpillar(nn.Module):
         self.w_O = randw(dim_v * nb_heads, dim_model)
 
         self.init_K_rec = randw(
-            caterpillar_height, caterpillar_length, dim_qk, amplitude=1e-5
+            caterpillar_height,
+            caterpillar_length,
+            dim_qk,
         )
         self.init_V_rec = randw(
-            caterpillar_height, caterpillar_length, dim_v, amplitude=1e-5
+            caterpillar_height,
+            caterpillar_length,
+            dim_v,
         )
 
     def reset_inner_loss(self):
@@ -572,30 +576,24 @@ class Caterpillar(nn.Module):
             warnings.warn("gate dropout", RuntimeWarning)
             epsilon = 0.5
 
-            dropout_start = (
-                (
-                    torch.rand(G.size(), device=G.device)
-                    .flatten(2, 3)
-                    .sort(dim=2)
-                    .indices
-                    == 0
-                )
-                .unflatten(2, (CH, t1 - t0))
+            dropout_head = (
+                (torch.rand(N, H, 1, t1 - t0, device=G.device).sort(dim=3).indices == 0)
+                .expand_as(G)
                 .float()
             )
 
-            dropout_tail = dropout_start.cumsum(dim=3) - dropout_start
+            dropout_tail = dropout_head.cumsum(dim=3) - dropout_head
 
             dropout_active = (
                 torch.rand(N, 1, 1, 1, device=G.device) < self.proba_gate_dropout
             ).long()
 
-            dropout_start *= dropout_active
+            dropout_head *= dropout_active
             dropout_tail *= dropout_active
 
             G = (
                 G
-                + dropout_start * (1 - epsilon - G.detach())
+                # + dropout_head * (1 - epsilon - G.detach())
                 - dropout_tail * G.detach()
             )