Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 11 Oct 2024 16:03:34 +0000 (18:03 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 11 Oct 2024 16:03:34 +0000 (18:03 +0200)
attae.py
main.py

index 0d36a33..bb97ed4 100755 (executable)
--- a/attae.py
+++ b/attae.py
@@ -76,12 +76,17 @@ class AdHocPositionalEncoding(nn.Module):
 
 
 class WithResidual(nn.Module):
-    def __init__(self, *f):
+    def __init__(self, f, masker=None):
         super().__init__()
         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
+        self.masker = masker
 
     def forward(self, x):
-        return x + self.f(x)
+        if self.masker is None:
+            mask = 1
+        else:
+            mask = self.masker(x)
+        return mask * x + self.f(x)
 
 
 ######################################################################
@@ -135,28 +140,42 @@ class MHAttention(nn.Module):
 ######################################################################
 
 
-def create_trunk(dim_model, dim_keys, dim_hidden, nb_heads, nb_blocks, dropout=0.0):
+def create_trunk(
+    dim_model,
+    dim_keys,
+    dim_hidden,
+    nb_heads,
+    nb_blocks,
+    dropout=0.0,
+    residual_masker=None,
+):
     trunk_blocks = []
 
     for b in range(nb_blocks):
         trunk_blocks += [
             WithResidual(
-                nn.LayerNorm((dim_model,)),
-                MHAttention(
-                    dim_model=dim_model,
-                    dim_qk=dim_keys,
-                    dim_v=dim_model // nb_heads,
-                    nb_heads=nb_heads,
-                    attention=vanilla_attention,
-                    attention_dropout=dropout,
+                masker=residual_masker,
+                f=(
+                    nn.LayerNorm((dim_model,)),
+                    MHAttention(
+                        dim_model=dim_model,
+                        dim_qk=dim_keys,
+                        dim_v=dim_model // nb_heads,
+                        nb_heads=nb_heads,
+                        attention=vanilla_attention,
+                        attention_dropout=dropout,
+                    ),
                 ),
             ),
             WithResidual(
-                nn.LayerNorm((dim_model,)),
-                nn.Linear(in_features=dim_model, out_features=dim_hidden),
-                nn.ReLU(),
-                nn.Linear(in_features=dim_hidden, out_features=dim_model),
-                nn.Dropout(dropout),
+                masker=residual_masker,
+                f=(
+                    nn.LayerNorm((dim_model,)),
+                    nn.Linear(in_features=dim_model, out_features=dim_hidden),
+                    nn.ReLU(),
+                    nn.Linear(in_features=dim_hidden, out_features=dim_model),
+                    nn.Dropout(dropout),
+                ),
             ),
         ]
 
@@ -214,12 +233,12 @@ class AttentionAE(nn.Module):
     def forward(self, x):
         x = self.embedding(x)
 
-        warnings.warn("flipping order for symmetry check", RuntimeWarning)
-        x = torch.cat([x[:, 200:], x[:, :200]], dim=1)
+        warnings.warn("flipping order for symmetry check", RuntimeWarning)
+
         x = self.positional_encoding(x)
-        x = torch.cat([x[:, 200:], x[:, :200]], dim=1)
 
         x = self.trunk(x)
+
         x = self.readout(x)
 
         return x
@@ -228,22 +247,6 @@ class AttentionAE(nn.Module):
 ######################################################################
 
 
-class WithMaskedResidual(nn.Module):
-    def __init__(self, masker, *f):
-        super().__init__()
-        self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
-        self.masker = masker
-        self.mask = None
-
-    def forward(self, x):
-        if self.mask is None:
-            self.mask = self.masker(x)
-        return self.mask * x + self.f(x)
-
-
-######################################################################
-
-
 class FunctionalAttentionAE(nn.Module):
     def __init__(
         self,
@@ -348,6 +351,7 @@ class Reasoning(nn.Module):
         attention=vanilla_attention,
         attention_dropout=0.0,
         len_max=1e5,
+        residual_masker=None,
     ):
         super().__init__()
 
@@ -359,47 +363,58 @@ class Reasoning(nn.Module):
 
         self.positional_encoding = VaswaniPositionalEncoding(len_max)
 
-        self.trunk_joint = create_trunk(
+        self.trunk_A = create_trunk(
             dim_model=dim_model,
             dim_keys=dim_qk,
             dim_hidden=dim_hidden,
             nb_heads=nb_heads,
             nb_blocks=nb_blocks,
             dropout=attention_dropout,
+            residual_masker=residual_masker,
         )
 
-        self.trunk_marginal = create_trunk(
+        self.trunk_B = create_trunk(
             dim_model=dim_model,
             dim_keys=dim_qk,
             dim_hidden=dim_hidden,
             nb_heads=nb_heads,
             nb_blocks=nb_blocks,
             dropout=attention_dropout,
+            residual_masker=residual_masker,
         )
 
     def forward(self, x_q):
-        #!!!!!!!!!!!!!!!!!!!!
-        # x_q = torch.cat([x_q[:,200:,:], x_q[:,:200,:]],dim=1)
-
         T, S = x_q.size(1), self.x_star.size(0)
         nb, dim, nc = x_q.size(0), x_q.size(2), self.nb_chunks
+        f = self.x_star.reshape(1, S, dim).expand(nb, S, dim)
 
-        x_star = self.x_star.reshape(1, S, dim).expand(nb, S, dim)
-
-        x = torch.cat([x_star, x_q], dim=1)
-        x = self.trunk_joint(x)
-
-        f, x = x[:, :S, :], x[:, S:, :]
-        x = x.reshape(nb * nc, T // nc, dim)
+        x = x_q
+        x = x.reshape(nb, nc, T // nc, dim).reshape(nb * nc, T // nc, dim)
         f = f.repeat(nc, 1, 1)
         x = torch.cat([f, x], dim=1)
-        x = self.trunk_marginal(x)
+        x = self.trunk_A(x)
+        k = torch.arange(nb, device=x_q.device)
+        u = x[k * 2, :S]
+        x[k * 2, :S] = x[k * 2 + 1, :S]
+        x[k * 2 + 1, :S] = u
+        x = self.trunk_B(x)
+        x = x[:, S:]
+        x = x.reshape(nb, nc, T // nc, dim).reshape(nb, T, dim)
 
-        x = x[:, S:, :]
-        x = x.reshape(nb, T, dim)
+        return x
 
-        #!!!!!!!!!!!!!!!!!!!!
-        # x = torch.cat([x[:,200:,:], x[:,:200,:]],dim=1)
+    def forward_one_vector(self, x_q):
+        T, S = x_q.size(1), self.x_star.size(0)
+        nb, dim, nc = x_q.size(0), x_q.size(2), self.nb_chunks
+        x_star = self.x_star.reshape(1, S, dim).expand(nb, S, dim)
+        x = torch.cat([x_star, x_q], dim=1)
+        x = self.trunk_A(x)
+        f = x[:, :S, :]
+        x = x_q
+        x = x + 1e-3 * f.mean(dim=1, keepdim=True)
+        x = x.reshape(nb, nc, T // nc, dim).reshape(nb * nc, T // nc, dim)
+        x = self.trunk_B(x)
+        x = x.reshape(nb, nc, T // nc, dim).reshape(nb, T, dim)
 
         return x
 
diff --git a/main.py b/main.py
index d5c1c5c..3b1caa9 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -527,15 +527,15 @@ def one_epoch(model, n_epoch, c_quizzes, train=True, local_device=main_device):
     imt_set = torch.cat([b_p, b_g])
     imt_set = imt_set[torch.randperm(imt_set.size(0), device=imt_set.device)]
 
+    batch_size = args.batch_size
+
     if train:
         label = "train"
         model.train().to(local_device)
         optimizer_to(model.optimizer, local_device)
-        batch_size = args.train_batch_size
     else:
         label = "test"
         model.eval().to(local_device)
-        batch_size = args.eval_batch_size
 
     nb_samples, acc_loss = 0, 0.0
 
@@ -932,20 +932,30 @@ if args.test == "aebn":
     # )
 
     i = torch.arange(400)[:, None]
-    k = [2**k for k in range(4)] + [10 * 2**k for k in range(4)] + [100, 200]
+    k = [1, 2, 4, 8, 16, 10, 20, 40, 80, 160, 100, 200]
     k = torch.tensor(k)[None, :]
-    pe = (i // k) % 2
+    pe = 2.0 * ((i // k) % 2) - 1.0
+
+    model.positional_encoding = attae.AdHocPositionalEncoding(
+        args.dim_model,
+        pe,  # trainable=True
+    )
+
+    nb_f_tokens = 100
 
-    model.positional_encoding = attae.AdHocPositionalEncoding(args.dim_model, pe)
+    def no_f_residual(x):
+        m = x.new_full((1, x.size(1), 1), 1.0)
+        m[:, :nb_f_tokens, :] = 0
+        return m
 
     model.trunk = attae.Reasoning(
-        nb_f_tokens=8,
+        nb_f_tokens=nb_f_tokens,
         nb_chunks=2,
         dim_model=args.dim_model,
         dim_qk=args.dim_keys,
         dim_hidden=args.dim_hidden,
         nb_heads=args.nb_heads,
-        nb_blocks=args.nb_blocks // 2,
+        nb_blocks=args.nb_blocks,
         attention_dropout=args.dropout,
     )
 
@@ -962,6 +972,16 @@ if args.test == "aebn":
             test_c_quizzes=None,
             local_device=main_device,
         )
+        filename = f"aebn_{model.id:03d}.pth"
+        torch.save(
+            {
+                "state_dict": model.state_dict(),
+                "optimizer_state_dict": model.optimizer.state_dict(),
+                "test_accuracy": model.test_accuracy,
+                "nb_epochs": model.nb_epochs,
+            },
+            os.path.join(args.result_dir, filename),
+        )
 
     exit(0)