Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 20 Sep 2024 11:40:39 +0000 (13:40 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 20 Sep 2024 11:40:39 +0000 (13:40 +0200)
attae.py
grids.py
main.py

index 1e5e122..c04c5d3 100755 (executable)
--- a/attae.py
+++ b/attae.py
@@ -101,7 +101,6 @@ class AttentionAE(nn.Module):
         dim_hidden,
         nb_heads,
         nb_blocks,
-        attention=vanilla_attention,
         dropout=0.0,
         len_max=1e5,
     ):
@@ -127,7 +126,7 @@ class AttentionAE(nn.Module):
                         dim_qk=dim_keys,
                         dim_v=dim_model // nb_heads,
                         nb_heads=nb_heads,
-                        attention=attention,
+                        attention=vanilla_attention,
                         attention_dropout=dropout,
                     ),
                 ),
@@ -163,7 +162,23 @@ class AttentionAE(nn.Module):
 ######################################################################
 
 
-class FunctionalAttentionAE(AttentionAE):
+class WithMaskedResidual(nn.Module):
+    def __init__(self, masker, *f):
+        super().__init__()
+        self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
+        self.masker = masker
+        self.mask = None
+
+    def forward(self, x):
+        if self.mask is None:
+            self.mask = self.masker(x)
+        return self.mask * x + self.f(x)
+
+
+######################################################################
+
+
+class FunctionalAttentionAE(nn.Module):
     def __init__(
         self,
         vocabulary_size,
@@ -176,6 +191,21 @@ class FunctionalAttentionAE(AttentionAE):
         dropout=0.0,
         len_max=1e5,
     ):
+        super().__init__()
+
+        assert dim_model % nb_heads == 0
+
+        self.nb_work_tokens = nb_work_tokens
+
+        self.embedding = nn.Sequential(
+            nn.Embedding(2 * vocabulary_size, dim_model),
+            nn.Dropout(dropout),
+        )
+
+        self.positional_encoding = VaswaniPositionalEncoding(len_max)
+
+        trunk_blocks = []
+
         def no_peek_attention(q, k, v):
             a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3))
             n = self.nb_work_tokens
@@ -186,23 +216,54 @@ class FunctionalAttentionAE(AttentionAE):
             y = torch.einsum("nhts,nhsd->nhtd", a, v)
             return y
 
-        AttentionAE.__init__(
-            self,
-            vocabulary_size,
-            dim_model,
-            dim_keys,
-            dim_hidden,
-            nb_heads,
-            nb_blocks,
-            attention=no_peek_attention,
-            dropout=0.0,
-            len_max=1e5,
-        )
-        self.nb_work_tokens = nb_work_tokens
+        def masker(x):
+            m = torch.arange(x.size(1), device=x.device) >= self.nb_work_tokens
+            return m[None, :, None]
+
+        for b in range(nb_blocks):
+            trunk_blocks += [
+                WithMaskedResidual(
+                    masker,
+                    nn.LayerNorm((dim_model,)),
+                    MHAttention(
+                        dim_model=dim_model,
+                        dim_qk=dim_keys,
+                        dim_v=dim_model // nb_heads,
+                        nb_heads=nb_heads,
+                        attention=no_peek_attention,
+                        attention_dropout=dropout,
+                    ),
+                ),
+                WithMaskedResidual(
+                    masker,
+                    nn.LayerNorm((dim_model,)),
+                    nn.Linear(in_features=dim_model, out_features=dim_hidden),
+                    nn.ReLU(),
+                    nn.Linear(in_features=dim_hidden, out_features=dim_model),
+                    nn.Dropout(dropout),
+                ),
+            ]
+
+        self.trunk = nn.Sequential(*trunk_blocks)
+
+        self.readout = nn.Linear(in_features=dim_model, out_features=vocabulary_size)
+
+        with torch.no_grad():
+            for m in self.modules():
+                if isinstance(m, nn.Embedding):
+                    m.weight.normal_(mean=0, std=2e-2)
+                elif isinstance(m, nn.LayerNorm):
+                    m.bias.zero_()
+                    m.weight.fill_(1.0)
 
     def forward(self, x):
