Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 19 Jan 2024 13:02:37 +0000 (14:02 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 19 Jan 2024 13:02:37 +0000 (14:02 +0100)
main.py
mygpt.py

diff --git a/main.py b/main.py
index 79841f3..3aa696b 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -99,7 +99,11 @@ parser.add_argument("--nb_lines", type=int, default=None)
 
 parser.add_argument("--caterpillar_height", type=int, default=None)
 
-parser.add_argument("--rho", type=float, default=0.0)
+parser.add_argument("--gate_dropout_proba", type=float, default=0.0)
+
+parser.add_argument("--gate_dropout_sync", type=bool, default=False)
+
+parser.add_argument("--rho_inner_loss", type=float, default=0.0)
 
 parser.add_argument("--nb_blocks", type=int, default=None)
 
@@ -747,7 +751,7 @@ model = mygpt.MyGPT(
     dropout=args.dropout,
     attention_layer=args.attention,
     logger=log_string,
-    **sup_args,
+    args=args,
 )
 
 model.to(device)
@@ -905,7 +909,9 @@ for n_epoch in range(nb_epochs_finished, nb_epochs):
         nb_train_samples += input.size(0)
         nb_samples_seen += input.size(0)
 
-        total_loss = loss + (args.rho * inner_loss if args.rho > 0 else 0.0)
+        total_loss = loss + (
+            args.rho_inner_loss * inner_loss if args.rho_inner_loss > 0 else 0.0
+        )
 
         it += 1
         lr = get_lr(n_epoch, it)
index fb24b9a..2d33574 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -202,7 +202,7 @@ class DumbRec(nn.Module):
         attention_dropout=0.0,
         len_max=1e5,
         logger=print,
-        **kwargs,
+        args,
     ):
         super().__init__()
 
@@ -333,7 +333,7 @@ class KVRec(nn.Module):
         attention_dropout=0.0,
         len_max=1e5,
         logger=print,
-        **kwargs,
+        args,
     ):
         super().__init__()
 
@@ -487,7 +487,7 @@ class Caterpillar(nn.Module):
         attention_dropout=0.0,
         len_max=1e5,
         logger=print,
-        **kwargs,
+        args,
     ):
         super().__init__()
 
@@ -502,27 +502,12 @@ class Caterpillar(nn.Module):
         self.caterpillar_height = caterpillar_height
         self.attention_dropout = attention_dropout
 
-        ######################################################################
-        # sup_args
-
-        x = kwargs.get("gate_dropout")
-        if x is None:
-            self.proba_gate_dropout = 0.0
-        else:
-            self.proba_gate_dropout = float(x)
-
-        logger(f"self.proba_gate_dropout {self.proba_gate_dropout}")
-
-        x = kwargs.get("default_bg")
-        if x is None:
-            default_bg = -math.log(caterpillar_height - 1)
-        else:
-            default_bg = float(x)
-
-        logger(f"default_bg {default_bg}")
+        self.gate_dropout_proba = args.gate_dropout_proba
+        self.gate_dropout_sync = args.gate_dropout_sync
 
         ######################################################################
 
+        default_bg = -math.log(caterpillar_height - 1)
         self.w_G = randw(nb_heads, caterpillar_height, dim_model)
         self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), default_bg))
 
@@ -639,7 +624,7 @@ class Caterpillar(nn.Module):
 
         next_V, next_K = recurrence(G, V, K)
 
-        if self.training and self.proba_gate_dropout > 0.0:
+        if self.training and self.gate_dropout_proba > 0.0:
             # G is NxHxRxT where r is the caterpillar's row.
 
             warnings.warn("gate dropout", RuntimeWarning)
@@ -652,7 +637,7 @@ class Caterpillar(nn.Module):
 
             # Keep these mask for only some of the NxHxR
             kill = kill * (
-                torch.rand(N, H, R, 1, device=G.device) <= self.proba_gate_dropout
+                torch.rand(N, H, R, 1, device=G.device) <= self.gate_dropout_proba
             )
 
             # The coefficient to keep are the complementary
@@ -661,10 +646,10 @@ class Caterpillar(nn.Module):
             masked_next_V, masked_next_K = recurrence(G * mask, V, K)
 
             next_V = next_V.detach() + (masked_next_V - masked_next_V.detach()) / (
-                1 - self.proba_gate_dropout
+                1 - self.gate_dropout_proba
             )
             next_K = next_K.detach() + (masked_next_K - masked_next_K.detach()) / (
-                1 - self.proba_gate_dropout
+                1 - self.gate_dropout_proba
             )
 
         self.rec_V[:, :, t0:t1] = next_V
@@ -730,7 +715,7 @@ class QKVAttention(nn.Module):
         causal=False,
         attention_dropout=0.0,
         logger=print,
-        **kwargs,
+        args,
     ):
         super().__init__()
 
@@ -823,7 +808,7 @@ class MyGPT(nn.Module):
         len_max=1e5,
         attention_layer="kvrec",
         logger=print,
-        **kwargs,
+        args,
     ):
         super().__init__()
 
@@ -861,7 +846,7 @@ class MyGPT(nn.Module):
                     causal=causal,
                     attention_dropout=dropout,
                     logger=logger,
-                    **kwargs,
+                    args,
                 )
             elif attention_layer == "dumbrec":
                 return DumbRec(
@@ -872,7 +857,7 @@ class MyGPT(nn.Module):
                     nb_lines=nb_lines,
                     attention_dropout=dropout,
                     logger=logger,
-                    **kwargs,
+                    args,
                 )
             elif attention_layer == "kvrec":
                 return KVRec(
@@ -883,7 +868,7 @@ class MyGPT(nn.Module):
                     nb_lines=nb_lines,
                     attention_dropout=dropout,
                     logger=logger,
-                    **kwargs,
+                    args,
                 )
             elif attention_layer == "caterpillar":
                 return Caterpillar(
@@ -895,7 +880,7 @@ class MyGPT(nn.Module):
                     caterpillar_height=self.caterpillar_height,
                     attention_dropout=dropout,
                     logger=logger,
-                    **kwargs,
+                    args,
                 )
             else:
                 raise ValueError(f"Unknown attention type {attention_layer}.")