-        x = torch.cat([x.new_zeros(x.size(0), self.nb_work_tokens), x], dim=1)
-        return AttentionAE.forward(self, x)[:, self.nb_work_tokens :]
+        x = self.embedding(x)
+        x = F.pad(x, (0, 0, self.nb_work_tokens, 0))
+        x = self.positional_encoding(x)
+        x = self.trunk(x)
+        x = F.pad(x, (0, 0, -self.nb_work_tokens, 0))
+        x = self.readout(x)
+        return x
 
 
 ######################################################################
index 197eb5a..e5890ca 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -134,16 +134,16 @@ def grow_islands(nb, height, width, nb_seeds, nb_iterations):
 
 
 class Grids(problem.Problem):
-    grid_gray = 64
-    thickness = 1
-    background_gray = 255
-    dots = False
-
-    # grid_gray=240
-    # thickness=1
-    # background_gray=240
+    # grid_gray = 64
+    # thickness = 1
+    # background_gray = 255
     # dots = False
 
+    grid_gray = 240
+    thickness = 0
+    background_gray = 240
+    dots = False
+
     # grid_gray = 192
     # thickness = 0
     # background_gray = 255
@@ -288,7 +288,7 @@ class Grids(problem.Problem):
 
     def vocabulary_size(self):
         warnings.warn("hack +4 to keep the vocabulary size unchanged", RuntimeWarning)
-        return self.nb_colors + 4
+        return self.nb_colors
 
     def grid2img(self, x, scale=15, grids=True):
         m = torch.logical_and(x >= 0, x < self.nb_colors).long()
@@ -369,6 +369,7 @@ class Grids(problem.Problem):
         grids=True,
         margin=12,
         delta=False,
+        delta_highlight=False,
     ):
         quizzes = quizzes.to("cpu")
 
@@ -422,6 +423,10 @@ class Grids(problem.Problem):
             self.grid2img(f_B, grids=grids), frame[None, :], thickness=thickness
         )
 
+        if delta_highlight:
+            q = (img_B == img_f_B).min(dim=1, keepdim=True).values.long()
+            img_f_B = q * (img_f_B // 4 + 192) + (1 - q) * img_f_B
+
         # predicted_parts Nx4
         # correct_parts Nx4
 
@@ -1847,6 +1852,7 @@ if __name__ == "__main__":
             "/tmp",
             t.__name__ + ".png",
             w_quizzes,
+            delta=True,
             # grids=False
             # comments=[f"{t.__name__} #{k}" for k in range(w_quizzes.size(0))],
         )
diff --git a/main.py b/main.py
index 06dfc5e..10e6bc0 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -63,6 +63,8 @@ parser.add_argument("--nb_mistakes_to_be_wrong", type=int, default=5)
 
 # ----------------------------------
 
+parser.add_argument("--model_type", type=str, default="standard")
+
 parser.add_argument("--model", type=str, default="37M")
 
 parser.add_argument("--dim_model", type=int, default=None)
@@ -843,9 +845,16 @@ log_string(f"vocabulary_size {vocabulary_size}")
 
 models = []
 
+if args.model_type == "standard":
+    model_constructor = attae.AttentionAE
+elif args.model_type == "functional":
+    model_constructor = attae.FunctionalAttentionAE
+else:
+    raise ValueError(f"Unknown model type {args.model_type}")
+
+
 for i in range(args.nb_models):
-    # model = attae.FunctionalAttentionAE(
-    model = attae.AttentionAE(
+    model = model_constructor(
         vocabulary_size=vocabulary_size * 2,
         dim_model=args.dim_model,
         dim_keys=args.dim_keys,